26#include "llvm/ADT/ArrayRef.h"
27#include "llvm/ADT/SmallVector.h"
28#include "llvm/ADT/TypeSwitch.h"
29#include "llvm/Frontend/OpenMP/OMPConstants.h"
30#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
31#include "llvm/IR/Constants.h"
32#include "llvm/IR/DebugInfoMetadata.h"
33#include "llvm/IR/DerivedTypes.h"
34#include "llvm/IR/IRBuilder.h"
35#include "llvm/IR/MDBuilder.h"
36#include "llvm/IR/ReplaceConstant.h"
37#include "llvm/Support/AMDGPUAddrSpace.h"
38#include "llvm/Support/FileSystem.h"
39#include "llvm/Support/NVPTXAddrSpace.h"
40#include "llvm/Support/VirtualFileSystem.h"
41#include "llvm/TargetParser/Triple.h"
42#include "llvm/Transforms/Utils/ModuleUtils.h"
53static llvm::omp::ScheduleKind
54convertToScheduleKind(std::optional<omp::ClauseScheduleKind> schedKind) {
55 if (!schedKind.has_value())
56 return llvm::omp::OMP_SCHEDULE_Default;
57 switch (schedKind.value()) {
58 case omp::ClauseScheduleKind::Static:
59 return llvm::omp::OMP_SCHEDULE_Static;
60 case omp::ClauseScheduleKind::Dynamic:
61 return llvm::omp::OMP_SCHEDULE_Dynamic;
62 case omp::ClauseScheduleKind::Guided:
63 return llvm::omp::OMP_SCHEDULE_Guided;
64 case omp::ClauseScheduleKind::Auto:
65 return llvm::omp::OMP_SCHEDULE_Auto;
66 case omp::ClauseScheduleKind::Runtime:
67 return llvm::omp::OMP_SCHEDULE_Runtime;
68 case omp::ClauseScheduleKind::Distribute:
69 return llvm::omp::OMP_SCHEDULE_Distribute;
71 llvm_unreachable(
"unhandled schedule clause argument");
76class OpenMPAllocStackFrame
81 explicit OpenMPAllocStackFrame(
82 llvm::OpenMPIRBuilder::InsertPointTy allocaIP,
83 llvm::ArrayRef<llvm::BasicBlock *> deallocBlocks)
84 : allocInsertPoint(allocaIP), deallocBlocks(deallocBlocks) {}
85 llvm::OpenMPIRBuilder::InsertPointTy allocInsertPoint;
86 llvm::SmallVector<llvm::BasicBlock *> deallocBlocks;
92class OpenMPLoopInfoStackFrame
96 llvm::CanonicalLoopInfo *loopInfo =
nullptr;
115class PreviouslyReportedError
116 :
public llvm::ErrorInfo<PreviouslyReportedError> {
118 void log(raw_ostream &)
const override {
122 std::error_code convertToErrorCode()
const override {
124 "PreviouslyReportedError doesn't support ECError conversion");
131char PreviouslyReportedError::ID = 0;
142class LinearClauseProcessor {
145 SmallVector<llvm::Value *> linearPreconditionVars;
146 SmallVector<llvm::Value *> linearLoopBodyTemps;
147 SmallVector<llvm::Value *> linearOrigVal;
148 SmallVector<llvm::Value *> linearSteps;
149 SmallVector<llvm::Type *> linearVarTypes;
150 llvm::BasicBlock *linearFinalizationBB;
151 llvm::BasicBlock *linearExitBB;
152 llvm::BasicBlock *linearLastIterExitBB;
156 void registerType(LLVM::ModuleTranslation &moduleTranslation,
157 mlir::Attribute &ty) {
158 linearVarTypes.push_back(moduleTranslation.
convertType(
159 mlir::cast<mlir::TypeAttr>(ty).getValue()));
163 void createLinearVar(llvm::IRBuilderBase &builder,
164 LLVM::ModuleTranslation &moduleTranslation,
165 llvm::Value *linearVar,
int idx) {
166 linearPreconditionVars.push_back(
167 builder.CreateAlloca(linearVarTypes[idx],
nullptr,
".linear_var"));
168 llvm::Value *linearLoopBodyTemp =
169 builder.CreateAlloca(linearVarTypes[idx],
nullptr,
".linear_result");
170 linearOrigVal.push_back(linearVar);
171 linearLoopBodyTemps.push_back(linearLoopBodyTemp);
175 inline void initLinearStep(LLVM::ModuleTranslation &moduleTranslation,
176 mlir::Value &linearStep) {
177 linearSteps.push_back(moduleTranslation.
lookupValue(linearStep));
181 void initLinearVar(llvm::IRBuilderBase &builder,
182 LLVM::ModuleTranslation &moduleTranslation,
183 llvm::BasicBlock *loopPreHeader) {
184 builder.SetInsertPoint(loopPreHeader->getTerminator());
185 for (
size_t index = 0; index < linearOrigVal.size(); index++) {
186 llvm::LoadInst *linearVarLoad =
187 builder.CreateLoad(linearVarTypes[index], linearOrigVal[index]);
188 builder.CreateStore(linearVarLoad, linearPreconditionVars[index]);
193 void updateLinearVar(llvm::IRBuilderBase &builder, llvm::BasicBlock *loopBody,
194 llvm::Value *loopInductionVar) {
195 builder.SetInsertPoint(loopBody->getTerminator());
196 for (
size_t index = 0; index < linearPreconditionVars.size(); index++) {
197 llvm::Type *linearVarType = linearVarTypes[index];
198 llvm::Value *iv = loopInductionVar;
199 llvm::Value *step = linearSteps[index];
201 if (!iv->getType()->isIntegerTy())
202 llvm_unreachable(
"OpenMP loop induction variable must be an integer "
205 if (linearVarType->isIntegerTy()) {
207 iv = builder.CreateSExtOrTrunc(iv, linearVarType);
208 step = builder.CreateSExtOrTrunc(step, linearVarType);
210 llvm::LoadInst *linearVarStart =
211 builder.CreateLoad(linearVarType, linearPreconditionVars[index]);
212 llvm::Value *mulInst = builder.CreateMul(iv, step);
213 llvm::Value *addInst = builder.CreateAdd(linearVarStart, mulInst);
214 builder.CreateStore(addInst, linearLoopBodyTemps[index]);
215 }
else if (linearVarType->isFloatingPointTy()) {
217 step = builder.CreateSExtOrTrunc(step, iv->getType());
218 llvm::Value *mulInst = builder.CreateMul(iv, step);
220 llvm::LoadInst *linearVarStart =
221 builder.CreateLoad(linearVarType, linearPreconditionVars[index]);
222 llvm::Value *mulFp = builder.CreateSIToFP(mulInst, linearVarType);
223 llvm::Value *addInst = builder.CreateFAdd(linearVarStart, mulFp);
224 builder.CreateStore(addInst, linearLoopBodyTemps[index]);
227 "Linear variable must be of integer or floating-point type");
234 void splitLinearFiniBB(llvm::IRBuilderBase &builder,
235 llvm::BasicBlock *loopExit) {
236 linearFinalizationBB = loopExit->splitBasicBlock(
237 loopExit->getTerminator(),
"omp_loop.linear_finalization");
238 linearExitBB = linearFinalizationBB->splitBasicBlock(
239 linearFinalizationBB->getTerminator(),
"omp_loop.linear_exit");
240 linearLastIterExitBB = linearFinalizationBB->splitBasicBlock(
241 linearFinalizationBB->getTerminator(),
"omp_loop.linear_lastiter_exit");
245 llvm::OpenMPIRBuilder::InsertPointOrErrorTy
246 finalizeLinearVar(llvm::IRBuilderBase &builder,
247 LLVM::ModuleTranslation &moduleTranslation,
248 llvm::Value *lastIter) {
250 builder.SetInsertPoint(linearFinalizationBB->getTerminator());
251 llvm::Value *loopLastIterLoad = builder.CreateLoad(
252 llvm::Type::getInt32Ty(builder.getContext()), lastIter);
253 llvm::Value *isLast =
254 builder.CreateCmp(llvm::CmpInst::ICMP_NE, loopLastIterLoad,
255 llvm::ConstantInt::get(
256 llvm::Type::getInt32Ty(builder.getContext()), 0));
258 builder.SetInsertPoint(linearLastIterExitBB->getTerminator());
259 for (
size_t index = 0; index < linearOrigVal.size(); index++) {
260 llvm::LoadInst *linearVarTemp =
261 builder.CreateLoad(linearVarTypes[index], linearLoopBodyTemps[index]);
262 builder.CreateStore(linearVarTemp, linearOrigVal[index]);
268 builder.SetInsertPoint(linearFinalizationBB->getTerminator());
269 builder.CreateCondBr(isLast, linearLastIterExitBB, linearExitBB);
270 linearFinalizationBB->getTerminator()->eraseFromParent();
272 builder.SetInsertPoint(linearExitBB->getTerminator());
274 builder.saveIP(), llvm::omp::OMPD_barrier);
279 void emitStoresForLinearVar(llvm::IRBuilderBase &builder) {
280 for (
size_t index = 0; index < linearOrigVal.size(); index++) {
281 llvm::LoadInst *linearVarTemp =
282 builder.CreateLoad(linearVarTypes[index], linearLoopBodyTemps[index]);
283 builder.CreateStore(linearVarTemp, linearOrigVal[index]);
289 void rewriteInPlace(llvm::IRBuilderBase &builder,
const std::string &BBName,
291 llvm::SmallVector<llvm::User *> users;
292 for (llvm::User *user : linearOrigVal[varIndex]->users())
293 users.push_back(user);
294 for (
auto *user : users) {
295 if (
auto *userInst = dyn_cast<llvm::Instruction>(user)) {
296 if (userInst->getParent()->getName().str().find(BBName) !=
298 user->replaceUsesOfWith(linearOrigVal[varIndex],
299 linearLoopBodyTemps[varIndex]);
310 SymbolRefAttr symbolName) {
311 omp::PrivateClauseOp privatizer =
314 assert(privatizer &&
"privatizer not found in the symbol table");
325 auto todo = [&op](StringRef clauseName) {
326 return op.
emitError() <<
"not yet implemented: Unhandled clause "
327 << clauseName <<
" in " << op.
getName()
331 auto checkAllocate = [&todo](
auto op, LogicalResult &
result) {
332 if (!op.getAllocateVars().empty() || !op.getAllocatorVars().empty())
333 result = todo(
"allocate");
335 auto checkBare = [&todo](
auto op, LogicalResult &
result) {
337 result = todo(
"ompx_bare");
339 auto checkDepend = [&todo](
auto op, LogicalResult &
result) {
340 if (!op.getDependVars().empty() || op.getDependKinds())
343 auto checkHint = [](
auto op, LogicalResult &) {
347 auto checkInReduction = [&todo](
auto op, LogicalResult &
result) {
348 if (!op.getInReductionVars().empty() || op.getInReductionByref() ||
349 op.getInReductionSyms())
350 result = todo(
"in_reduction");
352 auto checkNowait = [&todo](
auto op, LogicalResult &
result) {
356 auto checkOrder = [&todo](
auto op, LogicalResult &
result) {
357 if (op.getOrder() || op.getOrderMod())
360 auto checkPrivate = [&todo](
auto op, LogicalResult &
result) {
361 if (!op.getPrivateVars().empty() || op.getPrivateSyms())
362 result = todo(
"privatization");
364 auto checkReduction = [&todo](
auto op, LogicalResult &
result) {
365 if (isa<omp::TeamsOp>(op) || isa<omp::TaskloopContextOp>(op))
366 if (!op.getReductionVars().empty() || op.getReductionByref() ||
367 op.getReductionSyms())
368 result = todo(
"reduction");
369 if (op.getReductionMod() &&
370 op.getReductionMod().value() != omp::ReductionModifier::defaultmod)
371 result = todo(
"reduction with modifier");
373 auto checkTaskReduction = [&todo](
auto op, LogicalResult &
result) {
374 if (!op.getTaskReductionVars().empty() || op.getTaskReductionByref() ||
375 op.getTaskReductionSyms())
376 result = todo(
"task_reduction");
378 auto checkNumTeams = [&todo](
auto op, LogicalResult &
result) {
379 if (op.hasNumTeamsMultiDim())
380 result = todo(
"num_teams with multi-dimensional values");
382 auto checkNumThreads = [&todo](
auto op, LogicalResult &
result) {
383 if (op.hasNumThreadsMultiDim())
384 result = todo(
"num_threads with multi-dimensional values");
387 auto checkThreadLimit = [&todo](
auto op, LogicalResult &
result) {
388 if (op.hasThreadLimitMultiDim())
389 result = todo(
"thread_limit with multi-dimensional values");
392 auto checkDynGroupprivate = [&todo](
auto op, LogicalResult &
result) {
393 if (op.getDynGroupprivateSize())
394 result = todo(
"dyn_groupprivate");
399 .Case([&](omp::DistributeOp op) {
400 checkAllocate(op,
result);
403 .Case([&](omp::SectionsOp op) {
404 checkAllocate(op,
result);
406 checkReduction(op,
result);
408 .Case([&](omp::ScopeOp op) {
409 checkAllocate(op,
result);
410 checkReduction(op,
result);
412 .Case([&](omp::SingleOp op) {
413 checkAllocate(op,
result);
416 .Case([&](omp::TeamsOp op) {
417 checkAllocate(op,
result);
419 checkNumTeams(op,
result);
420 checkThreadLimit(op,
result);
421 checkDynGroupprivate(op,
result);
423 .Case([&](omp::TaskOp op) {
424 checkAllocate(op,
result);
425 checkInReduction(op,
result);
427 .Case([&](omp::TaskgroupOp op) {
428 checkAllocate(op,
result);
429 checkTaskReduction(op,
result);
431 .Case([&](omp::TaskwaitOp op) {
435 .Case([&](omp::TaskloopContextOp op) {
436 checkAllocate(op,
result);
437 checkInReduction(op,
result);
438 checkReduction(op,
result);
440 .Case([&](omp::WsloopOp op) {
441 checkAllocate(op,
result);
443 checkReduction(op,
result);
445 .Case([&](omp::ParallelOp op) {
446 checkAllocate(op,
result);
447 checkReduction(op,
result);
448 checkNumThreads(op,
result);
450 .Case([&](omp::SimdOp op) { checkReduction(op,
result); })
451 .Case<omp::AtomicReadOp, omp::AtomicWriteOp, omp::AtomicUpdateOp,
452 omp::AtomicCaptureOp>([&](
auto op) { checkHint(op,
result); })
453 .Case<omp::TargetEnterDataOp, omp::TargetExitDataOp>(
454 [&](
auto op) { checkDepend(op,
result); })
455 .Case([&](omp::TargetUpdateOp op) { checkDepend(op,
result); })
456 .Case([&](omp::TargetOp op) {
457 checkAllocate(op,
result);
459 checkInReduction(op,
result);
460 checkThreadLimit(op,
result);
472 llvm::handleAllErrors(
474 [&](
const PreviouslyReportedError &) {
result = failure(); },
475 [&](
const llvm::ErrorInfoBase &err) {
498 llvm::OpenMPIRBuilder::InsertPointTy allocInsertPoint;
501 [&](OpenMPAllocStackFrame &frame) {
502 allocInsertPoint = frame.allocInsertPoint;
503 deallocInsertPoints = frame.deallocBlocks;
511 allocInsertPoint.getBlock()->getParent() ==
512 builder.GetInsertBlock()->getParent()) {
514 deallocBlocks->insert(deallocBlocks->end(), deallocInsertPoints.begin(),
515 deallocInsertPoints.end());
516 return allocInsertPoint;
526 if (builder.GetInsertBlock() ==
527 &builder.GetInsertBlock()->getParent()->getEntryBlock()) {
528 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end() &&
529 "Assuming end of basic block");
530 llvm::BasicBlock *entryBB = llvm::BasicBlock::Create(
531 builder.getContext(),
"entry", builder.GetInsertBlock()->getParent(),
532 builder.GetInsertBlock()->getNextNode());
533 builder.CreateBr(entryBB);
534 builder.SetInsertPoint(entryBB);
540 for (llvm::BasicBlock &block : *builder.GetInsertBlock()->getParent()) {
544 llvm::Instruction *terminator = block.getTerminatorOrNull();
545 if (isa_and_present<llvm::ReturnInst>(terminator))
546 deallocBlocks->emplace_back(&block);
550 llvm::BasicBlock &funcEntryBlock =
551 builder.GetInsertBlock()->getParent()->getEntryBlock();
552 return llvm::OpenMPIRBuilder::InsertPointTy(
553 &funcEntryBlock, funcEntryBlock.getFirstInsertionPt());
559static llvm::CanonicalLoopInfo *
561 llvm::CanonicalLoopInfo *loopInfo =
nullptr;
562 moduleTranslation.
stackWalk<OpenMPLoopInfoStackFrame>(
563 [&](OpenMPLoopInfoStackFrame &frame) {
564 loopInfo = frame.loopInfo;
576 Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder,
579 bool isLoopWrapper = isa<omp::LoopWrapperInterface>(region.
getParentOp());
581 llvm::BasicBlock *continuationBlock =
582 splitBB(builder,
true,
"omp.region.cont");
583 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
585 llvm::LLVMContext &llvmContext = builder.getContext();
586 for (
Block &bb : region) {
587 llvm::BasicBlock *llvmBB = llvm::BasicBlock::Create(
588 llvmContext, blockName, builder.GetInsertBlock()->getParent(),
589 builder.GetInsertBlock()->getNextNode());
590 moduleTranslation.
mapBlock(&bb, llvmBB);
593 llvm::Instruction *sourceTerminator = sourceBlock->getTerminator();
600 unsigned numYields = 0;
602 if (!isLoopWrapper) {
603 bool operandsProcessed =
false;
605 if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(bb.getTerminator())) {
606 if (!operandsProcessed) {
607 for (
unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
608 continuationBlockPHITypes.push_back(
609 moduleTranslation.
convertType(yield->getOperand(i).getType()));
611 operandsProcessed =
true;
613 assert(continuationBlockPHITypes.size() == yield->getNumOperands() &&
614 "mismatching number of values yielded from the region");
615 for (
unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
616 llvm::Type *operandType =
617 moduleTranslation.
convertType(yield->getOperand(i).getType());
619 assert(continuationBlockPHITypes[i] == operandType &&
620 "values of mismatching types yielded from the region");
630 if (!continuationBlockPHITypes.empty())
632 continuationBlockPHIs &&
633 "expected continuation block PHIs if converted regions yield values");
634 if (continuationBlockPHIs) {
635 llvm::IRBuilderBase::InsertPointGuard guard(builder);
636 continuationBlockPHIs->reserve(continuationBlockPHITypes.size());
637 builder.SetInsertPoint(continuationBlock, continuationBlock->begin());
638 for (llvm::Type *ty : continuationBlockPHITypes)
639 continuationBlockPHIs->push_back(builder.CreatePHI(ty, numYields));
645 for (
Block *bb : blocks) {
646 llvm::BasicBlock *llvmBB = moduleTranslation.
lookupBlock(bb);
649 if (bb->isEntryBlock()) {
650 assert(sourceTerminator->getNumSuccessors() == 1 &&
651 "provided entry block has multiple successors");
652 assert(sourceTerminator->getSuccessor(0) == continuationBlock &&
653 "ContinuationBlock is not the successor of the entry block");
654 sourceTerminator->setSuccessor(0, llvmBB);
657 llvm::IRBuilderBase::InsertPointGuard guard(builder);
659 moduleTranslation.
convertBlock(*bb, bb->isEntryBlock(), builder)))
660 return llvm::make_error<PreviouslyReportedError>();
665 builder.CreateBr(continuationBlock);
676 Operation *terminator = bb->getTerminator();
677 if (isa<omp::TerminatorOp, omp::YieldOp>(terminator)) {
678 builder.CreateBr(continuationBlock);
680 for (
unsigned i = 0, e = terminator->
getNumOperands(); i < e; ++i)
681 (*continuationBlockPHIs)[i]->addIncoming(
695 return continuationBlock;
701 case omp::ClauseProcBindKind::Close:
702 return llvm::omp::ProcBindKind::OMP_PROC_BIND_close;
703 case omp::ClauseProcBindKind::Master:
704 return llvm::omp::ProcBindKind::OMP_PROC_BIND_master;
705 case omp::ClauseProcBindKind::Primary:
706 return llvm::omp::ProcBindKind::OMP_PROC_BIND_primary;
707 case omp::ClauseProcBindKind::Spread:
708 return llvm::omp::ProcBindKind::OMP_PROC_BIND_spread;
710 llvm_unreachable(
"Unknown ClauseProcBindKind kind");
717 auto maskedOp = cast<omp::MaskedOp>(opInst);
718 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
723 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
726 auto ®ion = maskedOp.getRegion();
727 builder.restoreIP(codeGenIP);
735 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
737 llvm::Value *filterVal =
nullptr;
738 if (
auto filterVar = maskedOp.getFilteredThreadId()) {
739 filterVal = moduleTranslation.
lookupValue(filterVar);
741 llvm::LLVMContext &llvmContext = builder.getContext();
743 llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext), 0);
745 assert(filterVal !=
nullptr);
746 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
747 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
754 builder.restoreIP(*afterIP);
762 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
763 auto masterOp = cast<omp::MasterOp>(opInst);
768 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
771 auto ®ion = masterOp.getRegion();
772 builder.restoreIP(codeGenIP);
780 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
782 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
783 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
790 builder.restoreIP(*afterIP);
798 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
799 auto criticalOp = cast<omp::CriticalOp>(opInst);
804 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
807 auto ®ion = cast<omp::CriticalOp>(opInst).getRegion();
808 builder.restoreIP(codeGenIP);
816 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
818 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
819 llvm::LLVMContext &llvmContext = moduleTranslation.
getLLVMContext();
820 llvm::Constant *hint =
nullptr;
823 if (criticalOp.getNameAttr()) {
826 auto symbolRef = cast<SymbolRefAttr>(criticalOp.getNameAttr());
827 auto criticalDeclareOp =
831 llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext),
832 static_cast<int>(criticalDeclareOp.getHint()));
834 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
836 ompLoc, bodyGenCB, finiCB, criticalOp.getName().value_or(
""), hint);
841 builder.restoreIP(*afterIP);
848 template <
typename OP>
851 cast<
omp::BlockArgOpenMPOpInterface>(*op).getPrivateBlockArgs()) {
854 collectPrivatizationDecls<OP>(op);
869 void collectPrivatizationDecls(OP op) {
870 std::optional<ArrayAttr> attr = op.getPrivateSyms();
875 for (
auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
886 std::optional<ArrayAttr> attr = op.getReductionSyms();
890 reductions.reserve(reductions.size() + op.getNumReductionVars());
891 for (
auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
892 reductions.push_back(
904 Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder,
913 llvm::Instruction *potentialTerminator =
914 builder.GetInsertBlock()->empty() ?
nullptr
915 : &builder.GetInsertBlock()->back();
917 if (potentialTerminator && potentialTerminator->isTerminator())
918 potentialTerminator->removeFromParent();
919 moduleTranslation.
mapBlock(®ion.
front(), builder.GetInsertBlock());
922 region.
front(),
true, builder)))
926 if (continuationBlockArgs)
928 *continuationBlockArgs,
935 if (potentialTerminator && potentialTerminator->isTerminator()) {
936 llvm::BasicBlock *block = builder.GetInsertBlock();
937 if (block->empty()) {
943 potentialTerminator->insertInto(block, block->begin());
945 potentialTerminator->insertAfter(&block->back());
959 if (continuationBlockArgs)
960 llvm::append_range(*continuationBlockArgs, phis);
961 builder.SetInsertPoint(*continuationBlock,
962 (*continuationBlock)->getFirstInsertionPt());
969using OwningReductionGen =
970 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
971 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *,
973using OwningAtomicReductionGen =
974 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
975 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Type *, llvm::Value *,
977using OwningDataPtrPtrReductionGen =
978 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
979 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *&)>;
985static OwningReductionGen
991 OwningReductionGen gen =
992 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
993 llvm::Value *
lhs, llvm::Value *
rhs,
994 llvm::Value *&
result)
mutable
995 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
996 moduleTranslation.
mapValue(decl.getReductionLhsArg(),
lhs);
997 moduleTranslation.
mapValue(decl.getReductionRhsArg(),
rhs);
998 builder.restoreIP(insertPoint);
1001 "omp.reduction.nonatomic.body", builder,
1002 moduleTranslation, &phis)))
1003 return llvm::createStringError(
1004 "failed to inline `combiner` region of `omp.declare_reduction`");
1005 result = llvm::getSingleElement(phis);
1006 return builder.saveIP();
1015static OwningAtomicReductionGen
1017 llvm::IRBuilderBase &builder,
1019 if (decl.getAtomicReductionRegion().empty())
1020 return OwningAtomicReductionGen();
1026 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint, llvm::Type *,
1027 llvm::Value *
lhs, llvm::Value *
rhs)
mutable
1028 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1029 moduleTranslation.
mapValue(decl.getAtomicReductionLhsArg(),
lhs);
1030 moduleTranslation.
mapValue(decl.getAtomicReductionRhsArg(),
rhs);
1031 builder.restoreIP(insertPoint);
1034 "omp.reduction.atomic.body", builder,
1035 moduleTranslation, &phis)))
1036 return llvm::createStringError(
1037 "failed to inline `atomic` region of `omp.declare_reduction`");
1038 assert(phis.empty());
1039 return builder.saveIP();
1048static OwningDataPtrPtrReductionGen
1051 if (!isByRef || decl.getDataPtrPtrRegion().empty())
1052 return OwningDataPtrPtrReductionGen();
1054 OwningDataPtrPtrReductionGen refDataPtrGen =
1055 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
1056 llvm::Value *byRefVal, llvm::Value *&
result)
mutable
1057 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1058 moduleTranslation.
mapValue(decl.getDataPtrPtrRegionArg(), byRefVal);
1059 builder.restoreIP(insertPoint);
1062 "omp.data_ptr_ptr.body", builder,
1063 moduleTranslation, &phis)))
1064 return llvm::createStringError(
1065 "failed to inline `data_ptr_ptr` region of `omp.declare_reduction`");
1066 result = llvm::getSingleElement(phis);
1067 return builder.saveIP();
1070 return refDataPtrGen;
1077 auto orderedOp = cast<omp::OrderedOp>(opInst);
1082 omp::ClauseDepend dependType = *orderedOp.getDoacrossDependType();
1083 bool isDependSource = dependType == omp::ClauseDepend::dependsource;
1084 unsigned numLoops = *orderedOp.getDoacrossNumLoops();
1086 moduleTranslation.
lookupValues(orderedOp.getDoacrossDependVars());
1088 size_t indexVecValues = 0;
1089 while (indexVecValues < vecValues.size()) {
1091 storeValues.reserve(numLoops);
1092 for (
unsigned i = 0; i < numLoops; i++) {
1093 storeValues.push_back(vecValues[indexVecValues]);
1096 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1098 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1099 builder.restoreIP(moduleTranslation.
getOpenMPBuilder()->createOrderedDepend(
1100 ompLoc, allocaIP, numLoops, storeValues,
".cnt.addr", isDependSource));
1110 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1111 auto orderedRegionOp = cast<omp::OrderedRegionOp>(opInst);
1116 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
1119 auto ®ion = cast<omp::OrderedRegionOp>(opInst).getRegion();
1120 builder.restoreIP(codeGenIP);
1128 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
1130 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1131 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1133 ompLoc, bodyGenCB, finiCB, !orderedRegionOp.getParLevelSimd());
1138 builder.restoreIP(*afterIP);
1144struct DeferredStore {
1145 DeferredStore(llvm::Value *value, llvm::Value *address)
1146 : value(value), address(address) {}
1149 llvm::Value *address;
1156template <
typename T>
1159 llvm::IRBuilderBase &builder,
1161 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1167 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1168 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1174 deferredStores.reserve(op.getNumReductionVars());
1176 for (std::size_t i = 0; i < op.getNumReductionVars(); ++i) {
1177 Region &allocRegion = reductionDecls[i].getAllocRegion();
1179 if (allocRegion.
empty())
1184 builder, moduleTranslation, &phis)))
1185 return op.emitError(
1186 "failed to inline `alloc` region of `omp.declare_reduction`");
1188 assert(phis.size() == 1 &&
"expected one allocation to be yielded");
1189 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1193 llvm::Type *ptrTy = builder.getPtrTy();
1197 if (useDeviceSharedMem) {
1198 var = ompBuilder->createOMPAllocShared(builder, varTy);
1200 var = builder.CreateAlloca(varTy);
1201 var = builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
1204 llvm::Value *castPhi =
1205 builder.CreatePointerBitCastOrAddrSpaceCast(phis[0], ptrTy);
1207 deferredStores.emplace_back(castPhi, var);
1209 privateReductionVariables[i] = var;
1210 moduleTranslation.
mapValue(reductionArgs[i], castPhi);
1211 reductionVariableMap.try_emplace(op.getReductionVars()[i], castPhi);
1213 assert(allocRegion.
empty() &&
1214 "allocaction is implicit for by-val reduction");
1216 llvm::Type *ptrTy = builder.getPtrTy();
1220 if (useDeviceSharedMem) {
1221 var = ompBuilder->createOMPAllocShared(builder, varTy);
1223 var = builder.CreateAlloca(varTy);
1224 var = builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
1227 moduleTranslation.
mapValue(reductionArgs[i], var);
1228 privateReductionVariables[i] = var;
1229 reductionVariableMap.try_emplace(op.getReductionVars()[i], var);
1237template <
typename T>
1240 llvm::IRBuilderBase &builder,
1245 mlir::omp::DeclareReductionOp &reduction = reductionDecls[i];
1246 Region &initializerRegion = reduction.getInitializerRegion();
1249 mlir::Value mlirSource = loop.getReductionVars()[i];
1250 llvm::Value *llvmSource = moduleTranslation.
lookupValue(mlirSource);
1251 llvm::Value *origVal = llvmSource;
1253 if (!isa<LLVM::LLVMPointerType>(
1254 reduction.getInitializerMoldArg().getType()) &&
1255 isa<LLVM::LLVMPointerType>(mlirSource.
getType())) {
1258 reduction.getInitializerMoldArg().getType()),
1259 llvmSource,
"omp_orig");
1261 moduleTranslation.
mapValue(reduction.getInitializerMoldArg(), origVal);
1264 llvm::Value *allocation =
1265 reductionVariableMap.lookup(loop.getReductionVars()[i]);
1266 moduleTranslation.
mapValue(reduction.getInitializerAllocArg(), allocation);
1272 llvm::BasicBlock *block =
nullptr) {
1273 if (block ==
nullptr)
1274 block = builder.GetInsertBlock();
1276 if (!block->hasTerminator())
1277 builder.SetInsertPoint(block);
1279 builder.SetInsertPoint(block->getTerminator());
1287template <
typename OP>
1290 llvm::IRBuilderBase &builder,
1292 llvm::BasicBlock *latestAllocaBlock,
1298 if (op.getNumReductionVars() == 0)
1304 llvm::BasicBlock *initBlock = splitBB(builder,
true,
"omp.reduction.init");
1305 auto allocaIP = llvm::IRBuilderBase::InsertPoint(
1306 latestAllocaBlock, latestAllocaBlock->getTerminator()->getIterator());
1307 builder.restoreIP(allocaIP);
1310 for (
unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1312 if (!reductionDecls[i].getAllocRegion().empty())
1320 if (useDeviceSharedMem)
1321 byRefVars[i] = ompBuilder->createOMPAllocShared(builder, varTy);
1323 byRefVars[i] = builder.CreateAlloca(varTy);
1331 for (
auto [data, addr] : deferredStores)
1332 builder.CreateStore(data, addr);
1337 for (
unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1342 reductionVariableMap, i);
1350 "omp.reduction.neutral", builder,
1351 moduleTranslation, &phis)))
1354 assert(phis.size() == 1 &&
"expected one value to be yielded from the "
1355 "reduction neutral element declaration region");
1360 if (!reductionDecls[i].getAllocRegion().empty())
1369 builder.CreateStore(phis[0], byRefVars[i]);
1371 privateReductionVariables[i] = byRefVars[i];
1372 moduleTranslation.
mapValue(reductionArgs[i], phis[0]);
1373 reductionVariableMap.try_emplace(op.getReductionVars()[i], phis[0]);
1376 builder.CreateStore(phis[0], privateReductionVariables[i]);
1383 moduleTranslation.
forgetMapping(reductionDecls[i].getInitializerRegion());
1390template <
typename T>
1391static void collectReductionInfo(
1392 T loop, llvm::IRBuilderBase &builder,
1401 unsigned numReductions = loop.getNumReductionVars();
1403 for (
unsigned i = 0; i < numReductions; ++i) {
1406 owningAtomicReductionGens.push_back(
1409 reductionDecls[i], builder, moduleTranslation, isByRef[i]));
1413 reductionInfos.reserve(numReductions);
1414 for (
unsigned i = 0; i < numReductions; ++i) {
1415 llvm::OpenMPIRBuilder::ReductionGenAtomicCBTy
atomicGen =
nullptr;
1416 if (owningAtomicReductionGens[i])
1417 atomicGen = owningAtomicReductionGens[i];
1418 llvm::Value *variable =
1419 moduleTranslation.
lookupValue(loop.getReductionVars()[i]);
1422 if (
auto alloca = mlir::dyn_cast<LLVM::AllocaOp>(op)) {
1423 allocatedType = alloca.getElemType();
1430 reductionInfos.push_back(
1432 privateReductionVariables[i],
1433 llvm::OpenMPIRBuilder::EvalKind::Scalar,
1437 allocatedType ? moduleTranslation.
convertType(allocatedType) :
nullptr,
1438 reductionDecls[i].getByrefElementType()
1440 *reductionDecls[i].getByrefElementType())
1450 llvm::IRBuilderBase &builder, StringRef regionName,
1451 bool shouldLoadCleanupRegionArg =
true) {
1452 for (
auto [i, cleanupRegion] : llvm::enumerate(cleanupRegions)) {
1453 if (cleanupRegion->empty())
1459 llvm::Instruction *potentialTerminator =
1460 builder.GetInsertBlock()->empty() ?
nullptr
1461 : &builder.GetInsertBlock()->back();
1462 if (potentialTerminator && potentialTerminator->isTerminator())
1463 builder.SetInsertPoint(potentialTerminator);
1464 llvm::Value *privateVarValue =
1465 shouldLoadCleanupRegionArg
1466 ? builder.CreateLoad(
1468 privateVariables[i])
1469 : privateVariables[i];
1474 moduleTranslation)))
1487 OP op, llvm::IRBuilderBase &builder,
1489 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1492 bool isNowait =
false,
bool isTeamsReduction =
false) {
1494 if (op.getNumReductionVars() == 0)
1506 collectReductionInfo(op, builder, moduleTranslation, reductionDecls,
1508 owningReductionGenRefDataPtrGens,
1509 privateReductionVariables, reductionInfos, isByRef);
1514 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
1515 builder.SetInsertPoint(tempTerminator);
1516 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
1517 ompBuilder->createReductions(builder.saveIP(), allocaIP, reductionInfos,
1518 isByRef, isNowait, isTeamsReduction);
1523 if (!contInsertPoint->getBlock())
1524 return op->emitOpError() <<
"failed to convert reductions";
1526 llvm::OpenMPIRBuilder::InsertPointTy afterIP = *contInsertPoint;
1527 if (!isTeamsReduction) {
1528 llvm::OpenMPIRBuilder::InsertPointOrErrorTy barrierIP =
1529 ompBuilder->createBarrier(*contInsertPoint, llvm::omp::OMPD_for);
1533 afterIP = *barrierIP;
1536 tempTerminator->eraseFromParent();
1537 builder.restoreIP(afterIP);
1541 llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
1542 [](omp::DeclareReductionOp reductionDecl) {
1543 return &reductionDecl.getCleanupRegion();
1546 reductionRegions, privateReductionVariables, moduleTranslation, builder,
1547 "omp.reduction.cleanup");
1550 if (useDeviceSharedMem) {
1551 for (
auto [var, reductionDecl] :
1552 llvm::zip_equal(privateReductionVariables, reductionDecls))
1553 ompBuilder->createOMPFreeShared(
1554 builder, var, moduleTranslation.
convertType(reductionDecl.getType()));
1567template <
typename OP>
1571 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1576 if (op.getNumReductionVars() == 0)
1582 allocaIP, reductionDecls,
1583 privateReductionVariables, reductionVariableMap,
1584 deferredStores, isByRef)))
1587 return initReductionVars(op, reductionArgs, builder, moduleTranslation,
1588 allocaIP.getBlock(), reductionDecls,
1589 privateReductionVariables, reductionVariableMap,
1590 isByRef, deferredStores);
1604 if (mappedPrivateVars ==
nullptr || !mappedPrivateVars->contains(privateVar))
1607 Value blockArg = (*mappedPrivateVars)[privateVar];
1610 assert(isa<LLVM::LLVMPointerType>(blockArgType) &&
1611 "A block argument corresponding to a mapped var should have "
1614 if (privVarType == blockArgType)
1621 if (!isa<LLVM::LLVMPointerType>(privVarType))
1622 return builder.CreateLoad(moduleTranslation.
convertType(privVarType),
1635 omp::PrivateClauseOp &privDecl, llvm::Value *nonPrivateVar,
1637 llvm::BasicBlock *privInitBlock,
1639 Region &initRegion = privDecl.getInitRegion();
1640 if (initRegion.
empty())
1641 return llvmPrivateVar;
1643 assert(nonPrivateVar);
1644 moduleTranslation.
mapValue(privDecl.getInitMoldArg(), nonPrivateVar);
1645 moduleTranslation.
mapValue(privDecl.getInitPrivateArg(), llvmPrivateVar);
1650 moduleTranslation, &phis)))
1651 return llvm::createStringError(
1652 "failed to inline `init` region of `omp.private`");
1654 assert(phis.size() == 1 &&
"expected one allocation to be yielded");
1671 llvm::Value *llvmPrivateVar, llvm::BasicBlock *privInitBlock,
1674 builder, moduleTranslation, privDecl,
1677 blockArg, llvmPrivateVar, privInitBlock, mappedPrivateVars);
1686 return llvm::Error::success();
1688 llvm::BasicBlock *privInitBlock = splitBB(builder,
true,
"omp.private.init");
1691 for (
auto [idx, zip] : llvm::enumerate(llvm::zip_equal(
1694 auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVar] = zip;
1696 builder, moduleTranslation, privDecl, mlirPrivVar, blockArg,
1697 llvmPrivateVar, privInitBlock, mappedPrivateVars);
1700 return privVarOrErr.takeError();
1702 llvmPrivateVar = privVarOrErr.get();
1703 moduleTranslation.
mapValue(blockArg, llvmPrivateVar);
1708 return llvm::Error::success();
1714template <
typename T>
1719 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1722 llvm::Instruction *allocaTerminator = allocaIP.getBlock()->getTerminator();
1723 splitBB(llvm::OpenMPIRBuilder::InsertPointTy(allocaIP.getBlock(),
1724 allocaTerminator->getIterator()),
1725 true, allocaTerminator->getStableDebugLoc(),
1726 "omp.region.after_alloca");
1728 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1730 allocaTerminator = allocaIP.getBlock()->getTerminator();
1731 builder.SetInsertPoint(allocaTerminator);
1733 assert(allocaTerminator->getNumSuccessors() == 1 &&
1734 "This is an unconditional branch created by splitBB");
1736 llvm::DataLayout dataLayout = builder.GetInsertBlock()->getDataLayout();
1737 llvm::BasicBlock *afterAllocas = allocaTerminator->getSuccessor(0);
1741 unsigned int allocaAS =
1742 moduleTranslation.
getLLVMModule()->getDataLayout().getAllocaAddrSpace();
1745 .getProgramAddressSpace();
1747 for (
auto [privDecl, mlirPrivVar, blockArg] :
1750 llvm::Type *llvmAllocType =
1751 moduleTranslation.
convertType(privDecl.getType());
1752 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1753 llvm::Value *llvmPrivateVar =
nullptr;
1755 llvmPrivateVar = ompBuilder->createOMPAllocShared(builder, llvmAllocType);
1757 llvmPrivateVar = builder.CreateAlloca(
1758 llvmAllocType,
nullptr,
"omp.private.alloc");
1759 if (allocaAS != defaultAS)
1760 llvmPrivateVar = builder.CreateAddrSpaceCast(
1761 llvmPrivateVar, builder.getPtrTy(defaultAS));
1764 privateVarsInfo.
llvmVars.push_back(llvmPrivateVar);
1767 return afterAllocas;
1775 if (mlir::isa<omp::SingleOp, omp::CriticalOp>(parent))
1784 if (mlir::isa<omp::ParallelOp>(parent))
1798 bool needsFirstprivate =
1799 llvm::any_of(privateDecls, [](omp::PrivateClauseOp &privOp) {
1800 return privOp.getDataSharingType() ==
1801 omp::DataSharingClauseType::FirstPrivate;
1804 if (!needsFirstprivate)
1807 llvm::BasicBlock *copyBlock =
1808 splitBB(builder,
true,
"omp.private.copy");
1811 for (
auto [decl, moldVar, llvmVar] :
1812 llvm::zip_equal(privateDecls, moldVars, llvmPrivateVars)) {
1813 if (decl.getDataSharingType() != omp::DataSharingClauseType::FirstPrivate)
1817 Region ©Region = decl.getCopyRegion();
1819 moduleTranslation.
mapValue(decl.getCopyMoldArg(), moldVar);
1822 moduleTranslation.
mapValue(decl.getCopyPrivateArg(), llvmVar);
1826 moduleTranslation)))
1827 return decl.emitError(
"failed to inline `copy` region of `omp.private`");
1841 llvm::OpenMPIRBuilder::InsertPointOrErrorTy res =
1842 ompBuilder->createBarrier(builder.saveIP(), llvm::omp::OMPD_barrier);
1858 llvm::transform(mlirPrivateVars, moldVars.begin(), [&](
mlir::Value mlirVar) {
1860 llvm::Value *moldVar = findAssociatedValue(
1861 mlirVar, builder, moduleTranslation, mappedPrivateVars);
1866 llvmPrivateVars, privateDecls, insertBarrier,
1870template <
typename T>
1878 std::back_inserter(privateCleanupRegions),
1879 [](omp::PrivateClauseOp privatizer) {
1880 return &privatizer.getDeallocRegion();
1884 privateVarsInfo.
llvmVars, moduleTranslation,
1885 builder,
"omp.private.dealloc",
1887 return mlir::emitError(loc,
"failed to inline `dealloc` region of an "
1888 "`omp.private` op in");
1892 for (
auto [privDecl, llvmPrivVar, blockArg] :
1896 ompBuilder->createOMPFreeShared(
1897 builder, llvmPrivVar,
1898 moduleTranslation.
convertType(privDecl.getType()));
1912 if (mlir::isa<omp::CancelOp, omp::CancellationPointOp>(child))
1922 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1923 using StorableBodyGenCallbackTy =
1924 llvm::OpenMPIRBuilder::StorableBodyGenCallbackTy;
1926 auto sectionsOp = cast<omp::SectionsOp>(opInst);
1932 assert(isByRef.size() == sectionsOp.getNumReductionVars());
1936 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1940 sectionsOp.getNumReductionVars());
1944 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
1947 sectionsOp, reductionArgs, builder, moduleTranslation, allocaIP,
1948 reductionDecls, privateReductionVariables, reductionVariableMap,
1955 auto sectionOp = dyn_cast<omp::SectionOp>(op);
1959 Region ®ion = sectionOp.getRegion();
1960 auto sectionCB = [§ionsOp, ®ion, &builder, &moduleTranslation](
1961 InsertPointTy allocaIP, InsertPointTy codeGenIP,
1963 builder.restoreIP(codeGenIP);
1970 sectionsOp.getRegion().getNumArguments());
1971 for (
auto [sectionsArg, sectionArg] : llvm::zip_equal(
1972 sectionsOp.getRegion().getArguments(), region.
getArguments())) {
1973 llvm::Value *llvmVal = moduleTranslation.
lookupValue(sectionsArg);
1975 moduleTranslation.
mapValue(sectionArg, llvmVal);
1982 sectionCBs.push_back(sectionCB);
1988 if (sectionCBs.empty())
1991 assert(isa<omp::SectionOp>(*sectionsOp.getRegion().op_begin()));
1996 auto privCB = [&](InsertPointTy, InsertPointTy codeGenIP, llvm::Value &,
1997 llvm::Value &vPtr, llvm::Value *&replacementValue)
1998 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1999 replacementValue = &vPtr;
2005 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
2009 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2010 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2012 ompLoc, allocaIP, sectionCBs, privCB, finiCB, isCancellable,
2013 sectionsOp.getNowait());
2018 builder.restoreIP(*afterIP);
2022 sectionsOp, builder, moduleTranslation, allocaIP, reductionDecls,
2023 privateReductionVariables, isByRef, sectionsOp.getNowait());
2030 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2037 assert(isByRef.size() == scopeOp.getNumReductionVars());
2046 scopeOp.getNumReductionVars());
2050 cast<omp::BlockArgOpenMPOpInterface>(*scopeOp).getReductionBlockArgs();
2054 scopeOp, builder, moduleTranslation, privateVarsInfo, allocaIP);
2059 scopeOp, reductionArgs, builder, moduleTranslation, allocaIP,
2060 reductionDecls, privateReductionVariables, reductionVariableMap,
2065 [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
2067 builder.restoreIP(codeGenIP);
2073 return llvm::make_error<PreviouslyReportedError>();
2076 scopeOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
2078 scopeOp.getPrivateNeedsBarrier())))
2079 return llvm::make_error<PreviouslyReportedError>();
2086 auto finiCB = [&](InsertPointTy codeGenIP) -> llvm::Error {
2087 InsertPointTy oldIP = builder.saveIP();
2088 builder.restoreIP(codeGenIP);
2090 scopeOp.getLoc(), privateVarsInfo)))
2091 return llvm::make_error<PreviouslyReportedError>();
2092 builder.restoreIP(oldIP);
2093 return llvm::Error::success();
2096 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2097 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2098 ompBuilder->createScope(ompLoc, bodyCB, finiCB, scopeOp.getNowait());
2103 builder.restoreIP(*afterIP);
2107 scopeOp, builder, moduleTranslation, allocaIP, reductionDecls,
2108 privateReductionVariables, isByRef, scopeOp.getNowait(),
2116 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2117 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2122 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP,
2124 builder.restoreIP(codegenIP);
2126 builder, moduleTranslation)
2129 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
2133 std::optional<ArrayAttr> cpFuncs = singleOp.getCopyprivateSyms();
2136 for (
size_t i = 0, e = cpVars.size(); i < e; ++i) {
2137 llvmCPVars.push_back(moduleTranslation.
lookupValue(cpVars[i]));
2139 singleOp, cast<SymbolRefAttr>((*cpFuncs)[i]));
2140 llvmCPFuncs.push_back(
2144 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2146 ompLoc, bodyCB, finiCB, singleOp.getNowait(), llvmCPVars,
2152 builder.restoreIP(*afterIP);
2156static omp::DistributeOp
2160 omp::DistributeOp distOp;
2161 WalkResult walk = teamsOp.getRegion().walk([&](omp::DistributeOp op) {
2167 if (walk.wasInterrupted() || !distOp)
2171 llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(teamsOp.getOperation());
2175 for (
auto ra : iface.getReductionBlockArgs())
2176 for (
auto &use : ra.getUses()) {
2177 auto *useOp = use.getOwner();
2179 if (mlir::isa<LLVM::DbgDeclareOp, LLVM::DbgValueOp>(useOp)) {
2180 debugUses.push_back(useOp);
2183 if (!distOp->isProperAncestor(useOp))
2190 for (
auto *use : debugUses)
2199 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2204 unsigned numReductionVars = op.getNumReductionVars();
2208 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2214 if (doTeamsReduction) {
2215 isByRef =
getIsByRef(op.getReductionByref());
2217 assert(isByRef.size() == op.getNumReductionVars());
2220 llvm::cast<omp::BlockArgOpenMPOpInterface>(*op).getReductionBlockArgs();
2225 op, reductionArgs, builder, moduleTranslation, allocaIP,
2226 reductionDecls, privateReductionVariables, reductionVariableMap,
2231 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP,
2234 moduleTranslation, allocaIP, deallocBlocks);
2235 builder.restoreIP(codegenIP);
2241 llvm::Value *numTeamsLower =
nullptr;
2242 if (
Value numTeamsLowerVar = op.getNumTeamsLower())
2243 numTeamsLower = moduleTranslation.
lookupValue(numTeamsLowerVar);
2245 llvm::Value *numTeamsUpper =
nullptr;
2246 if (!op.getNumTeamsUpperVars().empty())
2247 numTeamsUpper = moduleTranslation.
lookupValue(op.getNumTeams(0));
2249 llvm::Value *threadLimit =
nullptr;
2250 if (!op.getThreadLimitVars().empty())
2251 threadLimit = moduleTranslation.
lookupValue(op.getThreadLimit(0));
2253 llvm::Value *ifExpr =
nullptr;
2254 if (
Value ifVar = op.getIfExpr())
2257 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2258 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2260 ompLoc, bodyCB, numTeamsLower, numTeamsUpper, threadLimit, ifExpr);
2265 builder.restoreIP(*afterIP);
2266 if (doTeamsReduction) {
2269 op, builder, moduleTranslation, allocaIP, reductionDecls,
2270 privateReductionVariables, isByRef,
2276static llvm::omp::RTLDependenceKindTy
2279 case mlir::omp::ClauseTaskDepend::taskdependin:
2280 return llvm::omp::RTLDependenceKindTy::DepIn;
2284 case mlir::omp::ClauseTaskDepend::taskdependout:
2285 case mlir::omp::ClauseTaskDepend::taskdependinout:
2286 return llvm::omp::RTLDependenceKindTy::DepInOut;
2287 case mlir::omp::ClauseTaskDepend::taskdependmutexinoutset:
2288 return llvm::omp::RTLDependenceKindTy::DepMutexInOutSet;
2289 case mlir::omp::ClauseTaskDepend::taskdependinoutset:
2290 return llvm::omp::RTLDependenceKindTy::DepInOutSet;
2292 llvm_unreachable(
"unhandled depend kind");
2296 std::optional<ArrayAttr> dependKinds,
OperandRange dependVars,
2299 if (dependVars.empty())
2301 for (
auto dep : llvm::zip(dependVars, dependKinds->getValue())) {
2303 cast<mlir::omp::ClauseTaskDependAttr>(std::get<1>(dep)).getValue();
2305 llvm::Value *depVal = moduleTranslation.
lookupValue(std::get<0>(dep));
2306 llvm::OpenMPIRBuilder::DependData dd(type, depVal->getType(), depVal);
2307 dds.emplace_back(dd);
2319 llvm::IRBuilderBase &llvmBuilder, llvm::OpenMPIRBuilder &ompBuilder,
2321 auto finiCB = [&](llvm::OpenMPIRBuilder::InsertPointTy ip) -> llvm::Error {
2322 llvm::IRBuilderBase::InsertPointGuard guard(llvmBuilder);
2326 llvmBuilder.restoreIP(ip);
2332 cancelTerminators.push_back(llvmBuilder.CreateBr(ip.getBlock()));
2333 return llvm::Error::success();
2338 ompBuilder.pushFinalizationCB(
2348 llvm::OpenMPIRBuilder &ompBuilder,
2349 const llvm::OpenMPIRBuilder::InsertPointTy &afterIP) {
2350 ompBuilder.popFinalizationCB();
2351 llvm::BasicBlock *constructFini = afterIP.getBlock()->getSinglePredecessor();
2352 for (llvm::UncondBrInst *cancelBranch : cancelTerminators)
2353 cancelBranch->setSuccessor(constructFini);
2359class TaskContextStructManager {
2361 TaskContextStructManager(llvm::IRBuilderBase &builder,
2362 LLVM::ModuleTranslation &moduleTranslation,
2363 MutableArrayRef<omp::PrivateClauseOp> privateDecls)
2364 : builder{builder}, moduleTranslation{moduleTranslation},
2365 privateDecls{privateDecls} {}
2371 void generateTaskContextStruct();
2377 void createGEPsToPrivateVars();
2383 SmallVector<llvm::Value *>
2384 createGEPsToPrivateVars(llvm::Value *altStructPtr)
const;
2387 void freeStructPtr();
2389 MutableArrayRef<llvm::Value *> getLLVMPrivateVarGEPs() {
2390 return llvmPrivateVarGEPs;
2393 llvm::Value *getStructPtr() {
return structPtr; }
2396 llvm::IRBuilderBase &builder;
2397 LLVM::ModuleTranslation &moduleTranslation;
2398 MutableArrayRef<omp::PrivateClauseOp> privateDecls;
2401 SmallVector<llvm::Type *> privateVarTypes;
2405 SmallVector<llvm::Value *> llvmPrivateVarGEPs;
2408 llvm::Value *structPtr =
nullptr;
2410 llvm::Type *structTy =
nullptr;
2421 llvm::SmallVector<llvm::Value *> lowerBounds;
2422 llvm::SmallVector<llvm::Value *> upperBounds;
2423 llvm::SmallVector<llvm::Value *> steps;
2424 llvm::SmallVector<llvm::Value *> trips;
2426 llvm::Value *totalTrips;
2428 llvm::Value *lookUpAsI64(mlir::Value val,
const LLVM::ModuleTranslation &mt,
2429 llvm::IRBuilderBase &builder) {
2433 if (v->getType()->isIntegerTy(64))
2435 if (v->getType()->isIntegerTy())
2436 return builder.CreateSExtOrTrunc(v, builder.getInt64Ty());
2441 IteratorInfo(mlir::omp::IteratorOp itersOp,
2442 mlir::LLVM::ModuleTranslation &moduleTranslation,
2443 llvm::IRBuilderBase &builder) {
2444 dims = itersOp.getLoopLowerBounds().size();
2445 lowerBounds.resize(dims);
2446 upperBounds.resize(dims);
2450 for (
unsigned d = 0; d < dims; ++d) {
2451 llvm::Value *lb = lookUpAsI64(itersOp.getLoopLowerBounds()[d],
2452 moduleTranslation, builder);
2453 llvm::Value *ub = lookUpAsI64(itersOp.getLoopUpperBounds()[d],
2454 moduleTranslation, builder);
2456 lookUpAsI64(itersOp.getLoopSteps()[d], moduleTranslation, builder);
2457 assert(lb && ub && st &&
2458 "Expect lowerBounds, upperBounds, and steps in IteratorOp");
2459 assert((!llvm::isa<llvm::ConstantInt>(st) ||
2460 !llvm::cast<llvm::ConstantInt>(st)->isZero()) &&
2461 "Expect non-zero step in IteratorOp");
2463 lowerBounds[d] = lb;
2464 upperBounds[d] = ub;
2468 llvm::Value *diff = builder.CreateSub(ub, lb);
2469 llvm::Value *
div = builder.CreateSDiv(diff, st);
2470 trips[d] = builder.CreateAdd(
2471 div, llvm::ConstantInt::get(builder.getInt64Ty(), 1));
2474 totalTrips = llvm::ConstantInt::get(builder.getInt64Ty(), 1);
2475 for (
unsigned d = 0; d < dims; ++d)
2476 totalTrips = builder.CreateMul(totalTrips, trips[d]);
2479 unsigned getDims()
const {
return dims; }
2480 llvm::ArrayRef<llvm::Value *> getLowerBounds()
const {
return lowerBounds; }
2481 llvm::ArrayRef<llvm::Value *> getUpperBounds()
const {
return upperBounds; }
2482 llvm::ArrayRef<llvm::Value *> getSteps()
const {
return steps; }
2483 llvm::ArrayRef<llvm::Value *> getTrips()
const {
return trips; }
2484 llvm::Value *getTotalTrips()
const {
return totalTrips; }
2489void TaskContextStructManager::generateTaskContextStruct() {
2490 if (privateDecls.empty())
2492 privateVarTypes.reserve(privateDecls.size());
2494 for (omp::PrivateClauseOp &privOp : privateDecls) {
2497 if (!privOp.readsFromMold())
2499 Type mlirType = privOp.getType();
2500 privateVarTypes.push_back(moduleTranslation.
convertType(mlirType));
2503 if (privateVarTypes.empty())
2506 structTy = llvm::StructType::get(moduleTranslation.
getLLVMContext(),
2509 llvm::DataLayout dataLayout =
2510 builder.GetInsertBlock()->getModule()->getDataLayout();
2511 llvm::Type *intPtrTy = builder.getIntPtrTy(dataLayout);
2512 llvm::Constant *allocSize = llvm::ConstantExpr::getSizeOf(structTy);
2515 structPtr = builder.CreateMalloc(intPtrTy, structTy, allocSize,
2517 "omp.task.context_ptr");
2520SmallVector<llvm::Value *> TaskContextStructManager::createGEPsToPrivateVars(
2521 llvm::Value *altStructPtr)
const {
2522 SmallVector<llvm::Value *> ret;
2525 ret.reserve(privateDecls.size());
2526 llvm::Value *zero = builder.getInt32(0);
2528 for (
auto privDecl : privateDecls) {
2529 if (!privDecl.readsFromMold()) {
2531 ret.push_back(
nullptr);
2534 llvm::Value *iVal = builder.getInt32(i);
2535 llvm::Value *gep = builder.CreateGEP(structTy, altStructPtr, {zero, iVal});
2542void TaskContextStructManager::createGEPsToPrivateVars() {
2544 assert(privateVarTypes.empty());
2548 llvmPrivateVarGEPs = createGEPsToPrivateVars(structPtr);
2551void TaskContextStructManager::freeStructPtr() {
2555 llvm::IRBuilderBase::InsertPointGuard guard{builder};
2557 builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
2558 builder.CreateFree(structPtr);
2562 llvm::OpenMPIRBuilder &ompBuilder,
2563 llvm::Value *affinityList, llvm::Value *
index,
2564 llvm::Value *addr, llvm::Value *len) {
2565 llvm::StructType *kmpTaskAffinityInfoTy =
2566 ompBuilder.getKmpTaskAffinityInfoTy();
2567 llvm::Value *entry = builder.CreateInBoundsGEP(
2568 kmpTaskAffinityInfoTy, affinityList,
index,
"omp.affinity.entry");
2570 addr = builder.CreatePtrToInt(addr, kmpTaskAffinityInfoTy->getElementType(0));
2571 len = builder.CreateIntCast(len, kmpTaskAffinityInfoTy->getElementType(1),
2573 llvm::Value *flags = builder.getInt32(0);
2575 builder.CreateStore(addr,
2576 builder.CreateStructGEP(kmpTaskAffinityInfoTy, entry, 0));
2577 builder.CreateStore(len,
2578 builder.CreateStructGEP(kmpTaskAffinityInfoTy, entry, 1));
2579 builder.CreateStore(flags,
2580 builder.CreateStructGEP(kmpTaskAffinityInfoTy, entry, 2));
2584 llvm::IRBuilderBase &builder,
2586 llvm::Value *affinityList) {
2587 for (
auto [i, affinityVar] : llvm::enumerate(affinityVars)) {
2588 auto entryOp = affinityVar.getDefiningOp<mlir::omp::AffinityEntryOp>();
2589 assert(entryOp &&
"affinity item must be omp.affinity_entry");
2591 llvm::Value *addr = moduleTranslation.
lookupValue(entryOp.getAddr());
2592 llvm::Value *len = moduleTranslation.
lookupValue(entryOp.getLen());
2593 assert(addr && len &&
"expect affinity addr and len to be non-null");
2595 affinityList, builder.getInt64(i), addr, len);
2599static mlir::LogicalResult
2602 llvm::IRBuilderBase &builder,
2604 llvm::Value *tmp = linearIV;
2605 for (
int d = (
int)iterInfo.getDims() - 1; d >= 0; --d) {
2606 llvm::Value *trip = iterInfo.getTrips()[d];
2608 llvm::Value *idx = builder.CreateURem(tmp, trip);
2610 tmp = builder.CreateUDiv(tmp, trip);
2613 llvm::Value *physIV = builder.CreateAdd(
2614 iterInfo.getLowerBounds()[d],
2615 builder.CreateMul(idx, iterInfo.getSteps()[d]),
"omp.it.phys_iv");
2621 moduleTranslation.
mapBlock(&iteratorRegionBlock, builder.GetInsertBlock());
2622 if (mlir::failed(moduleTranslation.
convertBlock(iteratorRegionBlock,
2625 return mlir::failure();
2627 return mlir::success();
2633static mlir::LogicalResult
2636 IteratorInfo &iterInfo, llvm::StringRef loopName,
2641 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
2643 auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy bodyIP,
2644 llvm::Value *linearIV) -> llvm::Error {
2645 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2646 builder.restoreIP(bodyIP);
2649 builder, moduleTranslation))) {
2650 return llvm::make_error<llvm::StringError>(
2651 "failed to convert iterator region", llvm::inconvertibleErrorCode());
2655 mlir::dyn_cast<mlir::omp::YieldOp>(iteratorRegionBlock.
getTerminator());
2656 assert(yield && yield.getResults().size() == 1 &&
2657 "expect omp.yield in iterator region to have one result");
2659 genStoreEntry(linearIV, yield);
2665 return llvm::Error::success();
2668 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2670 loc, iterInfo.getTotalTrips(), bodyGen, loopName);
2674 builder.restoreIP(*afterIP);
2676 return mlir::success();
2679static mlir::LogicalResult
2682 llvm::OpenMPIRBuilder::AffinityData &ad) {
2684 if (taskOp.getAffinityVars().empty() && taskOp.getIterated().empty()) {
2687 return mlir::success();
2691 llvm::StructType *kmpTaskAffinityInfoTy =
2694 auto allocateAffinityList = [&](llvm::Value *count) -> llvm::Value * {
2695 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2696 if (llvm::isa<llvm::Constant>(count) || llvm::isa<llvm::Argument>(count))
2698 return builder.CreateAlloca(kmpTaskAffinityInfoTy, count,
2699 "omp.affinity_list");
2702 auto createAffinity =
2703 [&](llvm::Value *count,
2704 llvm::Value *info) -> llvm::OpenMPIRBuilder::AffinityData {
2705 llvm::OpenMPIRBuilder::AffinityData ad{};
2706 ad.Count = builder.CreateTrunc(count, builder.getInt32Ty());
2708 builder.CreatePointerBitCastOrAddrSpaceCast(info, builder.getPtrTy(0));
2712 if (!taskOp.getAffinityVars().empty()) {
2713 llvm::Value *count = llvm::ConstantInt::get(
2714 builder.getInt64Ty(), taskOp.getAffinityVars().size());
2715 llvm::Value *list = allocateAffinityList(count);
2718 ads.emplace_back(createAffinity(count, list));
2721 if (!taskOp.getIterated().empty()) {
2722 for (
auto [i, iter] : llvm::enumerate(taskOp.getIterated())) {
2723 auto itersOp = iter.getDefiningOp<omp::IteratorOp>();
2724 assert(itersOp &&
"iterated value must be defined by omp.iterator");
2725 IteratorInfo iterInfo(itersOp, moduleTranslation, builder);
2726 llvm::Value *affList = allocateAffinityList(iterInfo.getTotalTrips());
2728 itersOp, builder, moduleTranslation, iterInfo,
"iterator",
2729 [&](llvm::Value *linearIV, mlir::omp::YieldOp yield) {
2730 auto entryOp = yield.getResults()[0]
2731 .getDefiningOp<mlir::omp::AffinityEntryOp>();
2732 assert(entryOp &&
"expect yield produce an affinity entry");
2739 affList, linearIV, addr, len);
2741 return llvm::failure();
2742 ads.emplace_back(createAffinity(iterInfo.getTotalTrips(), affList));
2746 llvm::Value *totalAffinityCount = builder.getInt32(0);
2747 for (
const auto &affinity : ads)
2748 totalAffinityCount = builder.CreateAdd(
2750 builder.CreateIntCast(affinity.Count, builder.getInt32Ty(),
2753 llvm::Value *affinityInfo = ads.front().Info;
2754 if (ads.size() > 1) {
2755 llvm::StructType *kmpTaskAffinityInfoTy =
2757 llvm::Value *affinityInfoElemSize = builder.getInt64(
2758 moduleTranslation.
getLLVMModule()->getDataLayout().getTypeAllocSize(
2759 kmpTaskAffinityInfoTy));
2761 llvm::Value *packedAffinityInfo = allocateAffinityList(totalAffinityCount);
2762 llvm::Value *packedAffinityInfoOffset = builder.getInt32(0);
2763 for (
const auto &affinity : ads) {
2764 llvm::Value *affinityCount = builder.CreateIntCast(
2765 affinity.Count, builder.getInt32Ty(),
false);
2766 llvm::Value *affinityCountInt64 = builder.CreateIntCast(
2767 affinityCount, builder.getInt64Ty(),
false);
2768 llvm::Value *affinityInfoSize =
2769 builder.CreateMul(affinityCountInt64, affinityInfoElemSize);
2771 llvm::Value *packedAffinityInfoIndex = builder.CreateIntCast(
2772 packedAffinityInfoOffset, kmpTaskAffinityInfoTy->getElementType(0),
2774 packedAffinityInfoIndex = builder.CreateInBoundsGEP(
2775 kmpTaskAffinityInfoTy, packedAffinityInfo, packedAffinityInfoIndex);
2777 builder.CreateMemCpy(
2778 packedAffinityInfoIndex, llvm::Align(1),
2779 builder.CreatePointerBitCastOrAddrSpaceCast(
2780 affinity.Info, builder.getPtrTy(packedAffinityInfoIndex->getType()
2781 ->getPointerAddressSpace())),
2782 llvm::Align(1), affinityInfoSize);
2784 packedAffinityInfoOffset =
2785 builder.CreateAdd(packedAffinityInfoOffset, affinityCount);
2788 affinityInfo = packedAffinityInfo;
2791 ad.Count = totalAffinityCount;
2792 ad.Info = affinityInfo;
2794 return mlir::success();
2800static mlir::LogicalResult
2803 std::optional<ArrayAttr> dependIteratedKinds,
2804 llvm::IRBuilderBase &builder,
2806 llvm::OpenMPIRBuilder::DependenciesInfo &taskDeps) {
2807 if (dependIterated.empty()) {
2810 return mlir::success();
2814 llvm::Type *dependInfoTy = ompBuilder.DependInfo;
2815 unsigned numLocator = dependVars.size();
2818 llvm::Value *totalCount =
2819 llvm::ConstantInt::get(builder.getInt64Ty(), numLocator);
2822 for (
auto iter : dependIterated) {
2823 auto itersOp = iter.getDefiningOp<mlir::omp::IteratorOp>();
2824 assert(itersOp &&
"depend_iterated value must be defined by omp.iterator");
2825 iterInfos.emplace_back(itersOp, moduleTranslation, builder);
2827 builder.CreateAdd(totalCount, iterInfos.back().getTotalTrips());
2832 llvm::Constant *allocSize = llvm::ConstantExpr::getSizeOf(dependInfoTy);
2833 llvm::Value *depArray =
2834 builder.CreateMalloc(ompBuilder.SizeTy, dependInfoTy, allocSize,
2835 totalCount,
nullptr,
".dep.arr.addr");
2838 if (numLocator > 0) {
2841 for (
auto [i, dd] : llvm::enumerate(dds)) {
2842 llvm::Value *idx = llvm::ConstantInt::get(builder.getInt64Ty(), i);
2843 llvm::Value *entry =
2844 builder.CreateInBoundsGEP(dependInfoTy, depArray, idx);
2845 ompBuilder.emitTaskDependency(builder, entry, dd);
2850 llvm::Value *offset =
2851 llvm::ConstantInt::get(builder.getInt64Ty(), numLocator);
2852 for (
auto [i, iterInfo] : llvm::enumerate(iterInfos)) {
2853 auto kindAttr = cast<mlir::omp::ClauseTaskDependAttr>(
2854 dependIteratedKinds->getValue()[i]);
2855 llvm::omp::RTLDependenceKindTy rtlKind =
2858 auto itersOp = dependIterated[i].getDefiningOp<mlir::omp::IteratorOp>();
2860 itersOp, builder, moduleTranslation, iterInfo,
"dep_iterator",
2861 [&](llvm::Value *linearIV, mlir::omp::YieldOp yield) {
2863 moduleTranslation.
lookupValue(yield.getResults()[0]);
2864 llvm::Value *idx = builder.CreateAdd(offset, linearIV);
2865 llvm::Value *entry =
2866 builder.CreateInBoundsGEP(dependInfoTy, depArray, idx);
2867 ompBuilder.emitTaskDependency(
2869 llvm::OpenMPIRBuilder::DependData{rtlKind, addr->getType(),
2872 return mlir::failure();
2875 offset = builder.CreateAdd(offset, iterInfo.getTotalTrips());
2878 taskDeps.DepArray = depArray;
2879 taskDeps.NumDeps = builder.CreateTrunc(totalCount, builder.getInt32Ty());
2880 return mlir::success();
2887 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2892 TaskContextStructManager taskStructMgr{builder, moduleTranslation,
2904 InsertPointTy allocaIP =
2909 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end());
2910 llvm::BasicBlock *taskStartBlock = llvm::BasicBlock::Create(
2911 builder.getContext(),
"omp.task.start",
2912 builder.GetInsertBlock()->getParent());
2913 llvm::Instruction *branchToTaskStartBlock = builder.CreateBr(taskStartBlock);
2914 builder.SetInsertPoint(branchToTaskStartBlock);
2917 llvm::BasicBlock *copyBlock =
2918 splitBB(builder,
true,
"omp.private.copy");
2919 llvm::BasicBlock *initBlock =
2920 splitBB(builder,
true,
"omp.private.init");
2936 moduleTranslation, allocaIP, deallocBlocks);
2939 builder.SetInsertPoint(initBlock->getTerminator());
2942 taskStructMgr.generateTaskContextStruct();
2949 taskStructMgr.createGEPsToPrivateVars();
2951 for (
auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVarAlloc] :
2954 taskStructMgr.getLLVMPrivateVarGEPs())) {
2956 if (!privDecl.readsFromMold())
2958 assert(llvmPrivateVarAlloc &&
2959 "reads from mold so shouldn't have been skipped");
2962 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2963 blockArg, llvmPrivateVarAlloc, initBlock);
2964 if (!privateVarOrErr)
2965 return handleError(privateVarOrErr, *taskOp.getOperation());
2974 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
2975 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2976 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
2978 llvmPrivateVarAlloc = builder.CreateLoad(privateVarOrErr.get()->getType(),
2979 llvmPrivateVarAlloc);
2981 assert(llvmPrivateVarAlloc->getType() ==
2982 moduleTranslation.
convertType(blockArg.getType()));
2992 taskOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
2993 taskStructMgr.getLLVMPrivateVarGEPs(), privateVarsInfo.
privatizers,
2994 taskOp.getPrivateNeedsBarrier())))
2995 return llvm::failure();
2997 llvm::OpenMPIRBuilder::AffinityData ad;
2999 return llvm::failure();
3002 builder.SetInsertPoint(taskStartBlock);
3005 [&](InsertPointTy allocaIP, InsertPointTy codegenIP,
3010 moduleTranslation, allocaIP, deallocBlocks);
3013 builder.restoreIP(codegenIP);
3015 llvm::BasicBlock *privInitBlock =
nullptr;
3017 for (
auto [i, zip] : llvm::enumerate(llvm::zip_equal(
3020 auto [blockArg, privDecl, mlirPrivVar] = zip;
3022 if (privDecl.readsFromMold())
3025 llvm::IRBuilderBase::InsertPointGuard guard(builder);
3026 llvm::Type *llvmAllocType =
3027 moduleTranslation.
convertType(privDecl.getType());
3028 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
3029 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
3030 llvmAllocType,
nullptr,
"omp.private.alloc");
3033 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
3034 blockArg, llvmPrivateVar, privInitBlock);
3035 if (!privateVarOrError)
3036 return privateVarOrError.takeError();
3037 moduleTranslation.
mapValue(blockArg, privateVarOrError.get());
3038 privateVarsInfo.
llvmVars[i] = privateVarOrError.get();
3041 taskStructMgr.createGEPsToPrivateVars();
3042 for (
auto [i, llvmPrivVar] :
3043 llvm::enumerate(taskStructMgr.getLLVMPrivateVarGEPs())) {
3045 assert(privateVarsInfo.
llvmVars[i] &&
3046 "This is added in the loop above");
3049 privateVarsInfo.
llvmVars[i] = llvmPrivVar;
3054 for (
auto [blockArg, llvmPrivateVar, privateDecl] :
3058 if (!privateDecl.readsFromMold())
3061 if (!mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
3062 llvmPrivateVar = builder.CreateLoad(
3063 moduleTranslation.
convertType(blockArg.getType()), llvmPrivateVar);
3065 assert(llvmPrivateVar->getType() ==
3066 moduleTranslation.
convertType(blockArg.getType()));
3067 moduleTranslation.
mapValue(blockArg, llvmPrivateVar);
3071 taskOp.getRegion(),
"omp.task.region", builder, moduleTranslation);
3072 if (failed(
handleError(continuationBlockOrError, *taskOp)))
3073 return llvm::make_error<PreviouslyReportedError>();
3075 builder.SetInsertPoint(continuationBlockOrError.get()->getTerminator());
3078 taskOp.getLoc(), privateVarsInfo)))
3079 return llvm::make_error<PreviouslyReportedError>();
3082 taskStructMgr.freeStructPtr();
3084 return llvm::Error::success();
3093 llvm::omp::Directive::OMPD_taskgroup);
3095 llvm::OpenMPIRBuilder::DependenciesInfo dependencies;
3096 if (failed(
buildDependData(taskOp.getDependVars(), taskOp.getDependKinds(),
3097 taskOp.getDependIterated(),
3098 taskOp.getDependIteratedKinds(), builder,
3099 moduleTranslation, dependencies)))
3102 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3103 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3105 ompLoc, allocaIP, deallocBlocks, bodyCB, !taskOp.getUntied(),
3107 moduleTranslation.
lookupValue(taskOp.getIfExpr()), dependencies, ad,
3108 taskOp.getMergeable(),
3109 moduleTranslation.
lookupValue(taskOp.getEventHandle()),
3110 moduleTranslation.
lookupValue(taskOp.getPriority()));
3118 builder.restoreIP(*afterIP);
3120 if (dependencies.DepArray)
3121 builder.CreateFree(dependencies.DepArray);
3130 llvm::IRBuilderBase &builder,
3138 loopWrapperOp.getRegion(),
"omp.taskloop.wrapper.region", builder,
3141 if (failed(
handleError(continuationBlockOrError, opInst)))
3144 builder.SetInsertPoint(continuationBlockOrError.get());
3152static llvm::Expected<llvm::Value *>
3155 llvm::IRBuilderBase &builder) {
3156 if (llvm::Value *mapped = moduleTranslation.
lookupValue(value))
3161 return llvm::make_error<llvm::StringError>(
3162 "value is a block argument and is not mapped",
3163 llvm::inconvertibleErrorCode());
3165 return llvm::make_error<llvm::StringError>(
3166 "unsupported op defining taskloop loop bound",
3167 llvm::inconvertibleErrorCode());
3177 if (!operandOrError)
3178 return operandOrError.takeError();
3179 moduleTranslation.
mapValue(operand, *operandOrError);
3180 mappingsToRemove.push_back(operand);
3184 return llvm::make_error<llvm::StringError>(
3185 "failed to convert op defining taskloop loop bound",
3186 llvm::inconvertibleErrorCode());
3189 assert(
result &&
"expected conversion of loop bound op to produce a value");
3193 mappingsToRemove.push_back(resultValue);
3195 for (
Value mappedValue : mappingsToRemove)
3204 llvm::Value *&lbVal, llvm::Value *&ubVal,
3205 llvm::Value *&stepVal) {
3213 return firstLbOrErr.takeError();
3215 llvm::Type *boundType = (*firstLbOrErr)->getType();
3216 ubVal = builder.getIntN(boundType->getIntegerBitWidth(), 1);
3217 if (loopOp.getCollapseNumLoops() > 1) {
3235 for (uint64_t i = 0; i < loopOp.getCollapseNumLoops(); i++) {
3237 i == 0 ? std::move(firstLbOrErr)
3241 return lbOrErr.takeError();
3243 upperBounds[i], moduleTranslation, builder);
3245 return ubOrErr.takeError();
3249 return stepOrErr.takeError();
3251 llvm::Value *loopLb = *lbOrErr;
3252 llvm::Value *loopUb = *ubOrErr;
3253 llvm::Value *loopStep = *stepOrErr;
3259 llvm::Value *loopLbMinusOne = builder.CreateSub(
3260 loopLb, builder.getIntN(boundType->getIntegerBitWidth(), 1));
3261 llvm::Value *loopUbMinusOne = builder.CreateSub(
3262 loopUb, builder.getIntN(boundType->getIntegerBitWidth(), 1));
3263 llvm::Value *boundsCmp = builder.CreateICmpSLT(loopLb, loopUb);
3264 llvm::Value *ubMinusLb = builder.CreateSub(loopUb, loopLbMinusOne);
3265 llvm::Value *lbMinusUb = builder.CreateSub(loopLb, loopUbMinusOne);
3266 llvm::Value *loopTripCount =
3267 builder.CreateSelect(boundsCmp, ubMinusLb, lbMinusUb);
3268 loopTripCount = builder.CreateBinaryIntrinsic(
3269 llvm::Intrinsic::abs, loopTripCount, builder.getFalse());
3273 llvm::Value *loopTripCountDivStep =
3274 builder.CreateSDiv(loopTripCount, loopStep);
3275 loopTripCountDivStep = builder.CreateBinaryIntrinsic(
3276 llvm::Intrinsic::abs, loopTripCountDivStep, builder.getFalse());
3277 llvm::Value *loopTripCountRem =
3278 builder.CreateSRem(loopTripCount, loopStep);
3279 loopTripCountRem = builder.CreateBinaryIntrinsic(
3280 llvm::Intrinsic::abs, loopTripCountRem, builder.getFalse());
3281 llvm::Value *needsRoundUp = builder.CreateICmpNE(
3283 builder.getIntN(loopTripCountRem->getType()->getIntegerBitWidth(),
3286 builder.CreateAdd(loopTripCountDivStep,
3287 builder.CreateZExtOrTrunc(
3288 needsRoundUp, loopTripCountDivStep->getType()));
3289 ubVal = builder.CreateMul(ubVal, loopTripCount);
3291 lbVal = builder.getIntN(boundType->getIntegerBitWidth(), 1);
3292 stepVal = builder.getIntN(boundType->getIntegerBitWidth(), 1);
3297 return ubOrErr.takeError();
3301 return stepOrErr.takeError();
3302 lbVal = *firstLbOrErr;
3304 stepVal = *stepOrErr;
3307 assert(lbVal !=
nullptr &&
"Expected value for lbVal");
3308 assert(ubVal !=
nullptr &&
"Expected value for ubVal");
3309 assert(stepVal !=
nullptr &&
"Expected value for stepVal");
3310 return llvm::Error::success();
3316 llvm::IRBuilderBase &builder,
3318 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
3320 omp::TaskloopWrapperOp loopWrapperOp = contextOp.getLoopOp();
3328 TaskContextStructManager taskStructMgr{builder, moduleTranslation,
3332 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3335 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end());
3336 llvm::BasicBlock *taskloopStartBlock = llvm::BasicBlock::Create(
3337 builder.getContext(),
"omp.taskloop.wrapper.start",
3338 builder.GetInsertBlock()->getParent());
3339 llvm::Instruction *branchToTaskloopStartBlock =
3340 builder.CreateBr(taskloopStartBlock);
3341 builder.SetInsertPoint(branchToTaskloopStartBlock);
3343 llvm::BasicBlock *copyBlock =
3344 splitBB(builder,
true,
"omp.private.copy");
3345 llvm::BasicBlock *initBlock =
3346 splitBB(builder,
true,
"omp.private.init");
3349 moduleTranslation, allocaIP, deallocBlocks);
3352 builder.SetInsertPoint(initBlock->getTerminator());
3355 taskStructMgr.generateTaskContextStruct();
3356 taskStructMgr.createGEPsToPrivateVars();
3358 llvmFirstPrivateVars.resize(privateVarsInfo.
blockArgs.size());
3360 for (
auto [i, zip] : llvm::enumerate(llvm::zip_equal(
3362 privateVarsInfo.
blockArgs, taskStructMgr.getLLVMPrivateVarGEPs()))) {
3363 auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVarAlloc] = zip;
3365 if (!privDecl.readsFromMold())
3367 assert(llvmPrivateVarAlloc &&
3368 "reads from mold so shouldn't have been skipped");
3371 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
3372 blockArg, llvmPrivateVarAlloc, initBlock);
3373 if (!privateVarOrErr)
3374 return handleError(privateVarOrErr, *contextOp.getOperation());
3376 llvmFirstPrivateVars[i] = privateVarOrErr.get();
3378 llvm::IRBuilderBase::InsertPointGuard guard(builder);
3379 builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
3381 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
3382 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
3383 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
3385 llvmPrivateVarAlloc = builder.CreateLoad(privateVarOrErr.get()->getType(),
3386 llvmPrivateVarAlloc);
3388 assert(llvmPrivateVarAlloc->getType() ==
3389 moduleTranslation.
convertType(blockArg.getType()));
3395 contextOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
3396 taskStructMgr.getLLVMPrivateVarGEPs(), privateVarsInfo.
privatizers,
3397 contextOp.getPrivateNeedsBarrier())))
3398 return llvm::failure();
3401 builder.SetInsertPoint(taskloopStartBlock);
3403 auto loopOp = cast<omp::LoopNestOp>(loopWrapperOp.getWrappedLoop());
3404 llvm::Value *lbVal =
nullptr;
3405 llvm::Value *ubVal =
nullptr;
3406 llvm::Value *stepVal =
nullptr;
3408 loopOp, builder, moduleTranslation, lbVal, ubVal, stepVal))
3412 [&](InsertPointTy allocaIP, InsertPointTy codegenIP,
3417 moduleTranslation, allocaIP, deallocBlocks);
3420 builder.restoreIP(codegenIP);
3422 llvm::BasicBlock *privInitBlock =
nullptr;
3424 for (
auto [i, zip] : llvm::enumerate(llvm::zip_equal(
3427 auto [blockArg, privDecl, mlirPrivVar] = zip;
3429 if (privDecl.readsFromMold())
3432 llvm::IRBuilderBase::InsertPointGuard guard(builder);
3433 llvm::Type *llvmAllocType =
3434 moduleTranslation.
convertType(privDecl.getType());
3435 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
3436 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
3437 llvmAllocType,
nullptr,
"omp.private.alloc");
3440 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
3441 blockArg, llvmPrivateVar, privInitBlock);
3442 if (!privateVarOrError)
3443 return privateVarOrError.takeError();
3444 moduleTranslation.
mapValue(blockArg, privateVarOrError.get());
3445 privateVarsInfo.
llvmVars[i] = privateVarOrError.get();
3448 taskStructMgr.createGEPsToPrivateVars();
3449 for (
auto [i, llvmPrivVar] :
3450 llvm::enumerate(taskStructMgr.getLLVMPrivateVarGEPs())) {
3452 assert(privateVarsInfo.
llvmVars[i] &&
3453 "This is added in the loop above");
3456 privateVarsInfo.
llvmVars[i] = llvmPrivVar;
3461 for (
auto [blockArg, llvmPrivateVar, privateDecl] :
3465 if (!privateDecl.readsFromMold())
3468 if (!mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
3469 llvmPrivateVar = builder.CreateLoad(
3470 moduleTranslation.
convertType(blockArg.getType()), llvmPrivateVar);
3472 assert(llvmPrivateVar->getType() ==
3473 moduleTranslation.
convertType(blockArg.getType()));
3474 moduleTranslation.
mapValue(blockArg, llvmPrivateVar);
3480 contextOp.getRegion(),
"omp.taskloop.context.region", builder,
3483 if (failed(
handleError(continuationBlockOrError, opInst)))
3484 return llvm::make_error<PreviouslyReportedError>();
3486 builder.SetInsertPoint(continuationBlockOrError.get()->getTerminator());
3494 contextOp.getLoc(), privateVarsInfo)))
3495 return llvm::make_error<PreviouslyReportedError>();
3498 taskStructMgr.freeStructPtr();
3500 return llvm::Error::success();
3506 auto taskDupCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP,
3507 llvm::Value *destPtr, llvm::Value *srcPtr)
3509 llvm::IRBuilderBase::InsertPointGuard guard(builder);
3510 builder.restoreIP(codegenIP);
3513 builder.getPtrTy(srcPtr->getType()->getPointerAddressSpace());
3515 builder.CreateLoad(ptrTy, srcPtr,
"omp.taskloop.context.src");
3517 TaskContextStructManager &srcStructMgr = taskStructMgr;
3518 TaskContextStructManager destStructMgr(builder, moduleTranslation,
3520 destStructMgr.generateTaskContextStruct();
3521 llvm::Value *dest = destStructMgr.getStructPtr();
3522 dest->setName(
"omp.taskloop.context.dest");
3523 builder.CreateStore(dest, destPtr);
3526 srcStructMgr.createGEPsToPrivateVars(src);
3528 destStructMgr.createGEPsToPrivateVars(dest);
3531 for (
auto [privDecl, mold, blockArg, llvmPrivateVarAlloc] :
3532 llvm::zip_equal(privateVarsInfo.
privatizers, srcGEPs,
3535 if (!privDecl.readsFromMold())
3537 assert(llvmPrivateVarAlloc &&
3538 "reads from mold so shouldn't have been skipped");
3541 initPrivateVar(builder, moduleTranslation, privDecl, mold, blockArg,
3542 llvmPrivateVarAlloc, builder.GetInsertBlock());
3543 if (!privateVarOrErr)
3544 return privateVarOrErr.takeError();
3553 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
3554 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
3555 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
3557 llvmPrivateVarAlloc = builder.CreateLoad(
3558 privateVarOrErr.get()->getType(), llvmPrivateVarAlloc);
3560 assert(llvmPrivateVarAlloc->getType() ==
3561 moduleTranslation.
convertType(blockArg.getType()));
3569 moduleTranslation, srcGEPs, destGEPs,
3571 contextOp.getPrivateNeedsBarrier())))
3572 return llvm::make_error<PreviouslyReportedError>();
3574 return builder.saveIP();
3582 llvm::Value *ifCond =
nullptr;
3583 llvm::Value *grainsize =
nullptr;
3585 mlir::Value grainsizeVal = contextOp.getGrainsize();
3586 mlir::Value numTasksVal = contextOp.getNumTasks();
3587 if (
Value ifVar = contextOp.getIfExpr())
3590 grainsize = moduleTranslation.
lookupValue(grainsizeVal);
3592 }
else if (numTasksVal) {
3593 grainsize = moduleTranslation.
lookupValue(numTasksVal);
3597 llvm::OpenMPIRBuilder::TaskDupCallbackTy taskDupOrNull =
nullptr;
3598 if (taskStructMgr.getStructPtr())
3599 taskDupOrNull = taskDupCB;
3609 llvm::omp::Directive::OMPD_taskgroup);
3611 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3612 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3614 ompLoc, allocaIP, deallocBlocks, bodyCB, loopInfo, lbVal, ubVal,
3615 stepVal, contextOp.getUntied(), ifCond, grainsize,
3616 contextOp.getNogroup(), sched,
3617 moduleTranslation.
lookupValue(contextOp.getFinal()),
3618 contextOp.getMergeable(),
3619 moduleTranslation.
lookupValue(contextOp.getPriority()),
3620 loopOp.getCollapseNumLoops(), taskDupOrNull,
3621 taskStructMgr.getStructPtr());
3628 builder.restoreIP(*afterIP);
3636 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
3640 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP,
3642 builder.restoreIP(codegenIP);
3644 builder, moduleTranslation)
3649 InsertPointTy allocaIP =
3651 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3652 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3654 ompLoc, allocaIP, deallocBlocks, bodyCB);
3659 builder.restoreIP(*afterIP);
3678 auto wsloopOp = cast<omp::WsloopOp>(opInst);
3682 auto loopOp = cast<omp::LoopNestOp>(wsloopOp.getWrappedLoop());
3684 assert(isByRef.size() == wsloopOp.getNumReductionVars());
3688 wsloopOp.getScheduleKind().value_or(omp::ClauseScheduleKind::Static);
3691 llvm::Value *step = moduleTranslation.
lookupValue(loopOp.getLoopSteps()[0]);
3692 llvm::Type *ivType = step->getType();
3693 llvm::Value *chunk =
nullptr;
3694 if (wsloopOp.getScheduleChunk()) {
3695 llvm::Value *chunkVar =
3696 moduleTranslation.
lookupValue(wsloopOp.getScheduleChunk());
3697 chunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
3700 omp::DistributeOp distributeOp =
nullptr;
3701 llvm::Value *distScheduleChunk =
nullptr;
3702 bool hasDistSchedule =
false;
3703 if (llvm::isa_and_present<omp::DistributeOp>(opInst.
getParentOp())) {
3704 distributeOp = cast<omp::DistributeOp>(opInst.
getParentOp());
3705 hasDistSchedule = distributeOp.getDistScheduleStatic();
3706 if (distributeOp.getDistScheduleChunkSize()) {
3707 llvm::Value *chunkVar = moduleTranslation.
lookupValue(
3708 distributeOp.getDistScheduleChunkSize());
3709 distScheduleChunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
3718 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3722 wsloopOp.getNumReductionVars());
3725 wsloopOp, builder, moduleTranslation, privateVarsInfo, allocaIP);
3732 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
3737 moduleTranslation, allocaIP, reductionDecls,
3738 privateReductionVariables, reductionVariableMap,
3739 deferredStores, isByRef)))
3748 wsloopOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
3750 wsloopOp.getPrivateNeedsBarrier())))
3753 assert(afterAllocas.get()->getSinglePredecessor());
3754 if (failed(initReductionVars(wsloopOp, reductionArgs, builder,
3756 afterAllocas.get()->getSinglePredecessor(),
3757 reductionDecls, privateReductionVariables,
3758 reductionVariableMap, isByRef, deferredStores)))
3762 bool isOrdered = wsloopOp.getOrdered().has_value();
3763 std::optional<omp::ScheduleModifier> scheduleMod = wsloopOp.getScheduleMod();
3764 bool isSimd = wsloopOp.getScheduleSimd();
3765 bool loopNeedsBarrier = !wsloopOp.getNowait();
3770 llvm::omp::WorksharingLoopType workshareLoopType =
3771 llvm::isa_and_present<omp::DistributeOp>(opInst.
getParentOp())
3772 ? llvm::omp::WorksharingLoopType::DistributeForStaticLoop
3773 : llvm::omp::WorksharingLoopType::ForStaticLoop;
3777 llvm::omp::Directive::OMPD_for);
3779 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3782 LinearClauseProcessor linearClauseProcessor;
3784 if (!wsloopOp.getLinearVars().empty()) {
3785 auto linearVarTypes = wsloopOp.getLinearVarTypes().value();
3787 linearClauseProcessor.registerType(moduleTranslation, linearVarType);
3789 for (
auto [idx, linearVar] : llvm::enumerate(wsloopOp.getLinearVars()))
3790 linearClauseProcessor.createLinearVar(
3791 builder, moduleTranslation, moduleTranslation.
lookupValue(linearVar),
3793 for (
mlir::Value linearStep : wsloopOp.getLinearStepVars())
3794 linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
3798 wsloopOp.getRegion(),
"omp.wsloop.region", builder, moduleTranslation);
3806 if (!wsloopOp.getLinearVars().empty()) {
3807 linearClauseProcessor.initLinearVar(builder, moduleTranslation,
3808 loopInfo->getPreheader());
3809 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
3811 builder.saveIP(), llvm::omp::OMPD_barrier);
3814 builder.restoreIP(*afterBarrierIP);
3815 linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
3816 loopInfo->getIndVar());
3817 linearClauseProcessor.splitLinearFiniBB(builder, loopInfo->getExit());
3820 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
3823 bool noLoopMode =
false;
3824 omp::TargetOp targetOp = wsloopOp->getParentOfType<mlir::omp::TargetOp>();
3826 Operation *targetCapturedOp = targetOp.getInnermostCapturedOmpOp();
3830 if (loopOp == targetCapturedOp) {
3831 if (targetOp.getKernelExecFlags(targetCapturedOp) ==
3832 omp::TargetExecMode::no_loop)
3837 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
3838 ompBuilder->applyWorkshareLoop(
3839 ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
3840 convertToScheduleKind(schedule), chunk, isSimd,
3841 scheduleMod == omp::ScheduleModifier::monotonic,
3842 scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
3843 workshareLoopType, noLoopMode, hasDistSchedule, distScheduleChunk);
3849 if (!wsloopOp.getLinearVars().empty()) {
3850 llvm::OpenMPIRBuilder::InsertPointTy oldIP = builder.saveIP();
3851 assert(loopInfo->getLastIter() &&
3852 "`lastiter` in CanonicalLoopInfo is nullptr");
3853 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
3854 linearClauseProcessor.finalizeLinearVar(builder, moduleTranslation,
3855 loopInfo->getLastIter());
3858 for (
size_t index = 0;
index < wsloopOp.getLinearVars().size();
index++)
3859 linearClauseProcessor.rewriteInPlace(builder,
"omp.loop_nest.region",
3861 builder.restoreIP(oldIP);
3869 wsloopOp, builder, moduleTranslation, allocaIP, reductionDecls,
3870 privateReductionVariables, isByRef, wsloopOp.getNowait(),
3875 wsloopOp.getLoc(), privateVarsInfo);
3882 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
3884 assert(isByRef.size() == opInst.getNumReductionVars());
3897 opInst.getNumReductionVars());
3901 [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
3904 opInst, builder, moduleTranslation, privateVarsInfo, allocaIP);
3906 return llvm::make_error<PreviouslyReportedError>();
3912 cast<omp::BlockArgOpenMPOpInterface>(*opInst).getReductionBlockArgs();
3915 InsertPointTy(allocaIP.getBlock(),
3916 allocaIP.getBlock()->getTerminator()->getIterator());
3919 opInst, reductionArgs, builder, moduleTranslation, allocaIP,
3920 reductionDecls, privateReductionVariables, reductionVariableMap,
3921 deferredStores, isByRef)))
3922 return llvm::make_error<PreviouslyReportedError>();
3924 assert(afterAllocas.get()->getSinglePredecessor());
3925 builder.restoreIP(codeGenIP);
3931 return llvm::make_error<PreviouslyReportedError>();
3934 opInst, builder, moduleTranslation, privateVarsInfo.
mlirVars,
3936 opInst.getPrivateNeedsBarrier())))
3937 return llvm::make_error<PreviouslyReportedError>();
3940 initReductionVars(opInst, reductionArgs, builder, moduleTranslation,
3941 afterAllocas.get()->getSinglePredecessor(),
3942 reductionDecls, privateReductionVariables,
3943 reductionVariableMap, isByRef, deferredStores)))
3944 return llvm::make_error<PreviouslyReportedError>();
3949 moduleTranslation, allocaIP, deallocBlocks);
3953 opInst.getRegion(),
"omp.par.region", builder, moduleTranslation);
3955 return regionBlock.takeError();
3958 if (opInst.getNumReductionVars() > 0) {
3963 owningReductionGenRefDataPtrGens;
3965 collectReductionInfo(opInst, builder, moduleTranslation, reductionDecls,
3967 owningReductionGenRefDataPtrGens,
3968 privateReductionVariables, reductionInfos, isByRef);
3971 builder.SetInsertPoint((*regionBlock)->getTerminator());
3974 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
3975 builder.SetInsertPoint(tempTerminator);
3977 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
3978 ompBuilder->createReductions(
3979 builder.saveIP(), allocaIP, reductionInfos, isByRef,
3981 if (!contInsertPoint)
3982 return contInsertPoint.takeError();
3984 if (!contInsertPoint->getBlock())
3985 return llvm::make_error<PreviouslyReportedError>();
3987 tempTerminator->eraseFromParent();
3988 builder.restoreIP(*contInsertPoint);
3991 return llvm::Error::success();
3994 auto privCB = [](InsertPointTy allocaIP, InsertPointTy codeGenIP,
3995 llvm::Value &, llvm::Value &val, llvm::Value *&replVal) {
4004 auto finiCB = [&](InsertPointTy codeGenIP) -> llvm::Error {
4005 InsertPointTy oldIP = builder.saveIP();
4006 builder.restoreIP(codeGenIP);
4011 llvm::transform(reductionDecls, std::back_inserter(reductionCleanupRegions),
4012 [](omp::DeclareReductionOp reductionDecl) {
4013 return &reductionDecl.getCleanupRegion();
4016 reductionCleanupRegions, privateReductionVariables,
4017 moduleTranslation, builder,
"omp.reduction.cleanup")))
4018 return llvm::createStringError(
4019 "failed to inline `cleanup` region of `omp.declare_reduction`");
4022 opInst.getLoc(), privateVarsInfo)))
4023 return llvm::make_error<PreviouslyReportedError>();
4027 if (isCancellable) {
4028 auto IPOrErr = ompBuilder->createBarrier(
4029 llvm::OpenMPIRBuilder::LocationDescription(builder),
4030 llvm::omp::Directive::OMPD_unknown,
4034 return IPOrErr.takeError();
4037 builder.restoreIP(oldIP);
4038 return llvm::Error::success();
4041 llvm::Value *ifCond =
nullptr;
4042 if (
auto ifVar = opInst.getIfExpr())
4044 llvm::Value *numThreads =
nullptr;
4045 if (!opInst.getNumThreadsVars().empty())
4046 numThreads = moduleTranslation.
lookupValue(opInst.getNumThreads(0));
4047 auto pbKind = llvm::omp::OMP_PROC_BIND_default;
4048 if (
auto bind = opInst.getProcBindKind())
4052 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4054 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4056 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4057 ompBuilder->createParallel(ompLoc, allocaIP, deallocBlocks, bodyGenCB,
4058 privCB, finiCB, ifCond, numThreads, pbKind,
4064 builder.restoreIP(*afterIP);
4069static llvm::omp::OrderKind
4072 return llvm::omp::OrderKind::OMP_ORDER_unknown;
4074 case omp::ClauseOrderKind::Concurrent:
4075 return llvm::omp::OrderKind::OMP_ORDER_concurrent;
4077 llvm_unreachable(
"Unknown ClauseOrderKind kind");
4085 auto simdOp = cast<omp::SimdOp>(opInst);
4093 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
4096 simdOp.getNumReductionVars());
4101 assert(isByRef.size() == simdOp.getNumReductionVars());
4103 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4107 simdOp, builder, moduleTranslation, privateVarsInfo, allocaIP);
4112 LinearClauseProcessor linearClauseProcessor;
4114 if (!simdOp.getLinearVars().empty()) {
4115 auto linearVarTypes = simdOp.getLinearVarTypes().value();
4117 linearClauseProcessor.registerType(moduleTranslation, linearVarType);
4118 for (
auto [idx, linearVar] : llvm::enumerate(simdOp.getLinearVars())) {
4119 bool isImplicit =
false;
4120 for (
auto [mlirPrivVar, llvmPrivateVar] : llvm::zip_equal(
4124 if (linearVar == mlirPrivVar) {
4126 linearClauseProcessor.createLinearVar(builder, moduleTranslation,
4127 llvmPrivateVar, idx);
4133 linearClauseProcessor.createLinearVar(
4134 builder, moduleTranslation,
4137 for (
mlir::Value linearStep : simdOp.getLinearStepVars())
4138 linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
4142 moduleTranslation, allocaIP, reductionDecls,
4143 privateReductionVariables, reductionVariableMap,
4144 deferredStores, isByRef)))
4155 assert(afterAllocas.get()->getSinglePredecessor());
4156 if (failed(initReductionVars(simdOp, reductionArgs, builder,
4158 afterAllocas.get()->getSinglePredecessor(),
4159 reductionDecls, privateReductionVariables,
4160 reductionVariableMap, isByRef, deferredStores)))
4163 llvm::ConstantInt *simdlen =
nullptr;
4164 if (std::optional<uint64_t> simdlenVar = simdOp.getSimdlen())
4165 simdlen = builder.getInt64(simdlenVar.value());
4167 llvm::ConstantInt *safelen =
nullptr;
4168 if (std::optional<uint64_t> safelenVar = simdOp.getSafelen())
4169 safelen = builder.getInt64(safelenVar.value());
4171 llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars;
4174 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
4175 std::optional<ArrayAttr> alignmentValues = simdOp.getAlignments();
4177 for (
size_t i = 0; i < operands.size(); ++i) {
4178 llvm::Value *alignment =
nullptr;
4179 llvm::Value *llvmVal = moduleTranslation.
lookupValue(operands[i]);
4180 llvm::Type *ty = llvmVal->getType();
4182 auto intAttr = cast<IntegerAttr>((*alignmentValues)[i]);
4183 alignment = builder.getInt64(intAttr.getInt());
4184 assert(ty->isPointerTy() &&
"Invalid type for aligned variable");
4185 assert(alignment &&
"Invalid alignment value");
4189 if (!intAttr.getValue().isPowerOf2())
4192 auto curInsert = builder.saveIP();
4193 builder.SetInsertPoint(sourceBlock);
4194 llvmVal = builder.CreateLoad(ty, llvmVal);
4195 builder.restoreIP(curInsert);
4196 alignedVars[llvmVal] = alignment;
4200 simdOp.getRegion(),
"omp.simd.region", builder, moduleTranslation);
4207 if (simdOp.getLinearVars().size()) {
4208 linearClauseProcessor.initLinearVar(builder, moduleTranslation,
4209 loopInfo->getPreheader());
4211 linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
4212 loopInfo->getIndVar());
4214 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
4216 ompBuilder->applySimd(loopInfo, alignedVars,
4218 ? moduleTranslation.
lookupValue(simdOp.getIfExpr())
4220 order, simdlen, safelen);
4222 linearClauseProcessor.emitStoresForLinearVar(builder);
4225 bool hasOrderedRegions =
false;
4226 simdOp.getRegion().walk([&](omp::OrderedRegionOp orderedOp) {
4227 hasOrderedRegions =
true;
4231 for (
size_t index = 0;
index < simdOp.getLinearVars().size();
index++) {
4232 linearClauseProcessor.rewriteInPlace(builder,
"omp.loop_nest.region",
4234 if (hasOrderedRegions) {
4236 linearClauseProcessor.rewriteInPlace(builder,
"omp.ordered.region",
4239 linearClauseProcessor.rewriteInPlace(builder,
"omp_region.finalize",
4248 for (
auto [i, tuple] : llvm::enumerate(
4249 llvm::zip(reductionDecls, isByRef, simdOp.getReductionVars(),
4250 privateReductionVariables))) {
4251 auto [decl, byRef, reductionVar, privateReductionVar] = tuple;
4253 OwningReductionGen gen =
makeReductionGen(decl, builder, moduleTranslation);
4254 llvm::Value *originalVariable = moduleTranslation.
lookupValue(reductionVar);
4255 llvm::Type *reductionType = moduleTranslation.
convertType(decl.getType());
4259 llvm::Value *redValue = originalVariable;
4262 builder.CreateLoad(reductionType, redValue,
"red.value." + Twine(i));
4263 llvm::Value *privateRedValue = builder.CreateLoad(
4264 reductionType, privateReductionVar,
"red.private.value." + Twine(i));
4265 llvm::Value *reduced;
4267 auto res = gen(builder.saveIP(), redValue, privateRedValue, reduced);
4270 builder.restoreIP(res.get());
4274 builder.CreateStore(reduced, originalVariable);
4279 llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
4280 [](omp::DeclareReductionOp reductionDecl) {
4281 return &reductionDecl.getCleanupRegion();
4284 moduleTranslation, builder,
4285 "omp.reduction.cleanup")))
4297 auto loopOp = cast<omp::LoopNestOp>(opInst);
4303 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4308 auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip,
4309 llvm::Value *iv) -> llvm::Error {
4312 loopOp.getRegion().front().getArgument(loopInfos.size()), iv);
4317 bodyInsertPoints.push_back(ip);
4319 if (loopInfos.size() != loopOp.getNumLoops() - 1)
4320 return llvm::Error::success();
4323 builder.restoreIP(ip);
4325 loopOp.getRegion(),
"omp.loop_nest.region", builder, moduleTranslation);
4327 return regionBlock.takeError();
4329 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
4330 return llvm::Error::success();
4338 for (
unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) {
4339 llvm::Value *lowerBound =
4340 moduleTranslation.
lookupValue(loopOp.getLoopLowerBounds()[i]);
4341 llvm::Value *upperBound =
4342 moduleTranslation.
lookupValue(loopOp.getLoopUpperBounds()[i]);
4343 llvm::Value *step = moduleTranslation.
lookupValue(loopOp.getLoopSteps()[i]);
4348 llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc;
4349 llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP;
4351 loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back(),
4353 computeIP = loopInfos.front()->getPreheaderIP();
4357 ompBuilder->createCanonicalLoop(
4358 loc, bodyGen, lowerBound, upperBound, step,
4359 true, loopOp.getLoopInclusive(), computeIP);
4364 loopInfos.push_back(*loopResult);
4367 llvm::OpenMPIRBuilder::InsertPointTy afterIP =
4368 loopInfos.front()->getAfterIP();
4371 if (
const auto &tiles = loopOp.getTileSizes()) {
4372 llvm::Type *ivType = loopInfos.front()->getIndVarType();
4375 for (
auto tile : tiles.value()) {
4376 llvm::Value *tileVal = llvm::ConstantInt::get(ivType,
tile);
4377 tileSizes.push_back(tileVal);
4380 std::vector<llvm::CanonicalLoopInfo *> newLoops =
4381 ompBuilder->tileLoops(ompLoc.DL, loopInfos, tileSizes);
4385 llvm::BasicBlock *afterBB = newLoops.front()->getAfter();
4386 llvm::BasicBlock *afterAfterBB = afterBB->getSingleSuccessor();
4387 afterIP = {afterAfterBB, afterAfterBB->begin()};
4391 for (
const auto &newLoop : newLoops)
4392 loopInfos.push_back(newLoop);
4396 const auto &numCollapse = loopOp.getCollapseNumLoops();
4398 loopInfos.begin(), loopInfos.begin() + (numCollapse));
4400 auto newTopLoopInfo =
4401 ompBuilder->collapseLoops(ompLoc.DL, collapseLoopInfos, {});
4403 assert(newTopLoopInfo &&
"New top loop information is missing");
4404 moduleTranslation.
stackWalk<OpenMPLoopInfoStackFrame>(
4405 [&](OpenMPLoopInfoStackFrame &frame) {
4406 frame.loopInfo = newTopLoopInfo;
4414 builder.restoreIP(afterIP);
4424 llvm::OpenMPIRBuilder::LocationDescription loopLoc(builder);
4425 Value loopIV = op.getInductionVar();
4426 Value loopTC = op.getTripCount();
4428 llvm::Value *llvmTC = moduleTranslation.
lookupValue(loopTC);
4431 ompBuilder->createCanonicalLoop(
4433 [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *llvmIV) {
4436 moduleTranslation.
mapValue(loopIV, llvmIV);
4438 builder.restoreIP(ip);
4443 return bodyGenStatus.takeError();
4445 llvmTC,
"omp.loop");
4447 return op.emitError(llvm::toString(llvmOrError.takeError()));
4449 llvm::CanonicalLoopInfo *llvmCLI = *llvmOrError;
4450 llvm::IRBuilderBase::InsertPoint afterIP = llvmCLI->getAfterIP();
4451 builder.restoreIP(afterIP);
4454 if (
Value cli = op.getCli())
4467 Value applyee = op.getApplyee();
4468 assert(applyee &&
"Loop to apply unrolling on required");
4470 llvm::CanonicalLoopInfo *consBuilderCLI =
4472 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
4473 ompBuilder->unrollLoopHeuristic(loc.DL, consBuilderCLI);
4481static LogicalResult
applyTile(omp::TileOp op, llvm::IRBuilderBase &builder,
4484 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
4489 for (
Value size : op.getSizes()) {
4490 llvm::Value *translatedSize = moduleTranslation.
lookupValue(size);
4491 assert(translatedSize &&
4492 "sizes clause arguments must already be translated");
4493 translatedSizes.push_back(translatedSize);
4496 for (
Value applyee : op.getApplyees()) {
4497 llvm::CanonicalLoopInfo *consBuilderCLI =
4499 assert(applyee &&
"Canonical loop must already been translated");
4500 translatedLoops.push_back(consBuilderCLI);
4503 auto generatedLoops =
4504 ompBuilder->tileLoops(loc.DL, translatedLoops, translatedSizes);
4505 if (!op.getGeneratees().empty()) {
4506 for (
auto [mlirLoop,
genLoop] :
4507 zip_equal(op.getGeneratees(), generatedLoops))
4512 for (
Value applyee : op.getApplyees())
4520static LogicalResult
applyFuse(omp::FuseOp op, llvm::IRBuilderBase &builder,
4523 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
4527 for (
size_t i = 0; i < op.getApplyees().size(); i++) {
4528 Value applyee = op.getApplyees()[i];
4529 llvm::CanonicalLoopInfo *consBuilderCLI =
4531 assert(applyee &&
"Canonical loop must already been translated");
4532 if (op.getFirst().has_value() && i < op.getFirst().value() - 1)
4533 beforeFuse.push_back(consBuilderCLI);
4534 else if (op.getCount().has_value() &&
4535 i >= op.getFirst().value() + op.getCount().value() - 1)
4536 afterFuse.push_back(consBuilderCLI);
4538 toFuse.push_back(consBuilderCLI);
4541 (op.getGeneratees().empty() ||
4542 beforeFuse.size() + afterFuse.size() + 1 == op.getGeneratees().size()) &&
4543 "Wrong number of generatees");
4546 auto generatedLoop = ompBuilder->fuseLoops(loc.DL, toFuse);
4547 if (!op.getGeneratees().empty()) {
4549 for (; i < beforeFuse.size(); i++)
4550 moduleTranslation.
mapOmpLoop(op.getGeneratees()[i], beforeFuse[i]);
4551 moduleTranslation.
mapOmpLoop(op.getGeneratees()[i++], generatedLoop);
4552 for (; i < afterFuse.size(); i++)
4553 moduleTranslation.
mapOmpLoop(op.getGeneratees()[i], afterFuse[i]);
4557 for (
Value applyee : op.getApplyees())
4564static llvm::AtomicOrdering
4567 return llvm::AtomicOrdering::Monotonic;
4570 case omp::ClauseMemoryOrderKind::Seq_cst:
4571 return llvm::AtomicOrdering::SequentiallyConsistent;
4572 case omp::ClauseMemoryOrderKind::Acq_rel:
4573 return llvm::AtomicOrdering::AcquireRelease;
4574 case omp::ClauseMemoryOrderKind::Acquire:
4575 return llvm::AtomicOrdering::Acquire;
4576 case omp::ClauseMemoryOrderKind::Release:
4577 return llvm::AtomicOrdering::Release;
4578 case omp::ClauseMemoryOrderKind::Relaxed:
4579 return llvm::AtomicOrdering::Monotonic;
4581 llvm_unreachable(
"Unknown ClauseMemoryOrderKind kind");
4588 auto readOp = cast<omp::AtomicReadOp>(opInst);
4593 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4596 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4599 llvm::Value *x = moduleTranslation.
lookupValue(readOp.getX());
4600 llvm::Value *v = moduleTranslation.
lookupValue(readOp.getV());
4602 llvm::Type *elementType =
4603 moduleTranslation.
convertType(readOp.getElementType());
4605 llvm::OpenMPIRBuilder::AtomicOpValue V = {v, elementType,
false,
false};
4606 llvm::OpenMPIRBuilder::AtomicOpValue X = {x, elementType,
false,
false};
4607 builder.restoreIP(ompBuilder->createAtomicRead(ompLoc, X, V, AO, allocaIP));
4615 auto writeOp = cast<omp::AtomicWriteOp>(opInst);
4620 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4623 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4625 llvm::Value *expr = moduleTranslation.
lookupValue(writeOp.getExpr());
4626 llvm::Value *dest = moduleTranslation.
lookupValue(writeOp.getX());
4627 llvm::Type *ty = moduleTranslation.
convertType(writeOp.getExpr().getType());
4628 llvm::OpenMPIRBuilder::AtomicOpValue x = {dest, ty,
false,
4631 ompBuilder->createAtomicWrite(ompLoc, x, expr, ao, allocaIP));
4639 .Case([&](LLVM::AddOp) {
return llvm::AtomicRMWInst::BinOp::Add; })
4640 .Case([&](LLVM::SubOp) {
return llvm::AtomicRMWInst::BinOp::Sub; })
4641 .Case([&](LLVM::AndOp) {
return llvm::AtomicRMWInst::BinOp::And; })
4642 .Case([&](LLVM::OrOp) {
return llvm::AtomicRMWInst::BinOp::Or; })
4643 .Case([&](LLVM::XOrOp) {
return llvm::AtomicRMWInst::BinOp::Xor; })
4644 .Case([&](LLVM::UMaxOp) {
return llvm::AtomicRMWInst::BinOp::UMax; })
4645 .Case([&](LLVM::UMinOp) {
return llvm::AtomicRMWInst::BinOp::UMin; })
4646 .Case([&](LLVM::FAddOp) {
return llvm::AtomicRMWInst::BinOp::FAdd; })
4647 .Case([&](LLVM::FSubOp) {
return llvm::AtomicRMWInst::BinOp::FSub; })
4648 .Default(llvm::AtomicRMWInst::BinOp::BAD_BINOP);
4652 bool &isIgnoreDenormalMode,
4653 bool &isFineGrainedMemory,
4654 bool &isRemoteMemory) {
4655 isIgnoreDenormalMode =
false;
4656 isFineGrainedMemory =
false;
4657 isRemoteMemory =
false;
4658 if (atomicUpdateOp &&
4659 atomicUpdateOp->hasAttr(atomicUpdateOp.getAtomicControlAttrName())) {
4660 mlir::omp::AtomicControlAttr atomicControlAttr =
4661 atomicUpdateOp.getAtomicControlAttr();
4662 isIgnoreDenormalMode = atomicControlAttr.getIgnoreDenormalMode();
4663 isFineGrainedMemory = atomicControlAttr.getFineGrainedMemory();
4664 isRemoteMemory = atomicControlAttr.getRemoteMemory();
4671 llvm::IRBuilderBase &builder,
4678 auto &innerOpList = opInst.getRegion().front().getOperations();
4679 bool isXBinopExpr{
false};
4680 llvm::AtomicRMWInst::BinOp binop;
4682 llvm::Value *llvmExpr =
nullptr;
4683 llvm::Value *llvmX =
nullptr;
4684 llvm::Type *llvmXElementType =
nullptr;
4685 if (innerOpList.size() == 2) {
4691 opInst.getRegion().getArgument(0))) {
4692 return opInst.emitError(
"no atomic update operation with region argument"
4693 " as operand found inside atomic.update region");
4696 isXBinopExpr = innerOp.
getOperand(0) == opInst.getRegion().getArgument(0);
4698 llvmExpr = moduleTranslation.
lookupValue(mlirExpr);
4702 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
4704 llvmX = moduleTranslation.
lookupValue(opInst.getX());
4706 opInst.getRegion().getArgument(0).getType());
4707 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
4711 llvm::AtomicOrdering atomicOrdering =
4716 [&opInst, &moduleTranslation](
4717 llvm::Value *atomicx,
4720 moduleTranslation.
mapValue(*opInst.getRegion().args_begin(), atomicx);
4721 moduleTranslation.
mapBlock(&bb, builder.GetInsertBlock());
4722 if (failed(moduleTranslation.
convertBlock(bb,
true, builder)))
4723 return llvm::make_error<PreviouslyReportedError>();
4725 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.
getTerminator());
4726 assert(yieldop && yieldop.getResults().size() == 1 &&
4727 "terminator must be omp.yield op and it must have exactly one "
4729 return moduleTranslation.
lookupValue(yieldop.getResults()[0]);
4732 bool isIgnoreDenormalMode;
4733 bool isFineGrainedMemory;
4734 bool isRemoteMemory;
4739 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4740 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4741 ompBuilder->createAtomicUpdate(ompLoc, allocaIP, llvmAtomicX, llvmExpr,
4742 atomicOrdering, binop, updateFn,
4743 isXBinopExpr, isIgnoreDenormalMode,
4744 isFineGrainedMemory, isRemoteMemory);
4749 builder.restoreIP(*afterIP);
4755 llvm::IRBuilderBase &builder,
4762 bool isXBinopExpr =
false, isPostfixUpdate =
false;
4763 llvm::AtomicRMWInst::BinOp binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
4765 omp::AtomicUpdateOp atomicUpdateOp = atomicCaptureOp.getAtomicUpdateOp();
4766 omp::AtomicWriteOp atomicWriteOp = atomicCaptureOp.getAtomicWriteOp();
4768 assert((atomicUpdateOp || atomicWriteOp) &&
4769 "internal op must be an atomic.update or atomic.write op");
4771 if (atomicWriteOp) {
4772 isPostfixUpdate =
true;
4773 mlirExpr = atomicWriteOp.getExpr();
4775 isPostfixUpdate = atomicCaptureOp.getSecondOp() ==
4776 atomicCaptureOp.getAtomicUpdateOp().getOperation();
4777 auto &innerOpList = atomicUpdateOp.getRegion().front().getOperations();
4780 if (innerOpList.size() == 2) {
4783 atomicUpdateOp.getRegion().getArgument(0))) {
4784 return atomicUpdateOp.emitError(
4785 "no atomic update operation with region argument"
4786 " as operand found inside atomic.update region");
4790 innerOp.
getOperand(0) == atomicUpdateOp.getRegion().getArgument(0);
4793 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
4797 llvm::Value *llvmExpr = moduleTranslation.
lookupValue(mlirExpr);
4798 llvm::Value *llvmX =
4799 moduleTranslation.
lookupValue(atomicCaptureOp.getAtomicReadOp().getX());
4800 llvm::Value *llvmV =
4801 moduleTranslation.
lookupValue(atomicCaptureOp.getAtomicReadOp().getV());
4802 llvm::Type *llvmXElementType = moduleTranslation.
convertType(
4803 atomicCaptureOp.getAtomicReadOp().getElementType());
4804 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
4807 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicV = {llvmV, llvmXElementType,
4811 llvm::AtomicOrdering atomicOrdering =
4815 [&](llvm::Value *atomicx,
4818 return moduleTranslation.
lookupValue(atomicWriteOp.getExpr());
4819 Block &bb = *atomicUpdateOp.getRegion().
begin();
4820 moduleTranslation.
mapValue(*atomicUpdateOp.getRegion().args_begin(),
4822 moduleTranslation.
mapBlock(&bb, builder.GetInsertBlock());
4823 if (failed(moduleTranslation.
convertBlock(bb,
true, builder)))
4824 return llvm::make_error<PreviouslyReportedError>();
4826 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.
getTerminator());
4827 assert(yieldop && yieldop.getResults().size() == 1 &&
4828 "terminator must be omp.yield op and it must have exactly one "
4830 return moduleTranslation.
lookupValue(yieldop.getResults()[0]);
4833 bool isIgnoreDenormalMode;
4834 bool isFineGrainedMemory;
4835 bool isRemoteMemory;
4837 isFineGrainedMemory, isRemoteMemory);
4840 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4841 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4842 ompBuilder->createAtomicCapture(
4843 ompLoc, allocaIP, llvmAtomicX, llvmAtomicV, llvmExpr, atomicOrdering,
4844 binop, updateFn, atomicUpdateOp, isPostfixUpdate, isXBinopExpr,
4845 isIgnoreDenormalMode, isFineGrainedMemory, isRemoteMemory);
4847 if (failed(
handleError(afterIP, *atomicCaptureOp)))
4850 builder.restoreIP(*afterIP);
4855 omp::ClauseCancellationConstructType directive) {
4856 switch (directive) {
4857 case omp::ClauseCancellationConstructType::Loop:
4858 return llvm::omp::Directive::OMPD_for;
4859 case omp::ClauseCancellationConstructType::Parallel:
4860 return llvm::omp::Directive::OMPD_parallel;
4861 case omp::ClauseCancellationConstructType::Sections:
4862 return llvm::omp::Directive::OMPD_sections;
4863 case omp::ClauseCancellationConstructType::Taskgroup:
4864 return llvm::omp::Directive::OMPD_taskgroup;
4866 llvm_unreachable(
"Unhandled cancellation construct type");
4875 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4878 llvm::Value *ifCond =
nullptr;
4879 if (
Value ifVar = op.getIfExpr())
4882 llvm::omp::Directive cancelledDirective =
4885 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4886 ompBuilder->createCancel(ompLoc, ifCond, cancelledDirective);
4888 if (failed(
handleError(afterIP, *op.getOperation())))
4891 builder.restoreIP(afterIP.get());
4898 llvm::IRBuilderBase &builder,
4903 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4906 llvm::omp::Directive cancelledDirective =
4909 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4910 ompBuilder->createCancellationPoint(ompLoc, cancelledDirective);
4912 if (failed(
handleError(afterIP, *op.getOperation())))
4915 builder.restoreIP(afterIP.get());
4925 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4927 auto threadprivateOp = cast<omp::ThreadprivateOp>(opInst);
4932 Value symAddr = threadprivateOp.getSymAddr();
4935 if (
auto asCast = dyn_cast<LLVM::AddrSpaceCastOp>(symOp))
4938 if (!isa<LLVM::AddressOfOp>(symOp))
4939 return opInst.
emitError(
"Addressing symbol not found");
4940 LLVM::AddressOfOp addressOfOp = dyn_cast<LLVM::AddressOfOp>(symOp);
4942 LLVM::GlobalOp global =
4943 addressOfOp.getGlobal(moduleTranslation.
symbolTable());
4944 llvm::GlobalValue *globalValue = moduleTranslation.
lookupGlobal(global);
4945 llvm::Type *type = globalValue->getValueType();
4946 llvm::TypeSize typeSize =
4947 builder.GetInsertBlock()->getModule()->getDataLayout().getTypeStoreSize(
4949 llvm::ConstantInt *size = builder.getInt64(typeSize.getFixedValue());
4950 llvm::Value *callInst = ompBuilder->createCachedThreadPrivate(
4951 ompLoc, globalValue, size, global.getSymName() +
".cache");
4957static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind
4959 switch (deviceClause) {
4960 case mlir::omp::DeclareTargetDeviceType::host:
4961 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseHost;
4963 case mlir::omp::DeclareTargetDeviceType::nohost:
4964 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseNoHost;
4966 case mlir::omp::DeclareTargetDeviceType::any:
4967 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseAny;
4970 llvm_unreachable(
"unhandled device clause");
4973static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind
4975 mlir::omp::DeclareTargetCaptureClause captureClause) {
4976 switch (captureClause) {
4977 case mlir::omp::DeclareTargetCaptureClause::to:
4978 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo;
4979 case mlir::omp::DeclareTargetCaptureClause::link:
4980 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink;
4981 case mlir::omp::DeclareTargetCaptureClause::enter:
4982 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter;
4983 case mlir::omp::DeclareTargetCaptureClause::none:
4984 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryNone;
4986 llvm_unreachable(
"unhandled capture clause");
4991 if (
auto addrCast = dyn_cast_if_present<LLVM::AddrSpaceCastOp>(op))
4993 if (
auto addressOfOp = dyn_cast_if_present<LLVM::AddressOfOp>(op)) {
4994 auto modOp = addressOfOp->getParentOfType<mlir::ModuleOp>();
4995 return modOp.lookupSymbol(addressOfOp.getGlobalName());
5002 if (
auto addrCast = dyn_cast_if_present<LLVM::AddrSpaceCastOp>(op))
5003 value = addrCast.getOperand();
5020static llvm::SmallString<64>
5022 llvm::OpenMPIRBuilder &ompBuilder,
5023 llvm::vfs::FileSystem &vfs) {
5025 llvm::raw_svector_ostream os(suffix);
5028 auto fileInfoCallBack = [&loc]() {
5029 return std::pair<std::string, uint64_t>(
5030 llvm::StringRef(loc.getFilename()), loc.getLine());
5035 ompBuilder.getTargetEntryUniqueInfo(fileInfoCallBack, vfs).FileID);
5037 os <<
"_decl_tgt_ref_ptr";
5043 if (
auto declareTargetGlobal =
5044 dyn_cast_if_present<omp::DeclareTargetInterface>(
5046 if (declareTargetGlobal.getDeclareTargetCaptureClause() ==
5047 omp::DeclareTargetCaptureClause::link)
5053 if (
auto declareTargetGlobal =
5054 dyn_cast_if_present<omp::DeclareTargetInterface>(
5056 if (declareTargetGlobal.getDeclareTargetCaptureClause() ==
5057 omp::DeclareTargetCaptureClause::to ||
5058 declareTargetGlobal.getDeclareTargetCaptureClause() ==
5059 omp::DeclareTargetCaptureClause::enter)
5073 if (
auto declareTargetGlobal =
5074 dyn_cast<omp::DeclareTargetInterface>(gOp.getOperation())) {
5077 if ((declareTargetGlobal.getDeclareTargetCaptureClause() ==
5078 omp::DeclareTargetCaptureClause::link) ||
5079 (declareTargetGlobal.getDeclareTargetCaptureClause() ==
5080 omp::DeclareTargetCaptureClause::to &&
5081 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
5085 if (gOp.getSymName().contains(suffix))
5090 (gOp.getSymName().str() + suffix.str()).str());
5099struct MapInfosTy : llvm::OpenMPIRBuilder::MapInfosTy {
5100 SmallVector<Operation *, 4> Mappers;
5103 void append(MapInfosTy &curInfo) {
5104 Mappers.append(curInfo.Mappers.begin(), curInfo.Mappers.end());
5105 llvm::OpenMPIRBuilder::MapInfosTy::append(curInfo);
5114struct MapInfoData : MapInfosTy {
5115 llvm::SmallVector<bool, 4> IsDeclareTarget;
5116 llvm::SmallVector<bool, 4> IsAMember;
5118 llvm::SmallVector<bool, 4> IsAMapping;
5119 llvm::SmallVector<mlir::Operation *, 4> MapClause;
5120 llvm::SmallVector<llvm::Value *, 4> OriginalValue;
5123 llvm::SmallVector<llvm::Type *, 4> BaseType;
5126 void append(MapInfoData &CurInfo) {
5127 IsDeclareTarget.append(CurInfo.IsDeclareTarget.begin(),
5128 CurInfo.IsDeclareTarget.end());
5129 MapClause.append(CurInfo.MapClause.begin(), CurInfo.MapClause.end());
5130 OriginalValue.append(CurInfo.OriginalValue.begin(),
5131 CurInfo.OriginalValue.end());
5132 BaseType.append(CurInfo.BaseType.begin(), CurInfo.BaseType.end());
5133 MapInfosTy::append(CurInfo);
5137enum class TargetDirectiveEnumTy : uint32_t {
5141 TargetEnterData = 3,
5146static TargetDirectiveEnumTy getTargetDirectiveEnumTyFromOp(Operation *op) {
5147 return llvm::TypeSwitch<Operation *, TargetDirectiveEnumTy>(op)
5148 .Case([](omp::TargetDataOp) {
return TargetDirectiveEnumTy::TargetData; })
5149 .Case([](omp::TargetEnterDataOp) {
5150 return TargetDirectiveEnumTy::TargetEnterData;
5152 .Case([&](omp::TargetExitDataOp) {
5153 return TargetDirectiveEnumTy::TargetExitData;
5155 .Case([&](omp::TargetUpdateOp) {
5156 return TargetDirectiveEnumTy::TargetUpdate;
5158 .Case([&](omp::TargetOp) {
return TargetDirectiveEnumTy::Target; })
5159 .Default([&](Operation *op) {
return TargetDirectiveEnumTy::None; });
5166 if (
auto nestedArrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(
5167 arrTy.getElementType()))
5184 llvm::Value *basePointer,
5185 llvm::Type *baseType,
5186 llvm::IRBuilderBase &builder,
5188 if (
auto memberClause =
5189 mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(clauseOp)) {
5194 if (!memberClause.getBounds().empty()) {
5195 llvm::Value *elementCount = builder.getInt64(1);
5196 for (
auto bounds : memberClause.getBounds()) {
5197 if (
auto boundOp = mlir::dyn_cast_if_present<mlir::omp::MapBoundsOp>(
5198 bounds.getDefiningOp())) {
5203 elementCount = builder.CreateMul(
5207 moduleTranslation.
lookupValue(boundOp.getUpperBound()),
5208 moduleTranslation.
lookupValue(boundOp.getLowerBound())),
5209 builder.getInt64(1)));
5216 if (
auto arrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(type))
5224 return builder.CreateMul(elementCount,
5225 builder.getInt64(underlyingTypeSzInBits / 8));
5236static llvm::omp::OpenMPOffloadMappingFlags
5238 const bool hasExplicitMap =
5239 (mlirFlags &
~omp::ClauseMapFlags::is_device_ptr) !=
5240 omp::ClauseMapFlags::none;
5242 llvm::omp::OpenMPOffloadMappingFlags mapType =
5243 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
5245 if (bitEnumContainsAll(mlirFlags, omp::ClauseMapFlags::to))
5246 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
5248 if (bitEnumContainsAll(mlirFlags, omp::ClauseMapFlags::from))
5249 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
5251 if (bitEnumContainsAll(mlirFlags, omp::ClauseMapFlags::always))
5252 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
5254 if (bitEnumContainsAll(mlirFlags, omp::ClauseMapFlags::del))
5255 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
5257 if (bitEnumContainsAll(mlirFlags, omp::ClauseMapFlags::return_param))
5258 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
5260 if (bitEnumContainsAll(mlirFlags, omp::ClauseMapFlags::priv))
5261 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRIVATE;
5263 if (bitEnumContainsAll(mlirFlags, omp::ClauseMapFlags::literal))
5264 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
5266 if (bitEnumContainsAll(mlirFlags, omp::ClauseMapFlags::implicit))
5267 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
5269 if (bitEnumContainsAll(mlirFlags, omp::ClauseMapFlags::close))
5270 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
5272 if (bitEnumContainsAll(mlirFlags, omp::ClauseMapFlags::present))
5273 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT;
5275 if (bitEnumContainsAll(mlirFlags, omp::ClauseMapFlags::ompx_hold))
5276 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD;
5278 if (bitEnumContainsAll(mlirFlags, omp::ClauseMapFlags::attach))
5279 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH;
5281 if (bitEnumContainsAll(mlirFlags, omp::ClauseMapFlags::is_device_ptr)) {
5282 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
5283 if (!hasExplicitMap)
5284 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
5294 ArrayRef<Value> useDevAddrOperands = {},
5295 ArrayRef<Value> hasDevAddrOperands = {}) {
5297 auto checkRefPtrOrPteeMapWithAttach = [](omp::ClauseMapFlags mapType) {
5299 bitEnumContainsAll(mapType, omp::ClauseMapFlags::ref_ptr) ||
5300 bitEnumContainsAll(mapType, omp::ClauseMapFlags::ref_ptee);
5301 return hasRefType &&
5302 bitEnumContainsAll(mapType, omp::ClauseMapFlags::attach);
5305 auto checkIsAMember = [](
const auto &mapVars,
auto mapOp) {
5313 for (Value mapValue : mapVars) {
5314 auto map = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
5315 for (
auto member : map.getMembers())
5316 if (member == mapOp)
5323 for (Value mapValue : mapVars) {
5324 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
5325 bool isRefPtrOrPteeMapWithAttach =
5326 checkRefPtrOrPteeMapWithAttach(mapOp.getMapType());
5327 Value offloadPtr = (mapOp.getVarPtrPtr() && !isRefPtrOrPteeMapWithAttach)
5328 ? mapOp.getVarPtrPtr()
5329 : mapOp.getVarPtr();
5330 mapData.OriginalValue.push_back(moduleTranslation.
lookupValue(offloadPtr));
5331 mapData.Pointers.push_back(
5332 isRefPtrOrPteeMapWithAttach
5333 ? moduleTranslation.
lookupValue(mapOp.getVarPtrPtr())
5334 : mapData.OriginalValue.back());
5336 if (llvm::Value *refPtr =
5338 mapData.IsDeclareTarget.push_back(
true);
5339 mapData.BasePointers.push_back(refPtr);
5341 mapData.IsDeclareTarget.push_back(
true);
5342 mapData.BasePointers.push_back(mapData.OriginalValue.back());
5344 mapData.IsDeclareTarget.push_back(
false);
5345 mapData.BasePointers.push_back(mapData.OriginalValue.back());
5351 mapData.BaseType.push_back(moduleTranslation.
convertType(
5352 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtrType().value()
5353 : mapOp.getVarPtrType()));
5360 mlir::Type sizeType = (isRefPtrOrPteeMapWithAttach || !mapOp.getVarPtrPtr())
5361 ? mapOp.getVarPtrType()
5362 : mapOp.getVarPtrPtrType().value();
5364 dl, sizeType, isRefPtrOrPteeMapWithAttach ?
nullptr : mapOp,
5365 mapData.Pointers.back(), moduleTranslation.
convertType(sizeType),
5366 builder, moduleTranslation));
5367 mapData.MapClause.push_back(mapOp.getOperation());
5369 mapData.Names.push_back(LLVM::createMappingInformation(
5371 mapData.DevicePointers.push_back(llvm::OpenMPIRBuilder::DeviceInfoTy::None);
5372 if (mapOp.getMapperId())
5373 mapData.Mappers.push_back(
5375 mapOp, mapOp.getMapperIdAttr()));
5377 mapData.Mappers.push_back(
nullptr);
5378 mapData.IsAMapping.push_back(
true);
5379 mapData.IsAMember.push_back(checkIsAMember(mapVars, mapOp));
5382 auto findMapInfo = [&mapData](llvm::Value *val,
5383 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy,
5384 size_t memberCount) {
5387 for (llvm::Value *basePtr : mapData.OriginalValue) {
5388 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[index]);
5399 (mapData.Types[index] &
5400 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH) ==
5401 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH;
5402 if (!isAttachMap && basePtr == val && mapData.IsAMapping[index] &&
5403 memberCount == mapOp.getMembers().size()) {
5405 mapData.Types[index] |=
5406 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
5407 mapData.DevicePointers[index] = devInfoTy;
5415 auto addDevInfos = [&](
const llvm::ArrayRef<Value> &useDevOperands,
5416 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
5417 for (Value mapValue : useDevOperands) {
5418 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
5420 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
5421 llvm::Value *origValue = moduleTranslation.
lookupValue(offloadPtr);
5424 if (!findMapInfo(origValue, devInfoTy, mapOp.getMembers().size())) {
5425 mapData.OriginalValue.push_back(origValue);
5426 mapData.Pointers.push_back(mapData.OriginalValue.back());
5427 mapData.IsDeclareTarget.push_back(
false);
5428 mapData.BasePointers.push_back(mapData.OriginalValue.back());
5429 mlir::Type baseTy = mapOp.getVarPtrPtr()
5430 ? mapOp.getVarPtrPtrType().value()
5431 : mapOp.getVarPtrType();
5432 mapData.BaseType.push_back(moduleTranslation.
convertType(baseTy));
5433 mapData.Sizes.push_back(builder.getInt64(0));
5434 mapData.MapClause.push_back(mapOp.getOperation());
5435 mapData.Types.push_back(
5436 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM);
5437 mapData.Names.push_back(LLVM::createMappingInformation(
5439 mapData.DevicePointers.push_back(devInfoTy);
5440 mapData.Mappers.push_back(
nullptr);
5441 mapData.IsAMapping.push_back(
false);
5442 mapData.IsAMember.push_back(checkIsAMember(useDevOperands, mapOp));
5447 addDevInfos(useDevAddrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
5448 addDevInfos(useDevPtrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer);
5450 for (Value mapValue : hasDevAddrOperands) {
5451 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
5453 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
5454 llvm::Value *origValue = moduleTranslation.
lookupValue(offloadPtr);
5456 auto mapTypeAlways = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
5458 (mapOp.getMapType() & omp::ClauseMapFlags::is_device_ptr) !=
5459 omp::ClauseMapFlags::none;
5461 mapData.OriginalValue.push_back(origValue);
5462 mapData.BasePointers.push_back(origValue);
5463 mapData.Pointers.push_back(origValue);
5464 mapData.IsDeclareTarget.push_back(
false);
5466 mlir::Type baseTy = mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtrType().value()
5467 : mapOp.getVarPtrType();
5468 mapData.BaseType.push_back(moduleTranslation.
convertType(baseTy));
5469 mapData.Sizes.push_back(builder.getInt64(dl.
getTypeSize(baseTy)));
5471 mapData.MapClause.push_back(mapOp.getOperation());
5472 if (llvm::to_underlying(mapType & mapTypeAlways)) {
5476 mapData.Types.push_back(mapType);
5480 if (mapOp.getMapperId()) {
5481 mapData.Mappers.push_back(
5483 mapOp, mapOp.getMapperIdAttr()));
5485 mapData.Mappers.push_back(
nullptr);
5490 mapData.Types.push_back(
5491 isDevicePtr ? mapType
5492 : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL);
5493 mapData.Mappers.push_back(
nullptr);
5495 mapData.Names.push_back(LLVM::createMappingInformation(
5497 mapData.DevicePointers.push_back(
5498 isDevicePtr ? llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer
5499 : llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
5500 mapData.IsAMapping.push_back(
false);
5501 mapData.IsAMember.push_back(checkIsAMember(hasDevAddrOperands, mapOp));
5506 auto *res = llvm::find(mapData.MapClause, memberOp);
5507 assert(res != mapData.MapClause.end() &&
5508 "MapInfoOp for member not found in MapData, cannot return index");
5509 return std::distance(mapData.MapClause.begin(), res);
5513 omp::MapInfoOp mapInfo,
bool first =
true) {
5514 ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
5524 auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]);
5525 auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]);
5527 for (auto it : llvm::zip(memberIndicesA, memberIndicesB)) {
5528 int64_t aIndex = mlir::cast<IntegerAttr>(std::get<0>(it)).getInt();
5529 int64_t bIndex = mlir::cast<IntegerAttr>(std::get<1>(it)).getInt();
5531 if (aIndex == bIndex)
5534 if (aIndex < bIndex)
5537 if (aIndex > bIndex)
5544 bool memberAParent = memberIndicesA.size() < memberIndicesB.size();
5546 occludedChildren.push_back(
b);
5548 occludedChildren.push_back(a);
5549 return memberAParent;
5552 for (
auto v : occludedChildren)
5559 ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
5561 if (indexAttr.size() == 1)
5562 return cast<omp::MapInfoOp>(mapInfo.getMembers()[0].getDefiningOp());
5566 return llvm::cast<omp::MapInfoOp>(
5567 mapInfo.getMembers()[
indices.front()].getDefiningOp());
5590static std::vector<llvm::Value *>
5592 llvm::IRBuilderBase &builder,
bool isArrayTy,
5594 std::vector<llvm::Value *> idx;
5605 idx.push_back(builder.getInt64(0));
5606 for (
int i = bounds.size() - 1; i >= 0; --i) {
5607 if (
auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
5608 bounds[i].getDefiningOp())) {
5609 idx.push_back(moduleTranslation.
lookupValue(boundOp.getLowerBound()));
5627 std::vector<llvm::Value *> dimensionIndexSizeOffset;
5628 for (
int i = bounds.size() - 1; i >= 0; --i) {
5629 if (
auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
5630 bounds[i].getDefiningOp())) {
5631 if (i == ((
int)bounds.size() - 1))
5633 moduleTranslation.
lookupValue(boundOp.getLowerBound()));
5635 idx.back() = builder.CreateAdd(
5636 builder.CreateMul(idx.back(), moduleTranslation.
lookupValue(
5637 boundOp.getExtent())),
5638 moduleTranslation.
lookupValue(boundOp.getLowerBound()));
5647 llvm::transform(values, std::back_inserter(ints), [](
Attribute value) {
5648 return cast<IntegerAttr>(value).getInt();
5656 omp::MapInfoOp parentOp) {
5658 if (parentOp.getMembers().empty())
5662 if (parentOp.getMembers().size() == 1) {
5663 overlapMapDataIdxs.push_back(0);
5667 ArrayAttr indexAttr = parentOp.getMembersIndexAttr();
5668 size_t numMembers = indexAttr.size();
5672 for (
auto [i, indicesAttr] : llvm::enumerate(indexAttr))
5673 getAsIntegers(cast<ArrayAttr>(indicesAttr), memberIndices[i]);
5679 llvm::SmallDenseSet<size_t> skipIndices;
5680 for (
size_t i = 0; i < numMembers; ++i) {
5681 const auto &iIndices = memberIndices[i];
5682 for (
size_t j = 0;
j < numMembers; ++
j) {
5685 const auto &jIndices = memberIndices[
j];
5687 if (jIndices.size() < iIndices.size() &&
5688 std::equal(jIndices.begin(), jIndices.end(), iIndices.begin())) {
5689 skipIndices.insert(i);
5696 for (
size_t i = 0; i < numMembers; ++i)
5697 if (!skipIndices.contains(i))
5698 overlapMapDataIdxs.push_back(i);
5710 if (mapOp.getVarPtrPtr())
5733 llvm::OpenMPIRBuilder &ompBuilder, MapInfoData &mapData,
5734 size_t mapDataIdx, MapInfosTy &combinedInfo,
5735 TargetDirectiveEnumTy targetDirective,
5736 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag =
5737 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE,
5738 bool isTargetParam =
true,
int mapDataParentIdx = -1) {
5739 auto mapFlag = mapData.Types[mapDataIdx];
5740 auto mapInfoOp = llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIdx]);
5744 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH) ==
5745 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH);
5751 if (isTargetParam &&
5752 (targetDirective == TargetDirectiveEnumTy::Target &&
5753 !mapData.IsDeclareTarget[mapDataIdx]) &&
5755 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
5757 if (mapInfoOp.getMapCaptureType() == omp::VariableCaptureKind::ByCopy &&
5759 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
5768 if (memberOfFlag != llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE) {
5769 if (!isPtrTy && !isAttachMap)
5770 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
5777 mapFlag &=
~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
5787 if (isPtrTy && !isAttachMap && mapData.IsDeclareTarget[mapDataIdx])
5788 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
5797 !bitEnumContainsAll(mapInfoOp.getMapType(),
5798 omp::ClauseMapFlags::ref_ptr) &&
5799 bitEnumContainsAll(mapInfoOp.getMapType(), omp::ClauseMapFlags::ref_ptee);
5800 bool isRefPtrPtee = bitEnumContainsAll(mapInfoOp.getMapType(),
5801 omp::ClauseMapFlags::ref_ptr |
5802 omp::ClauseMapFlags::ref_ptee);
5804 if (!mapInfoOp->getParentOfType<omp::DeclareMapperOp>() &&
5805 mapDataParentIdx >= 0 && !(isRefPtee || (isRefPtrPtee && isPtrTy))) {
5806 combinedInfo.BasePointers.emplace_back(
5807 mapData.BasePointers[mapDataParentIdx]);
5809 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIdx]);
5812 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIdx]);
5813 combinedInfo.DevicePointers.emplace_back(
5814 memberOfFlag != llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE
5815 ? llvm::OpenMPIRBuilder::DeviceInfoTy::None
5816 : mapData.DevicePointers[mapDataIdx]);
5817 combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIdx]);
5818 combinedInfo.Names.emplace_back(mapData.Names[mapDataIdx]);
5819 combinedInfo.Types.emplace_back(mapFlag);
5820 combinedInfo.Sizes.emplace_back(
5821 isPtrTy ? builder.CreateSelect(
5822 builder.CreateIsNull(mapData.Pointers[mapDataIdx]),
5823 builder.getInt64(0), mapData.Sizes[mapDataIdx])
5824 : mapData.Sizes[mapDataIdx]);
5844 llvm::OpenMPIRBuilder &ompBuilder,
DataLayout &dl, MapInfosTy &combinedInfo,
5845 MapInfoData &mapData, uint64_t mapDataIndex,
5846 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag,
5847 TargetDirectiveEnumTy targetDirective) {
5848 using MapFlags = llvm::omp::OpenMPOffloadMappingFlags;
5849 assert(!ompBuilder.Config.isTargetDevice() &&
5850 "function only supported for host device codegen");
5852 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
5853 auto *parentMapper = mapData.Mappers[mapDataIndex];
5859 MapFlags baseFlag = (targetDirective == TargetDirectiveEnumTy::Target &&
5860 !mapData.IsDeclareTarget[mapDataIndex])
5861 ? MapFlags::OMP_MAP_TARGET_PARAM
5862 : MapFlags::OMP_MAP_NONE;
5868 MapFlags parentFlags = mapData.Types[mapDataIndex];
5869 MapFlags preserve = MapFlags::OMP_MAP_TO | MapFlags::OMP_MAP_FROM |
5870 MapFlags::OMP_MAP_ALWAYS | MapFlags::OMP_MAP_CLOSE |
5871 MapFlags::OMP_MAP_PRESENT |
5872 MapFlags::OMP_MAP_OMPX_HOLD |
5873 MapFlags::OMP_MAP_IMPLICIT;
5874 baseFlag |= (parentFlags & preserve);
5876 MapFlags parentFlags = mapData.Types[mapDataIndex];
5878 MapFlags::OMP_MAP_PRESENT | MapFlags::OMP_MAP_RETURN_PARAM;
5879 baseFlag |= (parentFlags & preserve);
5882 combinedInfo.Types.emplace_back(baseFlag);
5883 combinedInfo.DevicePointers.emplace_back(
5884 mapData.DevicePointers[mapDataIndex]);
5888 combinedInfo.Mappers.emplace_back(
5889 parentMapper && !parentClause.getPartialMap() ? parentMapper :
nullptr);
5891 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
5892 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
5901 llvm::Value *lowAddr, *highAddr;
5902 if (!parentClause.getPartialMap()) {
5903 lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
5904 builder.getPtrTy());
5905 highAddr = builder.CreatePointerCast(
5906 builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
5907 mapData.Pointers[mapDataIndex], 1),
5908 builder.getPtrTy());
5909 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
5911 auto mapOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
5914 lowAddr = builder.CreatePointerCast(mapData.BasePointers[firstMemberIdx],
5915 builder.getPtrTy());
5919 auto lastMemberMapInfo =
5920 cast<omp::MapInfoOp>(mapData.MapClause[lastMemberIdx]);
5929 bool isRefPteeMap = bitEnumContainsAll(lastMemberMapInfo.getMapType(),
5930 omp::ClauseMapFlags::ref_ptee) &&
5931 !bitEnumContainsAll(lastMemberMapInfo.getMapType(),
5932 omp::ClauseMapFlags::ref_ptr);
5933 llvm::Type *castType = mapData.BaseType[lastMemberIdx];
5936 moduleTranslation.
convertType(lastMemberMapInfo.getVarPtrType());
5937 highAddr = builder.CreatePointerCast(
5938 builder.CreateGEP(castType, mapData.BasePointers[lastMemberIdx],
5939 builder.getInt64(1)),
5940 builder.getPtrTy());
5941 combinedInfo.Pointers.emplace_back(mapData.BasePointers[firstMemberIdx]);
5944 llvm::Value *size = builder.CreateIntCast(
5945 builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
5946 builder.getInt64Ty(),
5948 combinedInfo.Sizes.push_back(size);
5956 if (!parentClause.getPartialMap()) {
5961 MapFlags mapFlag = mapData.Types[mapDataIndex];
5962 bool hasMapClose = (MapFlags(mapFlag) & MapFlags::OMP_MAP_CLOSE) ==
5963 MapFlags::OMP_MAP_CLOSE;
5964 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
5980 if (targetDirective == TargetDirectiveEnumTy::TargetUpdate || hasMapClose ||
5981 overlapIdxs.size() == 1) {
5982 combinedInfo.Types.emplace_back(mapFlag);
5983 combinedInfo.DevicePointers.emplace_back(
5984 mapData.DevicePointers[mapDataIndex]);
5986 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
5987 combinedInfo.BasePointers.emplace_back(
5988 mapData.BasePointers[mapDataIndex]);
5989 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
5990 combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIndex]);
5991 combinedInfo.Mappers.emplace_back(
nullptr);
5997 lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
5998 builder.getPtrTy());
5999 highAddr = builder.CreatePointerCast(
6000 builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
6001 mapData.Pointers[mapDataIndex], 1),
6002 builder.getPtrTy());
6009 mapFlag &=
~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
6016 for (
auto v : overlapIdxs) {
6019 cast<omp::MapInfoOp>(parentClause.getMembers()[v].getDefiningOp()));
6021 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataOverlapIdx]));
6022 combinedInfo.Types.emplace_back(mapFlag);
6023 combinedInfo.DevicePointers.emplace_back(
6024 llvm::OpenMPIRBuilder::DeviceInfoTy::None);
6026 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
6027 combinedInfo.BasePointers.emplace_back(
6028 mapData.BasePointers[mapDataIndex]);
6029 combinedInfo.Mappers.emplace_back(
nullptr);
6030 combinedInfo.Pointers.emplace_back(lowAddr);
6031 auto sizeCalc = builder.CreateIntCast(
6032 builder.CreatePtrDiff(builder.getInt8Ty(),
6033 mapData.OriginalValue[mapDataOverlapIdx],
6035 builder.getInt64Ty(),
true);
6040 auto sizeSel = builder.CreateSelect(
6041 builder.CreateICmpNE(builder.getInt64(0), sizeCalc), sizeCalc,
6042 isPtrMap ? llvm::ConstantExpr::getSizeOf(builder.getPtrTy())
6043 : mapData.Sizes[mapDataOverlapIdx]);
6044 combinedInfo.Sizes.emplace_back(sizeSel);
6045 lowAddr = builder.CreateConstGEP1_32(
6046 isPtrMap ? builder.getPtrTy() : mapData.BaseType[mapDataOverlapIdx],
6047 mapData.BasePointers[mapDataOverlapIdx], 1);
6050 combinedInfo.Types.emplace_back(mapFlag);
6051 combinedInfo.DevicePointers.emplace_back(
6052 llvm::OpenMPIRBuilder::DeviceInfoTy::None);
6054 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
6055 combinedInfo.BasePointers.emplace_back(
6056 mapData.BasePointers[mapDataIndex]);
6057 combinedInfo.Mappers.emplace_back(
nullptr);
6058 combinedInfo.Pointers.emplace_back(lowAddr);
6059 combinedInfo.Sizes.emplace_back(builder.CreateIntCast(
6060 builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
6061 builder.getInt64Ty(),
true));
6067 llvm::IRBuilderBase &builder,
6068 llvm::OpenMPIRBuilder &ompBuilder,
6070 MapInfoData &mapData, uint64_t mapDataIndex,
6071 TargetDirectiveEnumTy targetDirective) {
6072 assert(!ompBuilder.Config.isTargetDevice() &&
6073 "function only supported for host device codegen");
6076 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
6081 if (parentClause.getMembers().size() == 1 && parentClause.getPartialMap()) {
6082 auto memberClause = llvm::cast<omp::MapInfoOp>(
6083 parentClause.getMembers()[0].getDefiningOp());
6096 builder, ompBuilder, mapData, memberDataIdx, combinedInfo,
6098 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE,
6099 true, mapDataIndex);
6103 auto collectMapInfoIdxs =
6106 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
6108 for (
auto member : parentClause.getMembers())
6110 mapData, llvm::cast<omp::MapInfoOp>(member.getDefiningOp())));
6114 collectMapInfoIdxs(mapInfoIdx);
6116 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag =
6117 ompBuilder.getMemberOfFlag(combinedInfo.Types.size());
6118 for (
size_t i = 0; i < mapInfoIdx.size(); i++) {
6123 combinedInfo, mapData, mapInfoIdx[i], memberOfFlag,
6127 combinedInfo, targetDirective, memberOfFlag,
6128 false, mapDataIndex);
6140 llvm::IRBuilderBase &builder) {
6142 "function only supported for host device codegen");
6143 for (
size_t i = 0; i < mapData.MapClause.size(); ++i) {
6144 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
6147 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH) ==
6148 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH);
6153 if (!mapData.IsDeclareTarget[i] ||
6154 (mapData.IsDeclareTarget[i] && isAttachMap)) {
6155 omp::VariableCaptureKind captureKind = mapOp.getMapCaptureType();
6165 switch (captureKind) {
6166 case omp::VariableCaptureKind::ByRef: {
6167 llvm::Value *newV = mapData.Pointers[i];
6169 moduleTranslation, builder, mapData.BaseType[i]->isArrayTy(),
6172 newV = builder.CreateLoad(builder.getPtrTy(), newV);
6174 if (!offsetIdx.empty())
6175 newV = builder.CreateInBoundsGEP(mapData.BaseType[i], newV, offsetIdx,
6177 mapData.Pointers[i] = newV;
6179 case omp::VariableCaptureKind::ByCopy: {
6180 llvm::Type *type = mapData.BaseType[i];
6182 if (mapData.Pointers[i]->getType()->isPointerTy())
6183 newV = builder.CreateLoad(type, mapData.Pointers[i]);
6185 newV = mapData.Pointers[i];
6188 auto curInsert = builder.saveIP();
6189 llvm::DebugLoc DbgLoc = builder.getCurrentDebugLocation();
6191 auto *memTempAlloc =
6192 builder.CreateAlloca(builder.getPtrTy(),
nullptr,
".casted");
6193 builder.SetCurrentDebugLocation(DbgLoc);
6194 builder.restoreIP(curInsert);
6196 builder.CreateStore(newV, memTempAlloc);
6197 newV = builder.CreateLoad(builder.getPtrTy(), memTempAlloc);
6200 mapData.Pointers[i] = newV;
6201 mapData.BasePointers[i] = newV;
6203 case omp::VariableCaptureKind::This:
6204 case omp::VariableCaptureKind::VLAType:
6205 mapData.MapClause[i]->emitOpError(
"Unhandled capture kind");
6216 MapInfoData &mapData,
6217 TargetDirectiveEnumTy targetDirective) {
6219 "function only supported for host device codegen");
6240 for (
size_t i = 0; i < mapData.MapClause.size(); ++i) {
6241 if (mapData.IsAMember[i])
6244 auto mapInfoOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[i]);
6245 if (!mapInfoOp.getMembers().empty()) {
6247 combinedInfo, mapData, i, targetDirective);
6256static llvm::Expected<llvm::Function *>
6258 LLVM::ModuleTranslation &moduleTranslation,
6259 llvm::StringRef mapperFuncName,
6260 TargetDirectiveEnumTy targetDirective);
6262static llvm::Expected<llvm::Function *>
6265 TargetDirectiveEnumTy targetDirective) {
6267 "function only supported for host device codegen");
6268 auto declMapperOp = cast<omp::DeclareMapperOp>(op);
6269 std::string mapperFuncName =
6271 {
"omp_mapper", declMapperOp.getSymName()});
6273 if (
auto *lookupFunc = moduleTranslation.
lookupFunction(mapperFuncName))
6281 if (llvm::Function *existingFunc =
6282 moduleTranslation.
getLLVMModule()->getFunction(mapperFuncName)) {
6283 moduleTranslation.
mapFunction(mapperFuncName, existingFunc);
6284 return existingFunc;
6288 mapperFuncName, targetDirective);
6291static llvm::Expected<llvm::Function *>
6294 llvm::StringRef mapperFuncName,
6295 TargetDirectiveEnumTy targetDirective) {
6297 "function only supported for host device codegen");
6298 auto declMapperOp = cast<omp::DeclareMapperOp>(op);
6299 auto declMapperInfoOp = declMapperOp.getDeclareMapperInfo();
6302 llvm::Type *varType = moduleTranslation.
convertType(declMapperOp.getType());
6305 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
6308 MapInfosTy combinedInfo;
6310 [&](InsertPointTy codeGenIP, llvm::Value *ptrPHI,
6311 llvm::Value *unused2) -> llvm::OpenMPIRBuilder::MapInfosOrErrorTy {
6312 builder.restoreIP(codeGenIP);
6313 moduleTranslation.
mapValue(declMapperOp.getSymVal(), ptrPHI);
6314 moduleTranslation.
mapBlock(&declMapperOp.getRegion().front(),
6315 builder.GetInsertBlock());
6316 if (failed(moduleTranslation.
convertBlock(declMapperOp.getRegion().front(),
6319 return llvm::make_error<PreviouslyReportedError>();
6320 MapInfoData mapData;
6323 genMapInfos(builder, moduleTranslation, dl, combinedInfo, mapData,
6329 return combinedInfo;
6333 if (!combinedInfo.Mappers[i])
6336 moduleTranslation, targetDirective);
6340 genMapInfoCB, varType, mapperFuncName, customMapperCB,
6343 return newFn.takeError();
6344 if ([[maybe_unused]] llvm::Function *mappedFunc =
6346 assert(mappedFunc == *newFn &&
6347 "mapper function mapping disagrees with emitted function");
6349 moduleTranslation.
mapFunction(mapperFuncName, *newFn);
6357 llvm::Value *ifCond =
nullptr;
6358 llvm::Value *deviceID = builder.getInt64(llvm::omp::OMP_DEVICEID_UNDEF);
6362 llvm::omp::RuntimeFunction RTLFn;
6364 TargetDirectiveEnumTy targetDirective = getTargetDirectiveEnumTyFromOp(op);
6367 llvm::OpenMPIRBuilder::TargetDataInfo info(
6370 assert(!ompBuilder->Config.isTargetDevice() &&
6371 "target data/enter/exit/update are host ops");
6372 bool isOffloadEntry = !ompBuilder->Config.TargetTriples.empty();
6374 auto getDeviceID = [&](
mlir::Value dev) -> llvm::Value * {
6375 llvm::Value *v = moduleTranslation.
lookupValue(dev);
6376 return builder.CreateIntCast(v, builder.getInt64Ty(),
true);
6381 .Case([&](omp::TargetDataOp dataOp) {
6385 if (
auto ifVar = dataOp.getIfExpr())
6389 deviceID = getDeviceID(devId);
6391 mapVars = dataOp.getMapVars();
6392 useDevicePtrVars = dataOp.getUseDevicePtrVars();
6393 useDeviceAddrVars = dataOp.getUseDeviceAddrVars();
6396 .Case([&](omp::TargetEnterDataOp enterDataOp) -> LogicalResult {
6400 if (
auto ifVar = enterDataOp.getIfExpr())
6404 deviceID = getDeviceID(devId);
6407 enterDataOp.getNowait()
6408 ? llvm::omp::OMPRTL___tgt_target_data_begin_nowait_mapper
6409 : llvm::omp::OMPRTL___tgt_target_data_begin_mapper;
6410 mapVars = enterDataOp.getMapVars();
6411 info.HasNoWait = enterDataOp.getNowait();
6414 .Case([&](omp::TargetExitDataOp exitDataOp) -> LogicalResult {
6418 if (
auto ifVar = exitDataOp.getIfExpr())
6422 deviceID = getDeviceID(devId);
6424 RTLFn = exitDataOp.getNowait()
6425 ? llvm::omp::OMPRTL___tgt_target_data_end_nowait_mapper
6426 : llvm::omp::OMPRTL___tgt_target_data_end_mapper;
6427 mapVars = exitDataOp.getMapVars();
6428 info.HasNoWait = exitDataOp.getNowait();
6431 .Case([&](omp::TargetUpdateOp updateDataOp) -> LogicalResult {
6435 if (
auto ifVar = updateDataOp.getIfExpr())
6439 deviceID = getDeviceID(devId);
6442 updateDataOp.getNowait()
6443 ? llvm::omp::OMPRTL___tgt_target_data_update_nowait_mapper
6444 : llvm::omp::OMPRTL___tgt_target_data_update_mapper;
6445 mapVars = updateDataOp.getMapVars();
6446 info.HasNoWait = updateDataOp.getNowait();
6449 .DefaultUnreachable(
"unexpected operation");
6454 if (!isOffloadEntry)
6455 ifCond = builder.getFalse();
6457 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
6458 MapInfoData mapData;
6460 builder, useDevicePtrVars, useDeviceAddrVars);
6463 MapInfosTy combinedInfo;
6464 auto genMapInfoCB = [&](InsertPointTy codeGenIP) -> MapInfosTy & {
6465 builder.restoreIP(codeGenIP);
6466 genMapInfos(builder, moduleTranslation, DL, combinedInfo, mapData,
6468 return combinedInfo;
6474 [&moduleTranslation](
6475 llvm::OpenMPIRBuilder::DeviceInfoTy type,
6479 for (
auto [arg, useDevVar] :
6480 llvm::zip_equal(blockArgs, useDeviceVars)) {
6482 auto getMapBasePtr = [](omp::MapInfoOp mapInfoOp) {
6483 return mapInfoOp.getVarPtrPtr() ? mapInfoOp.getVarPtrPtr()
6484 : mapInfoOp.getVarPtr();
6487 auto useDevMap = cast<omp::MapInfoOp>(useDevVar.getDefiningOp());
6488 for (
auto [mapClause, devicePointer, basePointer] : llvm::zip_equal(
6489 mapInfoData.MapClause, mapInfoData.DevicePointers,
6490 mapInfoData.BasePointers)) {
6491 auto mapOp = cast<omp::MapInfoOp>(mapClause);
6492 if (getMapBasePtr(mapOp) != getMapBasePtr(useDevMap) ||
6493 devicePointer != type)
6496 if (llvm::Value *devPtrInfoMap =
6497 mapper ? mapper(basePointer) : basePointer) {
6498 moduleTranslation.
mapValue(arg, devPtrInfoMap);
6505 using BodyGenTy = llvm::OpenMPIRBuilder::BodyGenTy;
6506 auto bodyGenCB = [&](InsertPointTy codeGenIP, BodyGenTy bodyGenType)
6507 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
6510 builder.restoreIP(codeGenIP);
6511 assert(isa<omp::TargetDataOp>(op) &&
6512 "BodyGen requested for non TargetDataOp");
6513 auto blockArgIface = cast<omp::BlockArgOpenMPOpInterface>(op);
6514 Region ®ion = cast<omp::TargetDataOp>(op).getRegion();
6515 switch (bodyGenType) {
6516 case BodyGenTy::Priv:
6518 if (!info.DevicePtrInfoMap.empty()) {
6519 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
6520 blockArgIface.getUseDeviceAddrBlockArgs(),
6521 useDeviceAddrVars, mapData,
6522 [&](llvm::Value *basePointer) -> llvm::Value * {
6523 if (!info.DevicePtrInfoMap[basePointer].second)
6525 return builder.CreateLoad(
6527 info.DevicePtrInfoMap[basePointer].second);
6529 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
6530 blockArgIface.getUseDevicePtrBlockArgs(), useDevicePtrVars,
6531 mapData, [&](llvm::Value *basePointer) {
6532 return info.DevicePtrInfoMap[basePointer].second;
6536 moduleTranslation)))
6537 return llvm::make_error<PreviouslyReportedError>();
6540 case BodyGenTy::DupNoPriv:
6541 if (info.DevicePtrInfoMap.empty()) {
6544 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
6545 blockArgIface.getUseDeviceAddrBlockArgs(),
6546 useDeviceAddrVars, mapData);
6547 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
6548 blockArgIface.getUseDevicePtrBlockArgs(), useDevicePtrVars,
6552 case BodyGenTy::NoPriv:
6554 if (info.DevicePtrInfoMap.empty()) {
6556 moduleTranslation)))
6557 return llvm::make_error<PreviouslyReportedError>();
6561 return builder.saveIP();
6564 auto customMapperCB =
6566 if (!combinedInfo.Mappers[i])
6568 info.HasMapper =
true;
6570 moduleTranslation, targetDirective);
6573 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
6575 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
6577 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = [&]() {
6578 if (isa<omp::TargetDataOp>(op))
6579 return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(),
6580 deallocBlocks, deviceID, ifCond, info,
6581 genMapInfoCB, customMapperCB,
6584 return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(),
6585 deallocBlocks, deviceID, ifCond, info,
6586 genMapInfoCB, customMapperCB, &RTLFn);
6592 builder.restoreIP(*afterIP);
6600 auto distributeOp = cast<omp::DistributeOp>(opInst);
6607 bool doDistributeReduction =
6611 unsigned numReductionVars = teamsOp ? teamsOp.getNumReductionVars() : 0;
6616 if (doDistributeReduction) {
6617 isByRef =
getIsByRef(teamsOp.getReductionByref());
6618 assert(isByRef.size() == teamsOp.getNumReductionVars());
6621 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
6625 llvm::cast<omp::BlockArgOpenMPOpInterface>(*teamsOp)
6626 .getReductionBlockArgs();
6629 teamsOp, reductionArgs, builder, moduleTranslation, allocaIP,
6630 reductionDecls, privateReductionVariables, reductionVariableMap,
6635 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
6637 [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
6642 moduleTranslation, allocaIP, deallocBlocks);
6645 builder.restoreIP(codeGenIP);
6649 distributeOp, builder, moduleTranslation, privVarsInfo, allocaIP);
6651 return llvm::make_error<PreviouslyReportedError>();
6656 return llvm::make_error<PreviouslyReportedError>();
6659 distributeOp, builder, moduleTranslation, privVarsInfo.
mlirVars,
6661 distributeOp.getPrivateNeedsBarrier())))
6662 return llvm::make_error<PreviouslyReportedError>();
6665 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
6668 builder, moduleTranslation);
6670 return regionBlock.takeError();
6671 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
6676 if (!isa_and_present<omp::WsloopOp>(distributeOp.getNestedWrapper())) {
6679 bool hasDistSchedule = distributeOp.getDistScheduleStatic();
6680 auto schedule = hasDistSchedule ? omp::ClauseScheduleKind::Distribute
6681 : omp::ClauseScheduleKind::Static;
6683 bool isOrdered = hasDistSchedule;
6684 std::optional<omp::ScheduleModifier> scheduleMod;
6685 bool isSimd =
false;
6686 llvm::omp::WorksharingLoopType workshareLoopType =
6687 llvm::omp::WorksharingLoopType::DistributeStaticLoop;
6688 bool loopNeedsBarrier =
false;
6689 llvm::Value *chunk = moduleTranslation.
lookupValue(
6690 distributeOp.getDistScheduleChunkSize());
6691 llvm::CanonicalLoopInfo *loopInfo =
6693 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
6694 ompBuilder->applyWorkshareLoop(
6695 ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
6696 convertToScheduleKind(schedule), chunk, isSimd,
6697 scheduleMod == omp::ScheduleModifier::monotonic,
6698 scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
6699 workshareLoopType,
false, hasDistSchedule, chunk);
6702 return wsloopIP.takeError();
6705 distributeOp.getLoc(), privVarsInfo)))
6706 return llvm::make_error<PreviouslyReportedError>();
6708 return llvm::Error::success();
6712 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
6714 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
6715 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
6716 ompBuilder->createDistribute(ompLoc, allocaIP, deallocBlocks, bodyGenCB);
6721 builder.restoreIP(*afterIP);
6723 if (doDistributeReduction) {
6726 teamsOp, builder, moduleTranslation, allocaIP, reductionDecls,
6727 privateReductionVariables, isByRef,
6739 if (!cast<mlir::ModuleOp>(op))
6744 ompBuilder->M.addModuleFlag(llvm::Module::Max,
"openmp-device",
6745 attribute.getOpenmpDeviceVersion());
6747 if (attribute.getNoGpuLib())
6750 ompBuilder->createGlobalFlag(
6751 attribute.getDebugKind() ,
6752 "__omp_rtl_debug_kind");
6753 ompBuilder->createGlobalFlag(
6755 .getAssumeTeamsOversubscription()
6757 "__omp_rtl_assume_teams_oversubscription");
6758 ompBuilder->createGlobalFlag(
6760 .getAssumeThreadsOversubscription()
6762 "__omp_rtl_assume_threads_oversubscription");
6763 ompBuilder->createGlobalFlag(
6764 attribute.getAssumeNoThreadState() ,
6765 "__omp_rtl_assume_no_thread_state");
6766 ompBuilder->createGlobalFlag(
6768 .getAssumeNoNestedParallelism()
6770 "__omp_rtl_assume_no_nested_parallelism");
6775 omp::TargetOp targetOp,
6776 llvm::OpenMPIRBuilder &ompBuilder,
6777 llvm::vfs::FileSystem &vfs,
6778 llvm::StringRef parentName =
"") {
6779 auto fileLoc = targetOp.getLoc()->findInstanceOf<
FileLineColLoc>();
6780 assert(fileLoc &&
"No file found from location");
6782 auto fileInfoCallBack = [&fileLoc]() {
6783 return std::pair<std::string, uint64_t>(
6784 llvm::StringRef(fileLoc.getFilename()), fileLoc.getLine());
6788 ompBuilder.getTargetEntryUniqueInfo(fileInfoCallBack, vfs, parentName);
6794 llvm::IRBuilderBase &builder, llvm::Function *
func) {
6796 "function only supported for target device codegen");
6797 llvm::IRBuilderBase::InsertPointGuard guard(builder);
6798 for (
size_t i = 0; i < mapData.MapClause.size(); ++i) {
6811 if (!mapData.IsDeclareTarget[i])
6819 if (
auto *constant = dyn_cast<llvm::Constant>(mapData.OriginalValue[i]))
6820 convertUsersOfConstantsToInstructions(constant,
func,
false);
6827 for (llvm::User *user : mapData.OriginalValue[i]->users())
6828 userVec.push_back(user);
6830 for (llvm::User *user : userVec) {
6831 auto *insn = dyn_cast<llvm::Instruction>(user);
6832 if (!insn || insn->getFunction() !=
func)
6834 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
6835 llvm::Value *substitute = mapData.BasePointers[i];
6837 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
6841 ->Config.hasRequiresUnifiedSharedMemory())) {
6842 builder.SetCurrentDebugLocation(insn->getDebugLoc());
6843 substitute = builder.CreateLoad(mapData.BasePointers[i]->getType(),
6844 mapData.BasePointers[i]);
6845 cast<llvm::LoadInst>(substitute)->moveBefore(insn->getIterator());
6847 user->replaceUsesOfWith(mapData.OriginalValue[i], substitute);
6892 omp::TargetOp targetOp, MapInfoData &mapData, llvm::Argument &arg,
6893 llvm::Value *input, llvm::Value *&retVal, llvm::IRBuilderBase &builder,
6894 llvm::OpenMPIRBuilder &ompBuilder,
6896 llvm::IRBuilderBase::InsertPoint allocaIP,
6897 llvm::IRBuilderBase::InsertPoint codeGenIP,
6899 assert(ompBuilder.Config.isTargetDevice() &&
6900 "function only supported for target device codegen");
6901 builder.restoreIP(allocaIP);
6903 omp::VariableCaptureKind capture = omp::VariableCaptureKind::ByRef;
6905 ompBuilder.M.getContext());
6906 unsigned alignmentValue = 0;
6909 cast<omp::BlockArgOpenMPOpInterface>(*targetOp).getBlockArgsPairs(
6912 for (
size_t i = 0; i < mapData.MapClause.size(); ++i) {
6913 if (mapData.OriginalValue[i] == input) {
6914 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
6915 capture = mapOp.getMapCaptureType();
6918 mapOp.getVarPtrType(), ompBuilder.M.getDataLayout());
6922 for (
auto &[val, arg] : blockArgsPairs) {
6923 if (mapOp.getResult() == val) {
6928 assert(mlirArg &&
"expected to find entry block argument for map clause");
6933 unsigned int allocaAS = ompBuilder.M.getDataLayout().getAllocaAddrSpace();
6934 unsigned int defaultAS =
6935 ompBuilder.M.getDataLayout().getProgramAddressSpace();
6938 llvm::Value *v =
nullptr;
6946 builder.SetInsertPoint(codeGenIP.getBlock()->getFirstInsertionPt());
6947 v = ompBuilder.createOMPAllocShared(builder, arg.getType());
6951 llvm::IRBuilderBase::InsertPointGuard guard(builder);
6952 for (
auto deallocIP : deallocIPs) {
6953 builder.SetInsertPoint(deallocIP.getBlock(), deallocIP.getPoint());
6954 ompBuilder.createOMPFreeShared(builder, v, arg.getType());
6958 v = builder.CreateAlloca(arg.getType(), allocaAS);
6960 if (allocaAS != defaultAS && arg.getType()->isPointerTy())
6961 v = builder.CreateAddrSpaceCast(v, builder.getPtrTy(defaultAS));
6964 builder.CreateStore(&arg, v);
6966 builder.restoreIP(codeGenIP);
6969 case omp::VariableCaptureKind::ByCopy: {
6973 case omp::VariableCaptureKind::ByRef: {
6974 llvm::LoadInst *loadInst = builder.CreateAlignedLoad(
6976 ompBuilder.M.getDataLayout().getPrefTypeAlign(v->getType()));
6991 if (v->getType()->isPointerTy() && alignmentValue) {
6992 llvm::MDBuilder MDB(builder.getContext());
6993 loadInst->setMetadata(
6994 llvm::LLVMContext::MD_align,
6995 llvm::MDNode::get(builder.getContext(),
6996 MDB.createConstant(llvm::ConstantInt::get(
6997 llvm::Type::getInt64Ty(builder.getContext()),
7004 case omp::VariableCaptureKind::This:
7005 case omp::VariableCaptureKind::VLAType:
7008 assert(
false &&
"Currently unsupported capture kind");
7012 return builder.saveIP();
7029 auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
7030 for (
auto item : llvm::zip_equal(targetOp.getHostEvalVars(),
7031 blockArgIface.getHostEvalBlockArgs())) {
7032 Value hostEvalVar = std::get<0>(item), blockArg = std::get<1>(item);
7036 .Case([&](omp::TeamsOp teamsOp) {
7037 if (teamsOp.getNumTeamsLower() == blockArg)
7038 numTeamsLower = hostEvalVar;
7039 else if (llvm::is_contained(teamsOp.getNumTeamsUpperVars(),
7041 numTeamsUpper = hostEvalVar;
7042 else if (!teamsOp.getThreadLimitVars().empty() &&
7043 teamsOp.getThreadLimit(0) == blockArg)
7044 threadLimit = hostEvalVar;
7046 llvm_unreachable(
"unsupported host_eval use");
7048 .Case([&](omp::ParallelOp parallelOp) {
7049 if (!parallelOp.getNumThreadsVars().empty() &&
7050 parallelOp.getNumThreads(0) == blockArg)
7051 numThreads = hostEvalVar;
7053 llvm_unreachable(
"unsupported host_eval use");
7055 .Case([&](omp::LoopNestOp loopOp) {
7056 auto processBounds =
7060 for (
auto [i, lb] : llvm::enumerate(opBounds)) {
7061 if (lb == blockArg) {
7064 (*outBounds)[i] = hostEvalVar;
7070 processBounds(loopOp.getLoopLowerBounds(), lowerBounds);
7071 found = processBounds(loopOp.getLoopUpperBounds(), upperBounds) ||
7073 found = processBounds(loopOp.getLoopSteps(), steps) || found;
7075 assert(found &&
"unsupported host_eval use");
7077 .DefaultUnreachable(
"unsupported host_eval use");
7089template <
typename OpTy>
7094 if (OpTy casted = dyn_cast<OpTy>(op))
7097 if (immediateParent)
7098 return dyn_cast_if_present<OpTy>(op->
getParentOp());
7107 return std::nullopt;
7110 if (
auto constAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
7111 return constAttr.getInt();
7113 return std::nullopt;
7118 uint64_t sizeInBytes = sizeInBits / 8;
7122template <
typename OpTy>
7124 if (op.getNumReductionVars() > 0) {
7129 members.reserve(reductions.size());
7130 for (omp::DeclareReductionOp &red : reductions) {
7134 if (red.getByrefElementType())
7135 members.push_back(*red.getByrefElementType());
7137 members.push_back(red.getType());
7140 auto structType = mlir::LLVM::LLVMStructType::getLiteral(
7156 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs,
7157 bool isTargetDevice,
bool isGPU) {
7160 Value numThreads, numTeamsLower, numTeamsUpper, threadLimit;
7161 if (!isTargetDevice) {
7169 numTeamsLower = teamsOp.getNumTeamsLower();
7171 if (!teamsOp.getNumTeamsUpperVars().empty())
7172 numTeamsUpper = teamsOp.getNumTeams(0);
7173 if (!teamsOp.getThreadLimitVars().empty())
7174 threadLimit = teamsOp.getThreadLimit(0);
7178 if (!parallelOp.getNumThreadsVars().empty())
7179 numThreads = parallelOp.getNumThreads(0);
7185 int32_t minTeamsVal = 1, maxTeamsVal = -1;
7189 if (numTeamsUpper) {
7191 minTeamsVal = maxTeamsVal = *val;
7193 minTeamsVal = maxTeamsVal = 0;
7199 minTeamsVal = maxTeamsVal = 1;
7201 minTeamsVal = maxTeamsVal = -1;
7206 auto setMaxValueFromClause = [](
Value clauseValue, int32_t &
result) {
7220 int32_t targetThreadLimitVal = -1, teamsThreadLimitVal = -1;
7221 if (!targetOp.getThreadLimitVars().empty())
7222 setMaxValueFromClause(targetOp.getThreadLimit(0), targetThreadLimitVal);
7223 setMaxValueFromClause(threadLimit, teamsThreadLimitVal);
7226 int32_t maxThreadsVal = -1;
7228 setMaxValueFromClause(numThreads, maxThreadsVal);
7236 int32_t combinedMaxThreadsVal = targetThreadLimitVal;
7237 if (combinedMaxThreadsVal < 0 ||
7238 (teamsThreadLimitVal >= 0 && teamsThreadLimitVal < combinedMaxThreadsVal))
7239 combinedMaxThreadsVal = teamsThreadLimitVal;
7241 if (combinedMaxThreadsVal < 0 ||
7242 (maxThreadsVal >= 0 && maxThreadsVal < combinedMaxThreadsVal))
7243 combinedMaxThreadsVal = maxThreadsVal;
7245 int32_t reductionDataSize = 0;
7246 if (isGPU && capturedOp) {
7252 omp::TargetExecMode execMode = targetOp.getKernelExecFlags(capturedOp);
7254 case omp::TargetExecMode::bare:
7255 attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_BARE;
7257 case omp::TargetExecMode::generic:
7258 attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_GENERIC;
7260 case omp::TargetExecMode::spmd:
7261 attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_SPMD;
7263 case omp::TargetExecMode::no_loop:
7264 attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_SPMD_NO_LOOP;
7267 attrs.MinTeams = minTeamsVal;
7268 attrs.MaxTeams.front() = maxTeamsVal;
7269 attrs.MinThreads = 1;
7270 attrs.MaxThreads.front() = combinedMaxThreadsVal;
7271 attrs.ReductionDataSize = reductionDataSize;
7274 if (attrs.ReductionDataSize != 0)
7275 attrs.ReductionBufferLength = 1024;
7287 omp::TargetOp targetOp,
Operation *capturedOp,
7288 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs) {
7290 unsigned numLoops = loopOp ? loopOp.getNumLoops() : 0;
7292 Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
7296 teamsThreadLimit, &lowerBounds, &upperBounds, &steps);
7299 if (!targetOp.getThreadLimitVars().empty()) {
7300 Value targetThreadLimit = targetOp.getThreadLimit(0);
7301 attrs.TargetThreadLimit.front() =
7309 attrs.MinTeams = builder.CreateSExtOrTrunc(
7310 moduleTranslation.
lookupValue(numTeamsLower), builder.getInt32Ty());
7313 attrs.MaxTeams.front() = builder.CreateSExtOrTrunc(
7314 moduleTranslation.
lookupValue(numTeamsUpper), builder.getInt32Ty());
7316 if (teamsThreadLimit)
7317 attrs.TeamsThreadLimit.front() = builder.CreateSExtOrTrunc(
7318 moduleTranslation.
lookupValue(teamsThreadLimit), builder.getInt32Ty());
7321 attrs.MaxThreads = moduleTranslation.
lookupValue(numThreads);
7323 bool hostEvalTripCount;
7324 targetOp.getKernelExecFlags(capturedOp, &hostEvalTripCount);
7325 if (hostEvalTripCount) {
7327 attrs.LoopTripCount =
nullptr;
7332 for (
auto [loopLower, loopUpper, loopStep] :
7333 llvm::zip_equal(lowerBounds, upperBounds, steps)) {
7334 llvm::Value *lowerBound = moduleTranslation.
lookupValue(loopLower);
7335 llvm::Value *upperBound = moduleTranslation.
lookupValue(loopUpper);
7336 llvm::Value *step = moduleTranslation.
lookupValue(loopStep);
7338 if (!lowerBound || !upperBound || !step) {
7339 attrs.LoopTripCount =
nullptr;
7343 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
7344 llvm::Value *tripCount = ompBuilder->calculateCanonicalLoopTripCount(
7345 loc, lowerBound, upperBound, step,
true,
7346 loopOp.getLoopInclusive());
7348 if (!attrs.LoopTripCount) {
7349 attrs.LoopTripCount = tripCount;
7354 attrs.LoopTripCount = builder.CreateMul(attrs.LoopTripCount, tripCount,
7359 attrs.DeviceID = builder.getInt64(llvm::omp::OMP_DEVICEID_UNDEF);
7361 attrs.DeviceID = moduleTranslation.
lookupValue(devId);
7363 builder.CreateSExtOrTrunc(attrs.DeviceID, builder.getInt64Ty());
7367static llvm::omp::OMPDynGroupprivateFallbackType
7369 omp::FallbackModifier fb = fallbackAttr ? fallbackAttr.getValue()
7370 : omp::FallbackModifier::default_mem;
7372 case omp::FallbackModifier::abort:
7373 return llvm::omp::OMPDynGroupprivateFallbackType::Abort;
7374 case omp::FallbackModifier::null:
7375 return llvm::omp::OMPDynGroupprivateFallbackType::Null;
7376 case omp::FallbackModifier::default_mem:
7377 return llvm::omp::OMPDynGroupprivateFallbackType::DefaultMem;
7380 llvm_unreachable(
"unexpected dyn_groupprivate fallback type");
7386 auto targetOp = cast<omp::TargetOp>(opInst);
7391 llvm::DebugLoc outlinedFnLoc = builder.getCurrentDebugLocation();
7400 llvm::BasicBlock *parentBB = builder.GetInsertBlock();
7401 assert(parentBB &&
"No insert block is set for the builder");
7402 llvm::Function *parentLLVMFn = parentBB->getParent();
7403 assert(parentLLVMFn &&
"Parent Function must be valid");
7404 if (llvm::DISubprogram *SP = parentLLVMFn->getSubprogram())
7405 builder.SetCurrentDebugLocation(llvm::DILocation::get(
7406 parentLLVMFn->getContext(), outlinedFnLoc.getLine(),
7407 outlinedFnLoc.getCol(), SP, outlinedFnLoc.getInlinedAt()));
7410 bool isTargetDevice = ompBuilder->Config.isTargetDevice();
7411 bool isGPU = ompBuilder->Config.isGPU();
7414 auto argIface = cast<omp::BlockArgOpenMPOpInterface>(opInst);
7415 auto &targetRegion = targetOp.getRegion();
7432 llvm::Function *llvmOutlinedFn =
nullptr;
7433 TargetDirectiveEnumTy targetDirective =
7434 getTargetDirectiveEnumTyFromOp(&opInst);
7438 bool isOffloadEntry =
7439 isTargetDevice || !ompBuilder->Config.TargetTriples.empty();
7446 if (!targetOp.getPrivateVars().empty() && !targetOp.getMapVars().empty()) {
7448 std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms();
7449 std::optional<DenseI64ArrayAttr> privateMapIndices =
7450 targetOp.getPrivateMapsAttr();
7452 for (
auto [privVarIdx, privVarSymPair] :
7453 llvm::enumerate(llvm::zip_equal(privateVars, *privateSyms))) {
7454 auto privVar = std::get<0>(privVarSymPair);
7455 auto privSym = std::get<1>(privVarSymPair);
7457 SymbolRefAttr privatizerName = llvm::cast<SymbolRefAttr>(privSym);
7458 omp::PrivateClauseOp privatizer =
7461 if (!privatizer.needsMap())
7465 targetOp.getMappedValueForPrivateVar(privVarIdx);
7466 assert(mappedValue &&
"Expected to find mapped value for a privatized "
7467 "variable that needs mapping");
7472 auto mapInfoOp = mappedValue.
getDefiningOp<omp::MapInfoOp>();
7473 [[maybe_unused]]
Type varType = mapInfoOp.getVarPtrType();
7477 if (!isa<LLVM::LLVMPointerType>(privVar.getType()))
7479 varType == privVar.getType() &&
7480 "Type of private var doesn't match the type of the mapped value");
7484 mappedPrivateVars.insert(
7486 targetRegion.getArgument(argIface.getMapBlockArgsStart() +
7487 (*privateMapIndices)[privVarIdx])});
7491 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
7492 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
7494 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
7495 llvm::IRBuilderBase::InsertPointGuard guard(builder);
7496 builder.SetCurrentDebugLocation(llvm::DebugLoc());
7499 llvm::Function *llvmParentFn =
7501 llvmOutlinedFn = codeGenIP.getBlock()->getParent();
7502 assert(llvmParentFn && llvmOutlinedFn &&
7503 "Both parent and outlined functions must exist at this point");
7505 if (outlinedFnLoc && llvmParentFn->getSubprogram())
7506 llvmOutlinedFn->setSubprogram(outlinedFnLoc->getScope()->getSubprogram());
7508 if (
auto attr = llvmParentFn->getFnAttribute(
"target-cpu");
7509 attr.isStringAttribute())
7510 llvmOutlinedFn->addFnAttr(attr);
7512 if (
auto attr = llvmParentFn->getFnAttribute(
"target-features");
7513 attr.isStringAttribute())
7514 llvmOutlinedFn->addFnAttr(attr);
7516 for (
auto [arg, mapOp] : llvm::zip_equal(mapBlockArgs, mapVars)) {
7517 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
7518 llvm::Value *mapOpValue =
7519 moduleTranslation.
lookupValue(mapInfoOp.getVarPtr());
7520 moduleTranslation.
mapValue(arg, mapOpValue);
7522 for (
auto [arg, mapOp] : llvm::zip_equal(hdaBlockArgs, hdaVars)) {
7523 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
7524 llvm::Value *mapOpValue =
7525 moduleTranslation.
lookupValue(mapInfoOp.getVarPtr());
7526 moduleTranslation.
mapValue(arg, mapOpValue);
7535 privateVarsInfo, allocaIP, &mappedPrivateVars);
7538 return llvm::make_error<PreviouslyReportedError>();
7540 builder.restoreIP(codeGenIP);
7542 &mappedPrivateVars),
7545 return llvm::make_error<PreviouslyReportedError>();
7548 targetOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
7550 targetOp.getPrivateNeedsBarrier(), &mappedPrivateVars)))
7551 return llvm::make_error<PreviouslyReportedError>();
7554 moduleTranslation, allocaIP, deallocBlocks);
7556 targetRegion,
"omp.target", builder, moduleTranslation);
7559 return llvm::make_error<PreviouslyReportedError>();
7561 builder.SetInsertPoint(exitBlock.get()->getTerminator());
7564 targetOp.getLoc(), privateVarsInfo)))
7565 return llvm::make_error<PreviouslyReportedError>();
7567 return builder.saveIP();
7570 StringRef parentName = parentFn.getName();
7572 llvm::TargetRegionEntryInfo entryInfo;
7578 MapInfoData mapData;
7583 MapInfosTy combinedInfos;
7585 [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP) -> MapInfosTy & {
7586 builder.restoreIP(codeGenIP);
7587 genMapInfos(builder, moduleTranslation, dl, combinedInfos, mapData,
7592 auto *nullPtr = llvm::Constant::getNullValue(builder.getPtrTy());
7593 combinedInfos.BasePointers.push_back(nullPtr);
7594 combinedInfos.Pointers.push_back(nullPtr);
7595 combinedInfos.DevicePointers.push_back(
7596 llvm::OpenMPIRBuilder::DeviceInfoTy::None);
7597 combinedInfos.Sizes.push_back(builder.getInt64(0));
7598 combinedInfos.Types.push_back(
7599 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM |
7600 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL);
7601 if (!combinedInfos.Names.empty())
7602 combinedInfos.Names.push_back(nullPtr);
7603 combinedInfos.Mappers.push_back(
nullptr);
7605 return combinedInfos;
7608 auto argAccessorCB = [&](llvm::Argument &arg, llvm::Value *input,
7609 llvm::Value *&retVal, InsertPointTy allocaIP,
7610 InsertPointTy codeGenIP,
7612 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
7613 llvm::IRBuilderBase::InsertPointGuard guard(builder);
7614 builder.SetCurrentDebugLocation(llvm::DebugLoc());
7620 if (!isTargetDevice) {
7621 retVal = cast<llvm::Value>(&arg);
7626 builder, *ompBuilder, moduleTranslation,
7627 allocaIP, codeGenIP, deallocIPs);
7630 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs runtimeAttrs;
7631 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs;
7632 Operation *targetCapturedOp = targetOp.getInnermostCapturedOmpOp();
7634 isTargetDevice, isGPU);
7638 if (!isTargetDevice)
7640 targetCapturedOp, runtimeAttrs);
7648 for (
auto [arg, var] : llvm::zip_equal(hostEvalBlockArgs, hostEvalVars)) {
7649 llvm::Value *value = moduleTranslation.
lookupValue(var);
7650 moduleTranslation.
mapValue(arg, value);
7652 if (!llvm::isa<llvm::Constant>(value))
7653 kernelInput.push_back(value);
7656 for (
size_t i = 0, e = mapData.OriginalValue.size(); i != e; ++i) {
7665 bool isAttachMap = (mapData.Types[i] &
7666 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH) ==
7667 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH;
7668 if (!mapData.IsDeclareTarget[i] && !mapData.IsAMember[i] && !isAttachMap)
7669 kernelInput.push_back(mapData.OriginalValue[i]);
7673 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
7676 llvm::OpenMPIRBuilder::DependenciesInfo dds;
7678 targetOp.getDependVars(), targetOp.getDependKinds(),
7679 targetOp.getDependIterated(), targetOp.getDependIteratedKinds(),
7680 builder, moduleTranslation, dds)))
7683 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
7685 llvm::OpenMPIRBuilder::TargetDataInfo info(
7689 auto customMapperCB =
7691 if (!combinedInfos.Mappers[i])
7693 info.HasMapper =
true;
7695 moduleTranslation, targetDirective);
7698 llvm::Value *ifCond =
nullptr;
7699 if (
Value targetIfCond = targetOp.getIfExpr())
7700 ifCond = moduleTranslation.
lookupValue(targetIfCond);
7702 Value dynGroupPrivateSize = targetOp.getDynGroupprivateSize();
7703 llvm::Value *dynSizeVal =
nullptr;
7704 if (dynGroupPrivateSize) {
7705 dynSizeVal = moduleTranslation.
lookupValue(dynGroupPrivateSize);
7706 dynSizeVal = builder.CreateIntCast(dynSizeVal, builder.getInt32Ty(),
7710 llvm::omp::OMPDynGroupprivateFallbackType fallbackType =
7713 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
7715 ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), deallocBlocks,
7716 info, entryInfo, defaultAttrs, runtimeAttrs, ifCond, kernelInput,
7717 genMapInfoCB, bodyCB, argAccessorCB, customMapperCB, dds,
7718 targetOp.getNowait(), dynSizeVal, fallbackType);
7723 builder.restoreIP(*afterIP);
7726 builder.CreateFree(dds.DepArray);
7739 llvm::OpenMPIRBuilder *ompBuilder,
7748 if (FunctionOpInterface funcOp = dyn_cast<FunctionOpInterface>(op)) {
7749 if (
auto offloadMod = dyn_cast<omp::OffloadModuleInterface>(
7751 if (!offloadMod.getIsTargetDevice())
7754 omp::DeclareTargetDeviceType declareType =
7755 attribute.getDeviceType().getValue();
7757 if (declareType == omp::DeclareTargetDeviceType::host) {
7758 llvm::Function *llvmFunc =
7760 llvmFunc->dropAllReferences();
7761 llvmFunc->eraseFromParent();
7765 ompBuilder->Builder.ClearInsertionPoint();
7766 ompBuilder->Builder.SetCurrentDebugLocation(llvm::DebugLoc());
7772 if (LLVM::GlobalOp gOp = dyn_cast<LLVM::GlobalOp>(op)) {
7773 llvm::Module *llvmModule = moduleTranslation.
getLLVMModule();
7774 if (
auto *gVal = llvmModule->getNamedValue(gOp.getSymName())) {
7776 bool isDeclaration = gOp.isDeclaration();
7777 bool isExternallyVisible =
7780 llvm::StringRef mangledName = gOp.getSymName();
7781 auto captureClause =
7787 std::vector<llvm::GlobalVariable *> generatedRefs;
7789 std::vector<llvm::Triple> targetTriple;
7790 auto targetTripleAttr = dyn_cast_or_null<mlir::StringAttr>(
7792 LLVM::LLVMDialect::getTargetTripleAttrName()));
7793 if (targetTripleAttr)
7794 targetTriple.emplace_back(targetTripleAttr.data());
7796 auto fileInfoCallBack = [&loc]() {
7797 std::string filename =
"";
7798 std::uint64_t lineNo = 0;
7801 filename = loc.getFilename().str();
7802 lineNo = loc.getLine();
7805 return std::pair<std::string, std::uint64_t>(llvm::StringRef(filename),
7809 llvm::vfs::FileSystem &vfs = moduleTranslation.
getFileSystem();
7811 ompBuilder->registerTargetGlobalVariable(
7812 captureClause, deviceClause, isDeclaration, isExternallyVisible,
7813 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack, vfs),
7814 mangledName, generatedRefs,
false, targetTriple,
7816 gVal->getType(), gVal);
7818 if (ompBuilder->Config.isTargetDevice() &&
7819 (attribute.getCaptureClause().getValue() !=
7820 mlir::omp::DeclareTargetCaptureClause::to ||
7821 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
7822 ompBuilder->getAddrOfDeclareTargetVar(
7823 captureClause, deviceClause, isDeclaration, isExternallyVisible,
7824 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack, vfs),
7825 mangledName, generatedRefs,
false, targetTriple,
7826 gVal->getType(),
nullptr,
7839class OpenMPDialectLLVMIRTranslationInterface
7840 :
public LLVMTranslationDialectInterface {
7842 using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
7847 convertOperation(Operation *op, llvm::IRBuilderBase &builder,
7848 LLVM::ModuleTranslation &moduleTranslation)
const final;
7853 amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
7854 NamedAttribute attribute,
7855 LLVM::ModuleTranslation &moduleTranslation)
const final;
7860 void registerAllocatedPtr(Value var, llvm::Value *ptr)
const {
7861 ompAllocatedPtrs[var] = ptr;
7866 llvm::Value *lookupAllocatedPtr(Value var)
const {
7867 auto it = ompAllocatedPtrs.find(var);
7868 return it != ompAllocatedPtrs.end() ? it->second :
nullptr;
7880LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
7881 Operation *op, ArrayRef<llvm::Instruction *> instructions,
7882 NamedAttribute attribute,
7883 LLVM::ModuleTranslation &moduleTranslation)
const {
7884 return llvm::StringSwitch<llvm::function_ref<LogicalResult(Attribute)>>(
7886 .Case(
"omp.is_target_device",
7887 [&](Attribute attr) {
7888 if (
auto deviceAttr = dyn_cast<BoolAttr>(attr)) {
7889 llvm::OpenMPIRBuilderConfig &config =
7891 config.setIsTargetDevice(deviceAttr.getValue());
7897 [&](Attribute attr) {
7898 if (
auto gpuAttr = dyn_cast<BoolAttr>(attr)) {
7899 llvm::OpenMPIRBuilderConfig &config =
7901 config.setIsGPU(gpuAttr.getValue());
7906 .Case(
"omp.host_ir_filepath",
7907 [&](Attribute attr) {
7908 if (
auto filepathAttr = dyn_cast<StringAttr>(attr)) {
7909 llvm::OpenMPIRBuilder *ompBuilder =
7911 ompBuilder->loadOffloadInfoMetadata(
7912 moduleTranslation.
getFileSystem(), filepathAttr.getValue());
7918 [&](Attribute attr) {
7919 if (
auto rtlAttr = dyn_cast<omp::FlagsAttr>(attr))
7923 .Case(
"omp.version",
7924 [&](Attribute attr) {
7925 if (
auto versionAttr = dyn_cast<omp::VersionAttr>(attr)) {
7926 llvm::OpenMPIRBuilder *ompBuilder =
7928 ompBuilder->M.addModuleFlag(llvm::Module::Max,
"openmp",
7929 versionAttr.getVersion());
7934 .Case(
"omp.declare_target",
7935 [&](Attribute attr) {
7936 if (
auto declareTargetAttr =
7937 dyn_cast<omp::DeclareTargetAttr>(attr)) {
7938 llvm::OpenMPIRBuilder *ompBuilder =
7941 ompBuilder, moduleTranslation);
7945 .Case(
"omp.requires",
7946 [&](Attribute attr) {
7947 if (
auto requiresAttr = dyn_cast<omp::ClauseRequiresAttr>(attr)) {
7948 using Requires = omp::ClauseRequires;
7949 Requires flags = requiresAttr.getValue();
7950 llvm::OpenMPIRBuilderConfig &config =
7952 config.setHasRequiresReverseOffload(
7953 bitEnumContainsAll(flags, Requires::reverse_offload));
7954 config.setHasRequiresUnifiedAddress(
7955 bitEnumContainsAll(flags, Requires::unified_address));
7956 config.setHasRequiresUnifiedSharedMemory(
7957 bitEnumContainsAll(flags, Requires::unified_shared_memory));
7958 config.setHasRequiresDynamicAllocators(
7959 bitEnumContainsAll(flags, Requires::dynamic_allocators));
7964 .Case(
"omp.target_triples",
7965 [&](Attribute attr) {
7966 if (
auto triplesAttr = dyn_cast<ArrayAttr>(attr)) {
7967 llvm::OpenMPIRBuilderConfig &config =
7969 config.TargetTriples.clear();
7970 config.TargetTriples.reserve(triplesAttr.size());
7971 for (Attribute tripleAttr : triplesAttr) {
7972 if (
auto tripleStrAttr = dyn_cast<StringAttr>(tripleAttr))
7973 config.TargetTriples.emplace_back(tripleStrAttr.getValue());
7981 .Default([](Attribute) {
7997 if (
auto declareTargetIface =
7998 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
7999 parentFn.getOperation()))
8000 if (declareTargetIface.isDeclareTarget() &&
8001 declareTargetIface.getDeclareTargetDeviceType() !=
8002 mlir::omp::DeclareTargetDeviceType::host)
8012 llvm::Module *llvmModule) {
8013 llvm::Type *i64Ty = builder.getInt64Ty();
8014 llvm::Type *i32Ty = builder.getInt32Ty();
8015 llvm::Type *returnType = builder.getPtrTy(0);
8016 llvm::FunctionType *fnType =
8017 llvm::FunctionType::get(returnType, {i64Ty, i32Ty},
false);
8018 llvm::Function *
func = cast<llvm::Function>(
8019 llvmModule->getOrInsertFunction(
"omp_target_alloc", fnType).getCallee());
8023template <
typename T>
8027 llvm::DataLayout dataLayout =
8029 llvm::Type *llvmHeapTy =
8030 moduleTranslation.
convertType(op.getMemElemTypeAttr().getValue());
8032 auto alignment = op.getMemAlignment();
8033 llvm::TypeSize typeSize = llvm::alignTo(
8034 dataLayout.getTypeStoreSize(llvmHeapTy),
8035 alignment ? *alignment : dataLayout.getABITypeAlign(llvmHeapTy).value());
8037 llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue());
8038 return builder.CreateMul(
8040 builder.CreateIntCast(moduleTranslation.
lookupValue(op.getMemArraySize()),
8041 builder.getInt64Ty(),
8048 omp::TargetAllocMemOp op) {
8049 llvm::DataLayout dataLayout =
8051 llvm::Type *llvmHeapTy = moduleTranslation.
convertType(op.getAllocatedType());
8052 llvm::TypeSize typeSize = dataLayout.getTypeAllocSize(llvmHeapTy);
8053 llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue());
8054 for (
auto typeParam : op.getTypeparams()) {
8055 allocSize = builder.CreateMul(
8057 builder.CreateIntCast(moduleTranslation.
lookupValue(typeParam),
8058 builder.getInt64Ty(),
8067 auto allocMemOp = cast<omp::TargetAllocMemOp>(opInst);
8072 llvm::Module *llvmModule = moduleTranslation.
getLLVMModule();
8076 llvm::Value *llvmDeviceNum = moduleTranslation.
lookupValue(deviceNum);
8078 llvm::Value *allocSize =
8081 llvm::CallInst *call =
8082 builder.CreateCall(ompTargetAllocFunc, {allocSize, llvmDeviceNum});
8083 llvm::Value *resultI64 = builder.CreatePtrToInt(call, builder.getInt64Ty());
8086 moduleTranslation.
mapValue(allocMemOp.getResult(), resultI64);
8092 llvm::IRBuilderBase &builder,
8096 moduleTranslation.
mapValue(allocMemOp.getResult(),
8097 ompBuilder->createOMPAllocShared(builder, size));
8104 const OpenMPDialectLLVMIRTranslationInterface &ompIface) {
8105 auto allocateDirOp = cast<omp::AllocateDirOp>(opInst);
8108 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
8109 llvm::Module *llvmModule = moduleTranslation.
getLLVMModule();
8110 llvm::DataLayout dataLayout = llvmModule->getDataLayout();
8112 std::optional<int64_t> alignAttr = allocateDirOp.getAlign();
8114 llvm::Value *allocator;
8115 if (
auto allocatorVar = allocateDirOp.getAllocator()) {
8116 allocator = moduleTranslation.
lookupValue(allocatorVar);
8117 if (allocator->getType()->isIntegerTy())
8118 allocator = builder.CreateIntToPtr(allocator, builder.getPtrTy());
8119 else if (allocator->getType()->isPointerTy())
8120 allocator = builder.CreatePointerBitCastOrAddrSpaceCast(
8121 allocator, builder.getPtrTy());
8123 allocator = llvm::ConstantPointerNull::get(builder.getPtrTy());
8126 for (
Value var : vars) {
8127 llvm::Type *llvmVarTy = moduleTranslation.
convertType(var.getType());
8131 llvm::Type *typeToInspect = llvmVarTy;
8132 if (llvmVarTy->isPointerTy()) {
8135 if (
auto gop = dyn_cast<LLVM::GlobalOp>(globalOp))
8136 typeToInspect = moduleTranslation.
convertType(gop.getGlobalType());
8141 if (
auto arrTy = llvm::dyn_cast<llvm::ArrayType>(typeToInspect)) {
8142 llvm::Value *elementCount = builder.getInt64(1);
8143 llvm::Type *currentType = arrTy;
8144 while (
auto nestedArrTy = llvm::dyn_cast<llvm::ArrayType>(currentType)) {
8145 elementCount = builder.CreateMul(
8146 elementCount, builder.getInt64(nestedArrTy->getNumElements()));
8147 currentType = nestedArrTy->getElementType();
8149 uint64_t elemSizeInBits = dataLayout.getTypeSizeInBits(currentType);
8151 builder.CreateMul(elementCount, builder.getInt64(elemSizeInBits / 8));
8153 size = builder.getInt64(
8154 dataLayout.getTypeStoreSize(typeToInspect).getFixedValue());
8157 uint64_t alignValue =
8158 alignAttr ? alignAttr.value()
8159 : dataLayout.getABITypeAlign(typeToInspect).value();
8160 llvm::Value *alignConst = builder.getInt64(alignValue);
8162 size = builder.CreateAdd(size, builder.getInt64(alignValue - 1),
"",
true);
8163 size = builder.CreateUDiv(size, alignConst);
8164 size = builder.CreateMul(size, alignConst,
"",
true);
8166 std::string allocName =
8167 ompBuilder->createPlatformSpecificName({
".void.addr"});
8168 llvm::CallInst *allocCall;
8169 if (alignAttr.has_value()) {
8170 allocCall = ompBuilder->createOMPAlignedAlloc(
8171 ompLoc, builder.getInt64(alignAttr.value()), size, allocator,
8175 ompBuilder->createOMPAlloc(ompLoc, size, allocator, allocName);
8178 ompIface.registerAllocatedPtr(var, allocCall);
8187 const OpenMPDialectLLVMIRTranslationInterface &ompIface) {
8188 auto freeOp = cast<omp::AllocateFreeOp>(opInst);
8190 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
8192 llvm::Value *allocator;
8193 if (
auto allocatorVar = freeOp.getAllocator()) {
8194 allocator = moduleTranslation.
lookupValue(allocatorVar);
8195 if (allocator->getType()->isIntegerTy())
8196 allocator = builder.CreateIntToPtr(allocator, builder.getPtrTy());
8197 else if (allocator->getType()->isPointerTy())
8198 allocator = builder.CreatePointerBitCastOrAddrSpaceCast(
8199 allocator, builder.getPtrTy());
8201 allocator = llvm::ConstantPointerNull::get(builder.getPtrTy());
8206 for (
Value var : llvm::reverse(vars)) {
8207 llvm::Value *allocPtr = ompIface.lookupAllocatedPtr(var);
8209 return opInst.
emitError(
"omp.allocate_free: no allocation recorded");
8210 ompBuilder->createOMPFree(ompLoc, allocPtr, allocator,
"");
8217 llvm::Module *llvmModule) {
8218 llvm::Type *ptrTy = builder.getPtrTy(0);
8219 llvm::Type *i32Ty = builder.getInt32Ty();
8220 llvm::Type *voidTy = builder.getVoidTy();
8221 llvm::FunctionType *fnType =
8222 llvm::FunctionType::get(voidTy, {ptrTy, i32Ty},
false);
8223 llvm::Function *
func = dyn_cast<llvm::Function>(
8224 llvmModule->getOrInsertFunction(
"omp_target_free", fnType).getCallee());
8231 auto freeMemOp = cast<omp::TargetFreeMemOp>(opInst);
8236 llvm::Module *llvmModule = moduleTranslation.
getLLVMModule();
8240 llvm::Value *llvmDeviceNum = moduleTranslation.
lookupValue(deviceNum);
8243 llvm::Value *llvmHeapref = moduleTranslation.
lookupValue(heapref);
8245 llvm::Value *intToPtr =
8246 builder.CreateIntToPtr(llvmHeapref, builder.getPtrTy(0));
8247 builder.CreateCall(ompTragetFreeFunc, {intToPtr, llvmDeviceNum});
8253 llvm::IRBuilderBase &builder,
8257 ompBuilder->createOMPFreeShared(
8258 builder, moduleTranslation.
lookupValue(freeMemOp.getHeapref()), size);
8267 auto groupprivateOp = cast<omp::GroupprivateOp>(opInst);
8272 bool isTargetDevice = ompBuilder->Config.isTargetDevice();
8276 bool shouldAllocate =
true;
8277 switch (groupprivateOp.getDeviceType().value_or(
8278 mlir::omp::DeclareTargetDeviceType::any)) {
8279 case mlir::omp::DeclareTargetDeviceType::host:
8280 shouldAllocate = !isTargetDevice;
8282 case mlir::omp::DeclareTargetDeviceType::nohost:
8283 shouldAllocate = isTargetDevice;
8285 case mlir::omp::DeclareTargetDeviceType::any:
8286 shouldAllocate =
true;
8292 &opInst, groupprivateOp.getSymNameAttr());
8295 <<
"expected symbol '" << groupprivateOp.getSymName()
8296 <<
"' to reference an LLVM global variable";
8298 llvm::GlobalValue *globalValue = moduleTranslation.
lookupGlobal(global);
8299 llvm::Type *varType = moduleTranslation.
convertType(global.getType());
8300 std::string varName = globalValue->getName().str();
8302 llvm::Value *resultPtr;
8303 if (shouldAllocate && isTargetDevice) {
8304 llvm::Module *llvmModule = moduleTranslation.
getLLVMModule();
8305 llvm::Triple targetTriple(llvmModule->getTargetTriple());
8306 unsigned sharedAddressSpace;
8307 if (targetTriple.isAMDGCN())
8308 sharedAddressSpace = llvm::AMDGPUAS::LOCAL_ADDRESS;
8309 else if (targetTriple.isNVPTX())
8310 sharedAddressSpace = llvm::NVPTXAS::ADDRESS_SPACE_SHARED;
8312 return opInst.
emitError() <<
"groupprivate is not supported for target: "
8313 << targetTriple.str();
8314 llvm::GlobalVariable *sharedVar =
new llvm::GlobalVariable(
8315 *llvmModule, varType,
false,
8316 llvm::GlobalValue::InternalLinkage, llvm::PoisonValue::get(varType),
8317 varName,
nullptr, llvm::GlobalValue::NotThreadLocal,
8320 resultPtr = sharedVar;
8322 if (shouldAllocate && !isTargetDevice)
8323 opInst.
emitWarning(
"groupprivate directive is currently ignored on the "
8324 "host, using original global");
8325 resultPtr = globalValue;
8334LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
8335 Operation *op, llvm::IRBuilderBase &builder,
8336 LLVM::ModuleTranslation &moduleTranslation)
const {
8339 if (ompBuilder->Config.isTargetDevice() &&
8340 !isa<omp::TargetOp, omp::MapInfoOp, omp::TerminatorOp, omp::YieldOp>(
8343 return op->
emitOpError() <<
"unsupported host op found in device";
8351 bool isOutermostLoopWrapper =
8352 isa_and_present<omp::LoopWrapperInterface>(op) &&
8353 !dyn_cast_if_present<omp::LoopWrapperInterface>(op->
getParentOp());
8362 if (isa<omp::TaskloopContextOp>(op))
8363 isOutermostLoopWrapper =
true;
8364 else if (isa<omp::TaskloopWrapperOp>(op))
8365 isOutermostLoopWrapper =
false;
8367 if (isOutermostLoopWrapper)
8368 moduleTranslation.
stackPush<OpenMPLoopInfoStackFrame>();
8371 llvm::TypeSwitch<Operation *, LogicalResult>(op)
8372 .Case([&](omp::BarrierOp op) -> LogicalResult {
8376 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
8377 ompBuilder->createBarrier(builder.saveIP(),
8378 llvm::omp::OMPD_barrier);
8380 if (res.succeeded()) {
8383 builder.restoreIP(*afterIP);
8387 .Case([&](omp::TaskyieldOp op) {
8391 ompBuilder->createTaskyield(builder.saveIP());
8394 .Case([&](omp::FlushOp op) {
8406 ompBuilder->createFlush(builder.saveIP());
8409 .Case([&](omp::ParallelOp op) {
8412 .Case([&](omp::MaskedOp) {
8415 .Case([&](omp::MasterOp) {
8418 .Case([&](omp::CriticalOp) {
8421 .Case([&](omp::OrderedRegionOp) {
8424 .Case([&](omp::OrderedOp) {
8427 .Case([&](omp::WsloopOp) {
8430 .Case([&](omp::SimdOp) {
8433 .Case([&](omp::AtomicReadOp) {
8436 .Case([&](omp::AtomicWriteOp) {
8439 .Case([&](omp::AtomicUpdateOp op) {
8442 .Case([&](omp::AtomicCaptureOp op) {
8445 .Case([&](omp::CancelOp op) {
8448 .Case([&](omp::CancellationPointOp op) {
8451 .Case([&](omp::SectionsOp) {
8454 .Case([&](omp::ScopeOp op) {
8457 .Case([&](omp::SingleOp op) {
8460 .Case([&](omp::TeamsOp op) {
8463 .Case([&](omp::TaskOp op) {
8466 .Case([&](omp::TaskloopWrapperOp op) {
8469 .Case([&](omp::TaskloopContextOp op) {
8472 .Case([&](omp::TaskgroupOp op) {
8475 .Case([&](omp::TaskwaitOp op) {
8478 .Case<omp::YieldOp, omp::TerminatorOp, omp::DeclareMapperOp,
8479 omp::DeclareMapperInfoOp, omp::DeclareReductionOp,
8480 omp::CriticalDeclareOp>([](
auto op) {
8493 .Case([&](omp::ThreadprivateOp) {
8496 .Case<omp::TargetDataOp, omp::TargetEnterDataOp,
8497 omp::TargetExitDataOp, omp::TargetUpdateOp>([&](
auto op) {
8500 .Case([&](omp::TargetOp) {
8503 .Case([&](omp::DistributeOp) {
8506 .Case([&](omp::LoopNestOp) {
8509 .Case<omp::MapInfoOp, omp::MapBoundsOp, omp::PrivateClauseOp,
8510 omp::AffinityEntryOp, omp::IteratorOp>([&](
auto op) {
8516 .Case([&](omp::NewCliOp op) {
8521 .Case([&](omp::CanonicalLoopOp op) {
8524 .Case([&](omp::UnrollHeuristicOp op) {
8533 .Case([&](omp::TileOp op) {
8534 return applyTile(op, builder, moduleTranslation);
8536 .Case([&](omp::FuseOp op) {
8537 return applyFuse(op, builder, moduleTranslation);
8539 .Case([&](omp::TargetAllocMemOp) {
8542 .Case([&](omp::TargetFreeMemOp) {
8545 .Case([&](omp::AllocateDirOp) {
8548 .Case([&](omp::AllocateFreeOp) {
8552 .Case([&](omp::AllocSharedMemOp op) {
8555 .Case([&](omp::FreeSharedMemOp op) {
8558 .Case([&](omp::GroupprivateOp) {
8561 .Default([&](Operation *inst) {
8563 <<
"not yet implemented: " << inst->
getName();
8566 if (isOutermostLoopWrapper)
8573 registry.
insert<omp::OpenMPDialect>();
8575 dialect->addInterfaces<OpenMPDialectLLVMIRTranslationInterface>();
static mlir::LogicalResult buildDependData(OperandRange dependVars, std::optional< ArrayAttr > dependKinds, OperandRange dependIterated, std::optional< ArrayAttr > dependIteratedKinds, llvm::IRBuilderBase &builder, mlir::LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::DependenciesInfo &taskDeps)
static void handleDeclareTargetMapVar(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::Function *func)
static LogicalResult convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP atomic update operation using OpenMPIRBuilder.
static llvm::omp::OrderKind convertOrderKind(std::optional< omp::ClauseOrderKind > o)
Convert Order attribute to llvm::omp::OrderKind.
static void mapParentWithMembers(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, llvm::omp::OpenMPOffloadMappingFlags memberOfFlag, TargetDirectiveEnumTy targetDirective)
static void processIndividualMap(llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, MapInfoData &mapData, size_t mapDataIdx, MapInfosTy &combinedInfo, TargetDirectiveEnumTy targetDirective, llvm::omp::OpenMPOffloadMappingFlags memberOfFlag=llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE, bool isTargetParam=true, int mapDataParentIdx=-1)
This function handles the insertion of a single item of map data from MapInfoData into the OMPIRBuild...
static llvm::OpenMPIRBuilder::InsertPointTy findAllocInsertPoints(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::SmallVectorImpl< llvm::BasicBlock * > *deallocBlocks=nullptr)
Find the insertion point for allocas given the current insertion point for normal operations in the b...
static void sortMapIndices(llvm::SmallVectorImpl< size_t > &indices, omp::MapInfoOp mapInfo, bool first=true)
static LogicalResult convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
owningDataPtrPtrReductionGens[i]
static LogicalResult convertOmpTaskloopContextOp(omp::TaskloopContextOp contextOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static Operation * getGlobalOpFromValue(Value value)
static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind convertToCaptureClauseKind(mlir::omp::DeclareTargetCaptureClause captureClause)
static mlir::LogicalResult convertIteratorRegion(llvm::Value *linearIV, IteratorInfo &iterInfo, mlir::Block &iteratorRegionBlock, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo, bool first)
static OpTy castOrGetParentOfType(Operation *op, bool immediateParent=false)
If op is of the given type parameter, return it casted to that type. Otherwise, if its immediate pare...
static LogicalResult convertOmpOrderedRegion(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'ordered_region' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertFreeSharedMemOp(omp::FreeSharedMemOp freeMemOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertTargetFreeMemOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpAtomicWrite(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an omp.atomic.write operation to LLVM IR.
static OwningAtomicReductionGen makeAtomicReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Create an OpenMPIRBuilder-compatible atomic reduction generator for the given reduction declaration.
static OwningDataPtrPtrReductionGen makeRefDataPtrGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, bool isByRef)
Create an OpenMPIRBuilder-compatible data_ptr_ptr reduction generator for the given reduction declara...
static void popCancelFinalizationCB(const ArrayRef< llvm::UncondBrInst * > cancelTerminators, llvm::OpenMPIRBuilder &ompBuilder, const llvm::OpenMPIRBuilder::InsertPointTy &afterIP)
If we cancelled the construct, we should branch to the finalization block of that construct....
static llvm::Value * getRefPtrIfDeclareTarget(Value value, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult applyUnrollHeuristic(omp::UnrollHeuristicOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Apply a #pragma omp unroll / "!$omp unroll" transformation using the OpenMPIRBuilder.
static LogicalResult convertOmpMaster(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'master' operation into LLVM IR using OpenMPIRBuilder.
static void getAsIntegers(ArrayAttr values, llvm::SmallVector< int64_t > &ints)
static llvm::Value * findAssociatedValue(Value privateVar, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Return the llvm::Value * corresponding to the privateVar that is being privatized....
static ArrayRef< bool > getIsByRef(std::optional< ArrayRef< bool > > attr)
static llvm::Expected< llvm::Value * > lookupOrTranslatePureValue(Value value, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder)
Look up the given value in the mapping, and if it's not there, translate its defining operation at th...
static LogicalResult allocReductionVars(T op, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, SmallVectorImpl< DeferredStore > &deferredStores, llvm::ArrayRef< bool > isByRefs)
Allocate space for privatized reduction variables.
static LogicalResult convertOmpTaskwaitOp(omp::TaskwaitOp twOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP loop nest into LLVM IR using OpenMPIRBuilder.
static mlir::LogicalResult fillIteratorLoop(mlir::omp::IteratorOp itersOp, llvm::IRBuilderBase &builder, mlir::LLVM::ModuleTranslation &moduleTranslation, IteratorInfo &iterInfo, llvm::StringRef loopName, IteratorStoreEntryTy genStoreEntry)
static llvm::Expected< llvm::BasicBlock * > allocatePrivateVars(T op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, PrivateVarsInfo &privateVarsInfo, const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Allocate and initialize delayed private variables. Returns the basic block which comes after all of t...
static void createAlteredByCaptureMap(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder)
static LogicalResult convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP task construct into LLVM IR using OpenMPIRBuilder.
static void genMapInfos(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, TargetDirectiveEnumTy targetDirective)
static llvm::AtomicOrdering convertAtomicOrdering(std::optional< omp::ClauseMemoryOrderKind > ao)
Convert an Atomic Ordering attribute to llvm::AtomicOrdering.
static void setInsertPointForPossiblyEmptyBlock(llvm::IRBuilderBase &builder, llvm::BasicBlock *block=nullptr)
llvm::function_ref< void(llvm::Value *linearIV, mlir::omp::YieldOp yield)> IteratorStoreEntryTy
static LogicalResult convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'critical' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static omp::DistributeOp getDistributeCapturingTeamsReduction(omp::TeamsOp teamsOp)
static LogicalResult convertOmpCanonicalLoopOp(omp::CanonicalLoopOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Convert an omp.canonical_loop to LLVM-IR.
static LogicalResult convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static std::optional< int64_t > extractConstInteger(Value value)
If the given value is defined by an llvm.mlir.constant operation and it is of an integer type,...
static llvm::Expected< llvm::Value * > initPrivateVar(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, omp::PrivateClauseOp &privDecl, llvm::Value *nonPrivateVar, BlockArgument &blockArg, llvm::Value *llvmPrivateVar, llvm::BasicBlock *privInitBlock, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Initialize a single (first)private variable. You probably want to use allocateAndInitPrivateVars inst...
static mlir::LogicalResult buildAffinityData(mlir::omp::TaskOp &taskOp, llvm::IRBuilderBase &builder, mlir::LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::AffinityData &ad)
static LogicalResult allocAndInitializeReductionVars(OP op, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, llvm::ArrayRef< bool > isByRef)
static LogicalResult convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Value * getAllocationSize(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, T op)
static llvm::Function * getOmpTargetAlloc(llvm::IRBuilderBase &builder, llvm::Module *llvmModule)
static llvm::omp::OMPDynGroupprivateFallbackType getDynGroupprivateFallbackType(omp::FallbackModifierAttr fallbackAttr)
static llvm::Expected< llvm::Function * > emitUserDefinedMapper(Operation *declMapperOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::StringRef mapperFuncName, TargetDirectiveEnumTy targetDirective)
static LogicalResult convertOmpOrdered(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'ordered' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult cleanupPrivateVars(T op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, Location loc, PrivateVarsInfo &privateVarsInfo)
static void processMapWithMembersOf(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, TargetDirectiveEnumTy targetDirective)
static LogicalResult convertOmpMasked(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'masked' operation into LLVM IR using OpenMPIRBuilder.
static llvm::AtomicRMWInst::BinOp convertBinOpToAtomic(Operation &op)
Converts an LLVM dialect binary operation to the corresponding enum value for atomicrmw supported bin...
static LogicalResult convertOmpCancel(omp::CancelOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp)
allocatedType moduleTranslation static convertType(allocatedType) LogicalResult inlineOmpRegionCleanup(llvm::SmallVectorImpl< Region * > &cleanupRegions, llvm::ArrayRef< llvm::Value * > privateVariables, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, StringRef regionName, bool shouldLoadCleanupRegionArg=true)
handling of DeclareReductionOp's cleanup region
static LogicalResult applyFuse(omp::FuseOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Apply a #pragma omp fuse / !$omp fuse transformation using the OpenMPIRBuilder.
static LogicalResult convertOmpScope(omp::ScopeOp &scopeOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP scope construct into LLVM IR.
static llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Error initPrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, PrivateVarsInfo &privateVarsInfo, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
static LogicalResult convertAllocSharedMemOp(omp::AllocSharedMemOp allocMemOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::CanonicalLoopInfo * findCurrentLoopInfo(LLVM::ModuleTranslation &moduleTranslation)
Find the loop information structure for the loop nest being translated.
static OwningReductionGen makeReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Create an OpenMPIRBuilder-compatible reduction generator for the given reduction declaration.
static std::vector< llvm::Value * > calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, bool isArrayTy, OperandRange bounds)
This function calculates the array/pointer offset for map data provided with bounds operations,...
static void storeAffinityEntry(llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, llvm::Value *affinityList, llvm::Value *index, llvm::Value *addr, llvm::Value *len)
static LogicalResult convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts the OpenMP parallel operation to LLVM IR.
static void pushCancelFinalizationCB(SmallVectorImpl< llvm::UncondBrInst * > &cancelTerminators, llvm::IRBuilderBase &llvmBuilder, llvm::OpenMPIRBuilder &ompBuilder, mlir::Operation *op, llvm::omp::Directive cancelDirective)
Shared implementation of a callback which adds a termiator for the new block created for the branch t...
static LogicalResult inlineConvertOmpRegions(Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::Value * > *continuationBlockArgs=nullptr)
Translates the blocks contained in the given region and appends them to at the current insertion poin...
static void getTargetEntryUniqueInfo(llvm::TargetRegionEntryInfo &targetInfo, omp::TargetOp targetOp, llvm::OpenMPIRBuilder &ompBuilder, llvm::vfs::FileSystem &vfs, llvm::StringRef parentName="")
static LogicalResult convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP Threadprivate operation into LLVM IR using OpenMPIRBuilder.
static omp::PrivateClauseOp findPrivatizer(Operation *from, SymbolRefAttr symbolName)
Looks up from the operation from and returns the PrivateClauseOp with name symbolName.
static LogicalResult convertOmpGroupprivate(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP groupprivate operation into LLVM IR.
static llvm::Expected< llvm::Function * > getOrCreateUserDefinedMapperFunc(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, TargetDirectiveEnumTy targetDirective)
static uint64_t getTypeByteSize(mlir::Type type, const DataLayout &dl)
static llvm::SmallString< 64 > getDeclareTargetRefPtrSuffix(LLVM::GlobalOp globalOp, llvm::OpenMPIRBuilder &ompBuilder, llvm::vfs::FileSystem &vfs)
static void extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads, Value &numTeamsLower, Value &numTeamsUpper, Value &threadLimit, llvm::SmallVectorImpl< Value > *lowerBounds=nullptr, llvm::SmallVectorImpl< Value > *upperBounds=nullptr, llvm::SmallVectorImpl< Value > *steps=nullptr)
Follow uses of host_eval-defined block arguments of the given omp.target operation and populate outpu...
static llvm::Expected< llvm::BasicBlock * > convertOmpOpRegions(Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::PHINode * > *continuationBlockPHIs=nullptr)
Converts the given region that appears within an OpenMP dialect operation to LLVM IR,...
static LogicalResult copyFirstPrivateVars(mlir::Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::Value * > &moldVars, ArrayRef< llvm::Value * > llvmPrivateVars, SmallVectorImpl< omp::PrivateClauseOp > &privateDecls, bool insertBarrier, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
static LogicalResult convertAllocateDirOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, const OpenMPDialectLLVMIRTranslationInterface &ompIface)
static bool constructIsCancellable(Operation *op)
Returns true if the construct contains omp.cancel or omp.cancellation_point.
static llvm::omp::OpenMPOffloadMappingFlags convertClauseMapFlags(omp::ClauseMapFlags mlirFlags)
static void buildDependDataLocator(std::optional< ArrayAttr > dependKinds, OperandRange dependVars, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::OpenMPIRBuilder::DependData > &dds)
static void mapInitializationArgs(T loop, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, DenseMap< Value, llvm::Value * > &reductionVariableMap, unsigned i)
Map input arguments to reduction initialization region.
static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind)
Convert ProcBindKind from MLIR-generated enum to LLVM enum.
static void fillAffinityLocators(Operation::operand_range affinityVars, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::Value *affinityList)
static LogicalResult convertOmpTaskloopWrapperOp(omp::TaskloopWrapperOp loopWrapperOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
The correct entry point is convertOmpTaskloopContextOp. This gets called whilst lowering the body of ...
static void getOverlappedMembers(llvm::SmallVectorImpl< size_t > &overlapMapDataIdxs, omp::MapInfoOp parentOp)
static LogicalResult convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP single construct into LLVM IR using OpenMPIRBuilder.
static bool isDeclareTargetTo(Value value)
static uint64_t getArrayElementSizeInBits(LLVM::LLVMArrayType arrTy, DataLayout &dl)
static void collectReductionDecls(T op, SmallVectorImpl< omp::DeclareReductionOp > &reductions)
Populates reductions with reduction declarations used in the given op.
static LogicalResult handleError(llvm::Error error, Operation &op)
static LogicalResult convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind convertToDeviceClauseKind(mlir::omp::DeclareTargetDeviceType deviceClause)
static llvm::Error computeTaskloopBounds(omp::LoopNestOp loopOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::Value *&lbVal, llvm::Value *&ubVal, llvm::Value *&stepVal)
static LogicalResult checkImplementationStatus(Operation &op)
Check whether translation to LLVM IR for the given operation is currently supported.
static llvm::IRBuilderBase::InsertPoint createDeviceArgumentAccessor(omp::TargetOp targetOp, MapInfoData &mapData, llvm::Argument &arg, llvm::Value *input, llvm::Value *&retVal, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase::InsertPoint allocaIP, llvm::IRBuilderBase::InsertPoint codeGenIP, llvm::ArrayRef< llvm::IRBuilderBase::InsertPoint > deallocIPs)
static LogicalResult createReductionsAndCleanup(OP op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, ArrayRef< llvm::Value * > privateReductionVariables, ArrayRef< bool > isByRef, bool isNowait=false, bool isTeamsReduction=false)
static LogicalResult convertOmpCancellationPoint(omp::CancellationPointOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static bool opIsInSingleThread(mlir::Operation *op)
This can't always be determined statically, but when we can, it is good to avoid generating compiler-...
static uint64_t getReductionDataSize(OpTy &op)
static LogicalResult convertOmpAtomicRead(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Convert omp.atomic.read operation to LLVM IR.
static llvm::omp::Directive convertCancellationConstructType(omp::ClauseCancellationConstructType directive)
static void initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp, llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs, bool isTargetDevice, bool isGPU)
Populate default MinTeams, MaxTeams and MaxThreads to their default values as stated by the correspon...
static llvm::omp::RTLDependenceKindTy convertDependKind(mlir::omp::ClauseTaskDepend kind)
static void initTargetRuntimeAttrs(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, omp::TargetOp targetOp, Operation *capturedOp, llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs)
Gather LLVM runtime values for all clauses evaluated in the host that are passed to the kernel invoca...
static LogicalResult convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static Value getBaseValueForTypeLookup(Value value)
static bool isHostDeviceOp(Operation *op)
static LogicalResult convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute, llvm::OpenMPIRBuilder *ompBuilder, LLVM::ModuleTranslation &moduleTranslation)
static bool isDeclareTargetLink(Value value)
static LogicalResult convertFlagsAttr(Operation *op, mlir::omp::FlagsAttr attribute, LLVM::ModuleTranslation &moduleTranslation)
Lowers the FlagsAttr which is applied to the module on the device pass when offloading,...
static bool checkIfPointerMap(omp::MapInfoOp mapOp)
static LogicalResult applyTile(omp::TileOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Apply a #pragma omp tile / !$omp tile transformation using the OpenMPIRBuilder.
static LogicalResult convertAllocateFreeOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, const OpenMPDialectLLVMIRTranslationInterface &ompIface)
static llvm::Function * getOmpTargetFree(llvm::IRBuilderBase &builder, llvm::Module *llvmModule)
static LogicalResult convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder.
static void collectMapDataFromMapOperands(MapInfoData &mapData, SmallVectorImpl< Value > &mapVars, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, llvm::IRBuilderBase &builder, ArrayRef< Value > useDevPtrOperands={}, ArrayRef< Value > useDevAddrOperands={}, ArrayRef< Value > hasDevAddrOperands={})
static void extractAtomicControlFlags(omp::AtomicUpdateOp atomicUpdateOp, bool &isIgnoreDenormalMode, bool &isFineGrainedMemory, bool &isRemoteMemory)
static Operation * genLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr, unsigned numCases, bool needsUniv, ArrayRef< TensorLevel > tidLvls)
Generates a for-loop or a while-loop, depending on whether it implements singleton iteration or co-it...
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
An instance of this location represents a tuple of file, line number, and column number.
Implementation class for module translation.
llvm::BasicBlock * lookupBlock(Block *block) const
Finds an LLVM IR basic block that corresponds to the given MLIR block.
WalkResult stackWalk(llvm::function_ref< WalkResult(T &)> callback)
Calls callback for every ModuleTranslation stack frame of type T starting from the top of the stack.
void stackPush(Args &&...args)
Creates a stack frame of type T on ModuleTranslation stack.
LogicalResult convertBlock(Block &bb, bool ignoreArguments, llvm::IRBuilderBase &builder)
Translates the contents of the given block to LLVM IR using this translator.
SmallVector< llvm::Value * > lookupValues(ValueRange values)
Looks up remapped a list of remapped values.
void mapFunction(StringRef name, llvm::Function *func)
Stores the mapping between a function name and its LLVM IR representation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
void invalidateOmpLoop(omp::NewCliOp mlir)
Mark an OpenMP loop as having been consumed.
SymbolTableCollection & symbolTable()
llvm::Type * convertType(Type type)
Converts the type from MLIR LLVM dialect to LLVM.
llvm::OpenMPIRBuilder * getOpenMPBuilder()
Returns the OpenMP IR builder associated with the LLVM IR module being constructed.
llvm::vfs::FileSystem & getFileSystem()
Returns the virtual filesystem to use for file operations.
void mapOmpLoop(omp::NewCliOp mlir, llvm::CanonicalLoopInfo *llvm)
Map an MLIR OpenMP dialect CanonicalLoopInfo to its lowered LLVM-IR OpenMPIRBuilder CanonicalLoopInfo...
llvm::GlobalValue * lookupGlobal(Operation *op)
Finds an LLVM IR global value that corresponds to the given MLIR operation defining a global value.
SaveStateStack< T, ModuleTranslation > SaveStack
RAII object calling stackPush/stackPop on construction/destruction.
LogicalResult convertOperation(Operation &op, llvm::IRBuilderBase &builder)
Converts the given MLIR operation into LLVM IR using this translator.
llvm::Function * lookupFunction(StringRef name) const
Finds an LLVM IR function by its name.
void mapBlock(Block *mlir, llvm::BasicBlock *llvm)
Stores the mapping between an MLIR block and LLVM IR basic block.
llvm::Module * getLLVMModule()
Returns the LLVM module in which the IR is being constructed.
void stackPop()
Pops the last element from the ModuleTranslation stack.
void forgetMapping(Region ®ion)
Removes the mapping for blocks contained in the region and values defined in these blocks.
void mapValue(Value mlir, llvm::Value *llvm)
Stores the mapping between an MLIR value and its LLVM IR counterpart.
llvm::CanonicalLoopInfo * lookupOMPLoop(omp::NewCliOp mlir) const
Find the LLVM-IR loop that represents an MLIR loop.
llvm::LLVMContext & getLLVMContext() const
Returns the LLVM context in which the IR is being constructed.
Utility class to translate MLIR LLVM dialect types to LLVM IR.
unsigned getPreferredAlignment(Type type, const llvm::DataLayout &layout)
Returns the preferred alignment for the type given the data layout.
T findInstanceOf()
Return an instance of the given location type if one is nested under the current location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void appendDialectRegistry(const DialectRegistry ®istry)
Append the contents of the given dialect registry to the registry associated with this context.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
This class implements the operand iterators for the Operation class.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Value getOperand(unsigned idx)
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
unsigned getNumOperands()
OperandRange operand_range
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
user_range getUsers()
Returns a range of all users.
result_range getResults()
MLIRContext * getContext()
Return the context this operation is associated with.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgListType getArguments()
unsigned getNumArguments()
Operation * getParentOp()
Return the parent operation this region is attached to.
BlockListType & getBlocks()
bool hasOneBlock()
Return true if this region has exactly one block.
Concrete CRTP base class for StateStack frames.
@ Private
The symbol is private and may only be referenced by SymbolRefAttrs local to the operations within the...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkResult interrupt()
The OpAsmOpInterface, see OpAsmInterface.td for more details.
void connectPHINodes(Region ®ion, const ModuleTranslation &state)
For all blocks in the region that were converted to LLVM IR using the given ModuleTranslation,...
llvm::Constant * createMappingInformation(Location loc, llvm::OpenMPIRBuilder &builder)
Create a constant string representing the mapping information extracted from the MLIR location inform...
bool opInSharedDeviceContext(Operation &op)
Check whether the given operation is located in a context where an allocation to be used by multiple ...
bool allocaUsesRequireSharedMem(Value alloc)
Check whether the value representing an allocation, assumed to have been defined in a shared device c...
auto getDims(VectorType vType)
Returns a range over the dims (size and scalability) of a VectorType.
Include the generated interface declarations.
SetVector< Block * > getBlocksSortedByDominance(Region ®ion)
Gets a list of blocks that is sorted according to dominance.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool isPure(Operation *op)
Returns true if the given operation is pure, i.e., is speculatable that does not touch memory.
void registerOpenMPDialectTranslation(DialectRegistry ®istry)
Register the OpenMP dialect and the translation from it to the LLVM IR in the given registry;.
llvm::SetVector< T, Vector, Set, N > SetVector
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
A util to collect info needed to convert delayed privatizers from MLIR to LLVM.
SmallVector< mlir::Value > mlirVars
SmallVector< omp::PrivateClauseOp > privatizers
MutableArrayRef< BlockArgument > blockArgs
SmallVector< llvm::Value * > llvmVars
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.