26#include "llvm/ADT/ArrayRef.h"
27#include "llvm/ADT/SmallVector.h"
28#include "llvm/ADT/TypeSwitch.h"
29#include "llvm/Frontend/OpenMP/OMPConstants.h"
30#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
31#include "llvm/IR/Constants.h"
32#include "llvm/IR/DebugInfoMetadata.h"
33#include "llvm/IR/DerivedTypes.h"
34#include "llvm/IR/IRBuilder.h"
35#include "llvm/IR/MDBuilder.h"
36#include "llvm/IR/ReplaceConstant.h"
37#include "llvm/Support/AMDGPUAddrSpace.h"
38#include "llvm/Support/FileSystem.h"
39#include "llvm/Support/NVPTXAddrSpace.h"
40#include "llvm/Support/VirtualFileSystem.h"
41#include "llvm/TargetParser/Triple.h"
42#include "llvm/Transforms/Utils/ModuleUtils.h"
53static llvm::omp::ScheduleKind
54convertToScheduleKind(std::optional<omp::ClauseScheduleKind> schedKind) {
55 if (!schedKind.has_value())
56 return llvm::omp::OMP_SCHEDULE_Default;
57 switch (schedKind.value()) {
58 case omp::ClauseScheduleKind::Static:
59 return llvm::omp::OMP_SCHEDULE_Static;
60 case omp::ClauseScheduleKind::Dynamic:
61 return llvm::omp::OMP_SCHEDULE_Dynamic;
62 case omp::ClauseScheduleKind::Guided:
63 return llvm::omp::OMP_SCHEDULE_Guided;
64 case omp::ClauseScheduleKind::Auto:
65 return llvm::omp::OMP_SCHEDULE_Auto;
66 case omp::ClauseScheduleKind::Runtime:
67 return llvm::omp::OMP_SCHEDULE_Runtime;
68 case omp::ClauseScheduleKind::Distribute:
69 return llvm::omp::OMP_SCHEDULE_Distribute;
71 llvm_unreachable(
"unhandled schedule clause argument");
76class OpenMPAllocStackFrame
81 explicit OpenMPAllocStackFrame(
82 llvm::OpenMPIRBuilder::InsertPointTy allocaIP,
83 llvm::ArrayRef<llvm::BasicBlock *> deallocBlocks)
84 : allocInsertPoint(allocaIP), deallocBlocks(deallocBlocks) {}
85 llvm::OpenMPIRBuilder::InsertPointTy allocInsertPoint;
86 llvm::SmallVector<llvm::BasicBlock *> deallocBlocks;
92class OpenMPLoopInfoStackFrame
96 llvm::CanonicalLoopInfo *loopInfo =
nullptr;
115class PreviouslyReportedError
116 :
public llvm::ErrorInfo<PreviouslyReportedError> {
118 void log(raw_ostream &)
const override {
122 std::error_code convertToErrorCode()
const override {
124 "PreviouslyReportedError doesn't support ECError conversion");
131char PreviouslyReportedError::ID = 0;
142class LinearClauseProcessor {
145 SmallVector<llvm::Value *> linearPreconditionVars;
146 SmallVector<llvm::Value *> linearLoopBodyTemps;
147 SmallVector<llvm::Value *> linearOrigVal;
148 SmallVector<llvm::Value *> linearSteps;
149 SmallVector<llvm::Type *> linearVarTypes;
150 llvm::BasicBlock *linearFinalizationBB;
151 llvm::BasicBlock *linearExitBB;
152 llvm::BasicBlock *linearLastIterExitBB;
156 void registerType(LLVM::ModuleTranslation &moduleTranslation,
157 mlir::Attribute &ty) {
158 linearVarTypes.push_back(moduleTranslation.
convertType(
159 mlir::cast<mlir::TypeAttr>(ty).getValue()));
163 void createLinearVar(llvm::IRBuilderBase &builder,
164 LLVM::ModuleTranslation &moduleTranslation,
165 llvm::Value *linearVar,
int idx) {
166 linearPreconditionVars.push_back(
167 builder.CreateAlloca(linearVarTypes[idx],
nullptr,
".linear_var"));
168 llvm::Value *linearLoopBodyTemp =
169 builder.CreateAlloca(linearVarTypes[idx],
nullptr,
".linear_result");
170 linearOrigVal.push_back(linearVar);
171 linearLoopBodyTemps.push_back(linearLoopBodyTemp);
175 inline void initLinearStep(LLVM::ModuleTranslation &moduleTranslation,
176 mlir::Value &linearStep) {
177 linearSteps.push_back(moduleTranslation.
lookupValue(linearStep));
181 void initLinearVar(llvm::IRBuilderBase &builder,
182 LLVM::ModuleTranslation &moduleTranslation,
183 llvm::BasicBlock *loopPreHeader) {
184 builder.SetInsertPoint(loopPreHeader->getTerminator());
185 for (
size_t index = 0; index < linearOrigVal.size(); index++) {
186 llvm::LoadInst *linearVarLoad =
187 builder.CreateLoad(linearVarTypes[index], linearOrigVal[index]);
188 builder.CreateStore(linearVarLoad, linearPreconditionVars[index]);
193 void updateLinearVar(llvm::IRBuilderBase &builder, llvm::BasicBlock *loopBody,
194 llvm::Value *loopInductionVar) {
195 builder.SetInsertPoint(loopBody->getTerminator());
196 for (
size_t index = 0; index < linearPreconditionVars.size(); index++) {
197 llvm::Type *linearVarType = linearVarTypes[index];
198 llvm::Value *iv = loopInductionVar;
199 llvm::Value *step = linearSteps[index];
201 if (!iv->getType()->isIntegerTy())
202 llvm_unreachable(
"OpenMP loop induction variable must be an integer "
205 if (linearVarType->isIntegerTy()) {
207 iv = builder.CreateSExtOrTrunc(iv, linearVarType);
208 step = builder.CreateSExtOrTrunc(step, linearVarType);
210 llvm::LoadInst *linearVarStart =
211 builder.CreateLoad(linearVarType, linearPreconditionVars[index]);
212 llvm::Value *mulInst = builder.CreateMul(iv, step);
213 llvm::Value *addInst = builder.CreateAdd(linearVarStart, mulInst);
214 builder.CreateStore(addInst, linearLoopBodyTemps[index]);
215 }
else if (linearVarType->isFloatingPointTy()) {
217 step = builder.CreateSExtOrTrunc(step, iv->getType());
218 llvm::Value *mulInst = builder.CreateMul(iv, step);
220 llvm::LoadInst *linearVarStart =
221 builder.CreateLoad(linearVarType, linearPreconditionVars[index]);
222 llvm::Value *mulFp = builder.CreateSIToFP(mulInst, linearVarType);
223 llvm::Value *addInst = builder.CreateFAdd(linearVarStart, mulFp);
224 builder.CreateStore(addInst, linearLoopBodyTemps[index]);
227 "Linear variable must be of integer or floating-point type");
234 void splitLinearFiniBB(llvm::IRBuilderBase &builder,
235 llvm::BasicBlock *loopExit) {
236 linearFinalizationBB = loopExit->splitBasicBlock(
237 loopExit->getTerminator(),
"omp_loop.linear_finalization");
238 linearExitBB = linearFinalizationBB->splitBasicBlock(
239 linearFinalizationBB->getTerminator(),
"omp_loop.linear_exit");
240 linearLastIterExitBB = linearFinalizationBB->splitBasicBlock(
241 linearFinalizationBB->getTerminator(),
"omp_loop.linear_lastiter_exit");
245 llvm::OpenMPIRBuilder::InsertPointOrErrorTy
246 finalizeLinearVar(llvm::IRBuilderBase &builder,
247 LLVM::ModuleTranslation &moduleTranslation,
248 llvm::Value *lastIter) {
250 builder.SetInsertPoint(linearFinalizationBB->getTerminator());
251 llvm::Value *loopLastIterLoad = builder.CreateLoad(
252 llvm::Type::getInt32Ty(builder.getContext()), lastIter);
253 llvm::Value *isLast =
254 builder.CreateCmp(llvm::CmpInst::ICMP_NE, loopLastIterLoad,
255 llvm::ConstantInt::get(
256 llvm::Type::getInt32Ty(builder.getContext()), 0));
258 builder.SetInsertPoint(linearLastIterExitBB->getTerminator());
259 for (
size_t index = 0; index < linearOrigVal.size(); index++) {
260 llvm::LoadInst *linearVarTemp =
261 builder.CreateLoad(linearVarTypes[index], linearLoopBodyTemps[index]);
262 builder.CreateStore(linearVarTemp, linearOrigVal[index]);
268 builder.SetInsertPoint(linearFinalizationBB->getTerminator());
269 builder.CreateCondBr(isLast, linearLastIterExitBB, linearExitBB);
270 linearFinalizationBB->getTerminator()->eraseFromParent();
272 builder.SetInsertPoint(linearExitBB->getTerminator());
274 builder.saveIP(), llvm::omp::OMPD_barrier);
279 void emitStoresForLinearVar(llvm::IRBuilderBase &builder) {
280 for (
size_t index = 0; index < linearOrigVal.size(); index++) {
281 llvm::LoadInst *linearVarTemp =
282 builder.CreateLoad(linearVarTypes[index], linearLoopBodyTemps[index]);
283 builder.CreateStore(linearVarTemp, linearOrigVal[index]);
289 void rewriteInPlace(llvm::IRBuilderBase &builder,
const std::string &BBName,
291 llvm::SmallVector<llvm::User *> users;
292 for (llvm::User *user : linearOrigVal[varIndex]->users())
293 users.push_back(user);
294 for (
auto *user : users) {
295 if (
auto *userInst = dyn_cast<llvm::Instruction>(user)) {
296 if (userInst->getParent()->getName().str().find(BBName) !=
298 user->replaceUsesOfWith(linearOrigVal[varIndex],
299 linearLoopBodyTemps[varIndex]);
310 SymbolRefAttr symbolName) {
311 omp::PrivateClauseOp privatizer =
314 assert(privatizer &&
"privatizer not found in the symbol table");
325 auto todo = [&op](StringRef clauseName) {
326 return op.
emitError() <<
"not yet implemented: Unhandled clause "
327 << clauseName <<
" in " << op.
getName()
331 auto checkAllocate = [&todo](
auto op, LogicalResult &
result) {
332 if (!op.getAllocateVars().empty() || !op.getAllocatorVars().empty())
333 result = todo(
"allocate");
335 auto checkBare = [&todo](
auto op, LogicalResult &
result) {
337 result = todo(
"ompx_bare");
339 auto checkDepend = [&todo](
auto op, LogicalResult &
result) {
340 if (!op.getDependVars().empty() || op.getDependKinds())
343 auto checkHint = [](
auto op, LogicalResult &) {
347 auto checkInReduction = [&todo](
auto op, LogicalResult &
result) {
348 if (!op.getInReductionVars().empty() || op.getInReductionByref() ||
349 op.getInReductionSyms())
350 result = todo(
"in_reduction");
352 auto checkNowait = [&todo](
auto op, LogicalResult &
result) {
356 auto checkOrder = [&todo](
auto op, LogicalResult &
result) {
357 if (op.getOrder() || op.getOrderMod())
360 auto checkPrivate = [&todo](
auto op, LogicalResult &
result) {
361 if (!op.getPrivateVars().empty() || op.getPrivateSyms())
362 result = todo(
"privatization");
364 auto checkReduction = [&todo](
auto op, LogicalResult &
result) {
365 if (isa<omp::TeamsOp>(op) || isa<omp::TaskloopContextOp>(op))
366 if (!op.getReductionVars().empty() || op.getReductionByref() ||
367 op.getReductionSyms())
368 result = todo(
"reduction");
369 if (op.getReductionMod() &&
370 op.getReductionMod().value() != omp::ReductionModifier::defaultmod)
371 result = todo(
"reduction with modifier");
373 auto checkTaskReduction = [&todo](
auto op, LogicalResult &
result) {
374 if (!op.getTaskReductionVars().empty() || op.getTaskReductionByref() ||
375 op.getTaskReductionSyms())
376 result = todo(
"task_reduction");
378 auto checkNumTeams = [&todo](
auto op, LogicalResult &
result) {
379 if (op.hasNumTeamsMultiDim())
380 result = todo(
"num_teams with multi-dimensional values");
382 auto checkNumThreads = [&todo](
auto op, LogicalResult &
result) {
383 if (op.hasNumThreadsMultiDim())
384 result = todo(
"num_threads with multi-dimensional values");
387 auto checkThreadLimit = [&todo](
auto op, LogicalResult &
result) {
388 if (op.hasThreadLimitMultiDim())
389 result = todo(
"thread_limit with multi-dimensional values");
392 auto checkDynGroupprivate = [&todo](
auto op, LogicalResult &
result) {
393 if (op.getDynGroupprivateSize())
394 result = todo(
"dyn_groupprivate");
399 .Case([&](omp::DistributeOp op) {
400 checkAllocate(op,
result);
403 .Case([&](omp::SectionsOp op) {
404 checkAllocate(op,
result);
406 checkReduction(op,
result);
408 .Case([&](omp::ScopeOp op) {
409 checkAllocate(op,
result);
410 checkReduction(op,
result);
412 .Case([&](omp::SingleOp op) {
413 checkAllocate(op,
result);
416 .Case([&](omp::TeamsOp op) {
417 checkAllocate(op,
result);
419 checkNumTeams(op,
result);
420 checkThreadLimit(op,
result);
421 checkDynGroupprivate(op,
result);
423 .Case([&](omp::TaskOp op) {
424 checkAllocate(op,
result);
425 checkInReduction(op,
result);
427 .Case([&](omp::TaskgroupOp op) {
428 checkAllocate(op,
result);
429 checkTaskReduction(op,
result);
431 .Case([&](omp::TaskwaitOp op) {
435 .Case([&](omp::TaskloopContextOp op) {
436 checkAllocate(op,
result);
437 checkInReduction(op,
result);
438 checkReduction(op,
result);
440 .Case([&](omp::WsloopOp op) {
441 checkAllocate(op,
result);
443 checkReduction(op,
result);
445 .Case([&](omp::ParallelOp op) {
446 checkAllocate(op,
result);
447 checkReduction(op,
result);
448 checkNumThreads(op,
result);
450 .Case([&](omp::SimdOp op) { checkReduction(op,
result); })
451 .Case<omp::AtomicReadOp, omp::AtomicWriteOp, omp::AtomicUpdateOp,
452 omp::AtomicCaptureOp>([&](
auto op) { checkHint(op,
result); })
453 .Case<omp::TargetEnterDataOp, omp::TargetExitDataOp>(
454 [&](
auto op) { checkDepend(op,
result); })
455 .Case([&](omp::TargetUpdateOp op) { checkDepend(op,
result); })
456 .Case([&](omp::TargetOp op) {
457 checkAllocate(op,
result);
459 checkInReduction(op,
result);
460 checkThreadLimit(op,
result);
472 llvm::handleAllErrors(
474 [&](
const PreviouslyReportedError &) {
result = failure(); },
475 [&](
const llvm::ErrorInfoBase &err) {
498 llvm::OpenMPIRBuilder::InsertPointTy allocInsertPoint;
501 [&](OpenMPAllocStackFrame &frame) {
502 allocInsertPoint = frame.allocInsertPoint;
503 deallocInsertPoints = frame.deallocBlocks;
511 allocInsertPoint.getBlock()->getParent() ==
512 builder.GetInsertBlock()->getParent()) {
514 deallocBlocks->insert(deallocBlocks->end(), deallocInsertPoints.begin(),
515 deallocInsertPoints.end());
516 return allocInsertPoint;
526 if (builder.GetInsertBlock() ==
527 &builder.GetInsertBlock()->getParent()->getEntryBlock()) {
528 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end() &&
529 "Assuming end of basic block");
530 llvm::BasicBlock *entryBB = llvm::BasicBlock::Create(
531 builder.getContext(),
"entry", builder.GetInsertBlock()->getParent(),
532 builder.GetInsertBlock()->getNextNode());
533 builder.CreateBr(entryBB);
534 builder.SetInsertPoint(entryBB);
540 for (llvm::BasicBlock &block : *builder.GetInsertBlock()->getParent()) {
544 llvm::Instruction *terminator = block.getTerminatorOrNull();
545 if (isa_and_present<llvm::ReturnInst>(terminator))
546 deallocBlocks->emplace_back(&block);
550 llvm::BasicBlock &funcEntryBlock =
551 builder.GetInsertBlock()->getParent()->getEntryBlock();
552 return llvm::OpenMPIRBuilder::InsertPointTy(
553 &funcEntryBlock, funcEntryBlock.getFirstInsertionPt());
559static llvm::CanonicalLoopInfo *
561 llvm::CanonicalLoopInfo *loopInfo =
nullptr;
562 moduleTranslation.
stackWalk<OpenMPLoopInfoStackFrame>(
563 [&](OpenMPLoopInfoStackFrame &frame) {
564 loopInfo = frame.loopInfo;
576 Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder,
579 bool isLoopWrapper = isa<omp::LoopWrapperInterface>(region.
getParentOp());
581 llvm::BasicBlock *continuationBlock =
582 splitBB(builder,
true,
"omp.region.cont");
583 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
585 llvm::LLVMContext &llvmContext = builder.getContext();
586 for (
Block &bb : region) {
587 llvm::BasicBlock *llvmBB = llvm::BasicBlock::Create(
588 llvmContext, blockName, builder.GetInsertBlock()->getParent(),
589 builder.GetInsertBlock()->getNextNode());
590 moduleTranslation.
mapBlock(&bb, llvmBB);
593 llvm::Instruction *sourceTerminator = sourceBlock->getTerminator();
600 unsigned numYields = 0;
602 if (!isLoopWrapper) {
603 bool operandsProcessed =
false;
605 if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(bb.getTerminator())) {
606 if (!operandsProcessed) {
607 for (
unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
608 continuationBlockPHITypes.push_back(
609 moduleTranslation.
convertType(yield->getOperand(i).getType()));
611 operandsProcessed =
true;
613 assert(continuationBlockPHITypes.size() == yield->getNumOperands() &&
614 "mismatching number of values yielded from the region");
615 for (
unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
616 llvm::Type *operandType =
617 moduleTranslation.
convertType(yield->getOperand(i).getType());
619 assert(continuationBlockPHITypes[i] == operandType &&
620 "values of mismatching types yielded from the region");
630 if (!continuationBlockPHITypes.empty())
632 continuationBlockPHIs &&
633 "expected continuation block PHIs if converted regions yield values");
634 if (continuationBlockPHIs) {
635 llvm::IRBuilderBase::InsertPointGuard guard(builder);
636 continuationBlockPHIs->reserve(continuationBlockPHITypes.size());
637 builder.SetInsertPoint(continuationBlock, continuationBlock->begin());
638 for (llvm::Type *ty : continuationBlockPHITypes)
639 continuationBlockPHIs->push_back(builder.CreatePHI(ty, numYields));
645 for (
Block *bb : blocks) {
646 llvm::BasicBlock *llvmBB = moduleTranslation.
lookupBlock(bb);
649 if (bb->isEntryBlock()) {
650 assert(sourceTerminator->getNumSuccessors() == 1 &&
651 "provided entry block has multiple successors");
652 assert(sourceTerminator->getSuccessor(0) == continuationBlock &&
653 "ContinuationBlock is not the successor of the entry block");
654 sourceTerminator->setSuccessor(0, llvmBB);
657 llvm::IRBuilderBase::InsertPointGuard guard(builder);
659 moduleTranslation.
convertBlock(*bb, bb->isEntryBlock(), builder)))
660 return llvm::make_error<PreviouslyReportedError>();
665 builder.CreateBr(continuationBlock);
676 Operation *terminator = bb->getTerminator();
677 if (isa<omp::TerminatorOp, omp::YieldOp>(terminator)) {
678 builder.CreateBr(continuationBlock);
680 for (
unsigned i = 0, e = terminator->
getNumOperands(); i < e; ++i)
681 (*continuationBlockPHIs)[i]->addIncoming(
695 return continuationBlock;
701 case omp::ClauseProcBindKind::Close:
702 return llvm::omp::ProcBindKind::OMP_PROC_BIND_close;
703 case omp::ClauseProcBindKind::Master:
704 return llvm::omp::ProcBindKind::OMP_PROC_BIND_master;
705 case omp::ClauseProcBindKind::Primary:
706 return llvm::omp::ProcBindKind::OMP_PROC_BIND_primary;
707 case omp::ClauseProcBindKind::Spread:
708 return llvm::omp::ProcBindKind::OMP_PROC_BIND_spread;
710 llvm_unreachable(
"Unknown ClauseProcBindKind kind");
717 auto maskedOp = cast<omp::MaskedOp>(opInst);
718 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
723 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
726 auto ®ion = maskedOp.getRegion();
727 builder.restoreIP(codeGenIP);
735 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
737 llvm::Value *filterVal =
nullptr;
738 if (
auto filterVar = maskedOp.getFilteredThreadId()) {
739 filterVal = moduleTranslation.
lookupValue(filterVar);
741 llvm::LLVMContext &llvmContext = builder.getContext();
743 llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext), 0);
745 assert(filterVal !=
nullptr);
746 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
747 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
754 builder.restoreIP(*afterIP);
762 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
763 auto masterOp = cast<omp::MasterOp>(opInst);
768 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
771 auto ®ion = masterOp.getRegion();
772 builder.restoreIP(codeGenIP);
780 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
782 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
783 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
790 builder.restoreIP(*afterIP);
798 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
799 auto criticalOp = cast<omp::CriticalOp>(opInst);
804 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
807 auto ®ion = cast<omp::CriticalOp>(opInst).getRegion();
808 builder.restoreIP(codeGenIP);
816 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
818 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
819 llvm::LLVMContext &llvmContext = moduleTranslation.
getLLVMContext();
820 llvm::Constant *hint =
nullptr;
823 if (criticalOp.getNameAttr()) {
826 auto symbolRef = cast<SymbolRefAttr>(criticalOp.getNameAttr());
827 auto criticalDeclareOp =
831 llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext),
832 static_cast<int>(criticalDeclareOp.getHint()));
834 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
836 ompLoc, bodyGenCB, finiCB, criticalOp.getName().value_or(
""), hint);
841 builder.restoreIP(*afterIP);
848 template <
typename OP>
851 cast<
omp::BlockArgOpenMPOpInterface>(*op).getPrivateBlockArgs()) {
854 collectPrivatizationDecls<OP>(op);
869 void collectPrivatizationDecls(OP op) {
870 std::optional<ArrayAttr> attr = op.getPrivateSyms();
875 for (
auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
886 std::optional<ArrayAttr> attr = op.getReductionSyms();
890 reductions.reserve(reductions.size() + op.getNumReductionVars());
891 for (
auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
892 reductions.push_back(
904 Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder,
913 llvm::Instruction *potentialTerminator =
914 builder.GetInsertBlock()->empty() ?
nullptr
915 : &builder.GetInsertBlock()->back();
917 if (potentialTerminator && potentialTerminator->isTerminator())
918 potentialTerminator->removeFromParent();
919 moduleTranslation.
mapBlock(®ion.
front(), builder.GetInsertBlock());
922 region.
front(),
true, builder)))
926 if (continuationBlockArgs)
928 *continuationBlockArgs,
935 if (potentialTerminator && potentialTerminator->isTerminator()) {
936 llvm::BasicBlock *block = builder.GetInsertBlock();
937 if (block->empty()) {
943 potentialTerminator->insertInto(block, block->begin());
945 potentialTerminator->insertAfter(&block->back());
959 if (continuationBlockArgs)
960 llvm::append_range(*continuationBlockArgs, phis);
961 builder.SetInsertPoint(*continuationBlock,
962 (*continuationBlock)->getFirstInsertionPt());
969using OwningReductionGen =
970 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
971 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *,
973using OwningAtomicReductionGen =
974 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
975 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Type *, llvm::Value *,
977using OwningDataPtrPtrReductionGen =
978 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
979 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *&)>;
985static OwningReductionGen
991 OwningReductionGen gen =
992 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
993 llvm::Value *
lhs, llvm::Value *
rhs,
994 llvm::Value *&
result)
mutable
995 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
996 moduleTranslation.
mapValue(decl.getReductionLhsArg(),
lhs);
997 moduleTranslation.
mapValue(decl.getReductionRhsArg(),
rhs);
998 builder.restoreIP(insertPoint);
1001 "omp.reduction.nonatomic.body", builder,
1002 moduleTranslation, &phis)))
1003 return llvm::createStringError(
1004 "failed to inline `combiner` region of `omp.declare_reduction`");
1005 result = llvm::getSingleElement(phis);
1006 return builder.saveIP();
1015static OwningAtomicReductionGen
1017 llvm::IRBuilderBase &builder,
1019 if (decl.getAtomicReductionRegion().empty())
1020 return OwningAtomicReductionGen();
1026 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint, llvm::Type *,
1027 llvm::Value *
lhs, llvm::Value *
rhs)
mutable
1028 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1029 moduleTranslation.
mapValue(decl.getAtomicReductionLhsArg(),
lhs);
1030 moduleTranslation.
mapValue(decl.getAtomicReductionRhsArg(),
rhs);
1031 builder.restoreIP(insertPoint);
1034 "omp.reduction.atomic.body", builder,
1035 moduleTranslation, &phis)))
1036 return llvm::createStringError(
1037 "failed to inline `atomic` region of `omp.declare_reduction`");
1038 assert(phis.empty());
1039 return builder.saveIP();
1048static OwningDataPtrPtrReductionGen
1051 if (!isByRef || decl.getDataPtrPtrRegion().empty())
1052 return OwningDataPtrPtrReductionGen();
1054 OwningDataPtrPtrReductionGen refDataPtrGen =
1055 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
1056 llvm::Value *byRefVal, llvm::Value *&
result)
mutable
1057 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1058 moduleTranslation.
mapValue(decl.getDataPtrPtrRegionArg(), byRefVal);
1059 builder.restoreIP(insertPoint);
1062 "omp.data_ptr_ptr.body", builder,
1063 moduleTranslation, &phis)))
1064 return llvm::createStringError(
1065 "failed to inline `data_ptr_ptr` region of `omp.declare_reduction`");
1066 result = llvm::getSingleElement(phis);
1067 return builder.saveIP();
1070 return refDataPtrGen;
1077 auto orderedOp = cast<omp::OrderedOp>(opInst);
1082 omp::ClauseDepend dependType = *orderedOp.getDoacrossDependType();
1083 bool isDependSource = dependType == omp::ClauseDepend::dependsource;
1084 unsigned numLoops = *orderedOp.getDoacrossNumLoops();
1086 moduleTranslation.
lookupValues(orderedOp.getDoacrossDependVars());
1088 size_t indexVecValues = 0;
1089 while (indexVecValues < vecValues.size()) {
1091 storeValues.reserve(numLoops);
1092 for (
unsigned i = 0; i < numLoops; i++) {
1093 storeValues.push_back(vecValues[indexVecValues]);
1096 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1098 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1099 builder.restoreIP(moduleTranslation.
getOpenMPBuilder()->createOrderedDepend(
1100 ompLoc, allocaIP, numLoops, storeValues,
".cnt.addr", isDependSource));
1110 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1111 auto orderedRegionOp = cast<omp::OrderedRegionOp>(opInst);
1116 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
1119 auto ®ion = cast<omp::OrderedRegionOp>(opInst).getRegion();
1120 builder.restoreIP(codeGenIP);
1128 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
1130 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1131 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1133 ompLoc, bodyGenCB, finiCB, !orderedRegionOp.getParLevelSimd());
1138 builder.restoreIP(*afterIP);
1144struct DeferredStore {
1145 DeferredStore(llvm::Value *value, llvm::Value *address)
1146 : value(value), address(address) {}
1149 llvm::Value *address;
1156template <
typename T>
1159 llvm::IRBuilderBase &builder,
1161 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1167 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1168 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1174 deferredStores.reserve(op.getNumReductionVars());
1176 for (std::size_t i = 0; i < op.getNumReductionVars(); ++i) {
1177 Region &allocRegion = reductionDecls[i].getAllocRegion();
1179 if (allocRegion.
empty())
1184 builder, moduleTranslation, &phis)))
1185 return op.emitError(
1186 "failed to inline `alloc` region of `omp.declare_reduction`");
1188 assert(phis.size() == 1 &&
"expected one allocation to be yielded");
1189 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1193 llvm::Type *ptrTy = builder.getPtrTy();
1197 if (useDeviceSharedMem) {
1198 var = ompBuilder->createOMPAllocShared(builder, varTy);
1200 var = builder.CreateAlloca(varTy);
1201 var = builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
1204 llvm::Value *castPhi =
1205 builder.CreatePointerBitCastOrAddrSpaceCast(phis[0], ptrTy);
1207 deferredStores.emplace_back(castPhi, var);
1209 privateReductionVariables[i] = var;
1210 moduleTranslation.
mapValue(reductionArgs[i], castPhi);
1211 reductionVariableMap.try_emplace(op.getReductionVars()[i], castPhi);
1213 assert(allocRegion.
empty() &&
1214 "allocaction is implicit for by-val reduction");
1216 llvm::Type *ptrTy = builder.getPtrTy();
1220 if (useDeviceSharedMem) {
1221 var = ompBuilder->createOMPAllocShared(builder, varTy);
1223 var = builder.CreateAlloca(varTy);
1224 var = builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
1227 moduleTranslation.
mapValue(reductionArgs[i], var);
1228 privateReductionVariables[i] = var;
1229 reductionVariableMap.try_emplace(op.getReductionVars()[i], var);
1237template <
typename T>
1240 llvm::IRBuilderBase &builder,
1245 mlir::omp::DeclareReductionOp &reduction = reductionDecls[i];
1246 Region &initializerRegion = reduction.getInitializerRegion();
1249 mlir::Value mlirSource = loop.getReductionVars()[i];
1250 llvm::Value *llvmSource = moduleTranslation.
lookupValue(mlirSource);
1251 llvm::Value *origVal = llvmSource;
1253 if (!isa<LLVM::LLVMPointerType>(
1254 reduction.getInitializerMoldArg().getType()) &&
1255 isa<LLVM::LLVMPointerType>(mlirSource.
getType())) {
1258 reduction.getInitializerMoldArg().getType()),
1259 llvmSource,
"omp_orig");
1261 moduleTranslation.
mapValue(reduction.getInitializerMoldArg(), origVal);
1264 llvm::Value *allocation =
1265 reductionVariableMap.lookup(loop.getReductionVars()[i]);
1266 moduleTranslation.
mapValue(reduction.getInitializerAllocArg(), allocation);
1272 llvm::BasicBlock *block =
nullptr) {
1273 if (block ==
nullptr)
1274 block = builder.GetInsertBlock();
1276 if (!block->hasTerminator())
1277 builder.SetInsertPoint(block);
1279 builder.SetInsertPoint(block->getTerminator());
1287template <
typename OP>
1290 llvm::IRBuilderBase &builder,
1292 llvm::BasicBlock *latestAllocaBlock,
1298 if (op.getNumReductionVars() == 0)
1304 llvm::BasicBlock *initBlock = splitBB(builder,
true,
"omp.reduction.init");
1305 auto allocaIP = llvm::IRBuilderBase::InsertPoint(
1306 latestAllocaBlock, latestAllocaBlock->getTerminator()->getIterator());
1307 builder.restoreIP(allocaIP);
1310 for (
unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1312 if (!reductionDecls[i].getAllocRegion().empty())
1320 if (useDeviceSharedMem)
1321 byRefVars[i] = ompBuilder->createOMPAllocShared(builder, varTy);
1323 byRefVars[i] = builder.CreateAlloca(varTy);
1331 for (
auto [data, addr] : deferredStores)
1332 builder.CreateStore(data, addr);
1337 for (
unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1342 reductionVariableMap, i);
1350 "omp.reduction.neutral", builder,
1351 moduleTranslation, &phis)))
1354 assert(phis.size() == 1 &&
"expected one value to be yielded from the "
1355 "reduction neutral element declaration region");
1360 if (!reductionDecls[i].getAllocRegion().empty())
1369 builder.CreateStore(phis[0], byRefVars[i]);
1371 privateReductionVariables[i] = byRefVars[i];
1372 moduleTranslation.
mapValue(reductionArgs[i], phis[0]);
1373 reductionVariableMap.try_emplace(op.getReductionVars()[i], phis[0]);
1376 builder.CreateStore(phis[0], privateReductionVariables[i]);
1383 moduleTranslation.
forgetMapping(reductionDecls[i].getInitializerRegion());
1390template <
typename T>
1391static void collectReductionInfo(
1392 T loop, llvm::IRBuilderBase &builder,
1401 unsigned numReductions = loop.getNumReductionVars();
1403 for (
unsigned i = 0; i < numReductions; ++i) {
1406 owningAtomicReductionGens.push_back(
1409 reductionDecls[i], builder, moduleTranslation, isByRef[i]));
1413 reductionInfos.reserve(numReductions);
1414 for (
unsigned i = 0; i < numReductions; ++i) {
1415 llvm::OpenMPIRBuilder::ReductionGenAtomicCBTy
atomicGen =
nullptr;
1416 if (owningAtomicReductionGens[i])
1417 atomicGen = owningAtomicReductionGens[i];
1418 llvm::Value *variable =
1419 moduleTranslation.
lookupValue(loop.getReductionVars()[i]);
1422 if (
auto alloca = mlir::dyn_cast<LLVM::AllocaOp>(op)) {
1423 allocatedType = alloca.getElemType();
1430 reductionInfos.push_back(
1432 privateReductionVariables[i],
1433 llvm::OpenMPIRBuilder::EvalKind::Scalar,
1437 allocatedType ? moduleTranslation.
convertType(allocatedType) :
nullptr,
1438 reductionDecls[i].getByrefElementType()
1440 *reductionDecls[i].getByrefElementType())
1450 llvm::IRBuilderBase &builder, StringRef regionName,
1451 bool shouldLoadCleanupRegionArg =
true) {
1452 for (
auto [i, cleanupRegion] : llvm::enumerate(cleanupRegions)) {
1453 if (cleanupRegion->empty())
1459 llvm::Instruction *potentialTerminator =
1460 builder.GetInsertBlock()->empty() ?
nullptr
1461 : &builder.GetInsertBlock()->back();
1462 if (potentialTerminator && potentialTerminator->isTerminator())
1463 builder.SetInsertPoint(potentialTerminator);
1464 llvm::Value *privateVarValue =
1465 shouldLoadCleanupRegionArg
1466 ? builder.CreateLoad(
1468 privateVariables[i])
1469 : privateVariables[i];
1474 moduleTranslation)))
1487 OP op, llvm::IRBuilderBase &builder,
1489 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1492 bool isNowait =
false,
bool isTeamsReduction =
false) {
1494 if (op.getNumReductionVars() == 0)
1506 collectReductionInfo(op, builder, moduleTranslation, reductionDecls,
1508 owningReductionGenRefDataPtrGens,
1509 privateReductionVariables, reductionInfos, isByRef);
1514 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
1515 builder.SetInsertPoint(tempTerminator);
1516 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
1517 ompBuilder->createReductions(builder.saveIP(), allocaIP, reductionInfos,
1518 isByRef, isNowait, isTeamsReduction);
1523 if (!contInsertPoint->getBlock())
1524 return op->emitOpError() <<
"failed to convert reductions";
1526 llvm::OpenMPIRBuilder::InsertPointTy afterIP = *contInsertPoint;
1527 if (!isTeamsReduction) {
1528 llvm::OpenMPIRBuilder::InsertPointOrErrorTy barrierIP =
1529 ompBuilder->createBarrier(*contInsertPoint, llvm::omp::OMPD_for);
1533 afterIP = *barrierIP;
1536 tempTerminator->eraseFromParent();
1537 builder.restoreIP(afterIP);
1541 llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
1542 [](omp::DeclareReductionOp reductionDecl) {
1543 return &reductionDecl.getCleanupRegion();
1546 reductionRegions, privateReductionVariables, moduleTranslation, builder,
1547 "omp.reduction.cleanup");
1550 if (useDeviceSharedMem) {
1551 for (
auto [var, reductionDecl] :
1552 llvm::zip_equal(privateReductionVariables, reductionDecls))
1553 ompBuilder->createOMPFreeShared(
1554 builder, var, moduleTranslation.
convertType(reductionDecl.getType()));
1567template <
typename OP>
1571 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1576 if (op.getNumReductionVars() == 0)
1582 allocaIP, reductionDecls,
1583 privateReductionVariables, reductionVariableMap,
1584 deferredStores, isByRef)))
1587 return initReductionVars(op, reductionArgs, builder, moduleTranslation,
1588 allocaIP.getBlock(), reductionDecls,
1589 privateReductionVariables, reductionVariableMap,
1590 isByRef, deferredStores);
1604 if (mappedPrivateVars ==
nullptr || !mappedPrivateVars->contains(privateVar))
1607 Value blockArg = (*mappedPrivateVars)[privateVar];
1610 assert(isa<LLVM::LLVMPointerType>(blockArgType) &&
1611 "A block argument corresponding to a mapped var should have "
1614 if (privVarType == blockArgType)
1621 if (!isa<LLVM::LLVMPointerType>(privVarType))
1622 return builder.CreateLoad(moduleTranslation.
convertType(privVarType),
1635 omp::PrivateClauseOp &privDecl, llvm::Value *nonPrivateVar,
1637 llvm::BasicBlock *privInitBlock,
1639 Region &initRegion = privDecl.getInitRegion();
1640 if (initRegion.
empty())
1641 return llvmPrivateVar;
1643 assert(nonPrivateVar);
1644 moduleTranslation.
mapValue(privDecl.getInitMoldArg(), nonPrivateVar);
1645 moduleTranslation.
mapValue(privDecl.getInitPrivateArg(), llvmPrivateVar);
1650 moduleTranslation, &phis)))
1651 return llvm::createStringError(
1652 "failed to inline `init` region of `omp.private`");
1654 assert(phis.size() == 1 &&
"expected one allocation to be yielded");
1671 llvm::Value *llvmPrivateVar, llvm::BasicBlock *privInitBlock,
1674 builder, moduleTranslation, privDecl,
1677 blockArg, llvmPrivateVar, privInitBlock, mappedPrivateVars);
1686 return llvm::Error::success();
1688 llvm::BasicBlock *privInitBlock = splitBB(builder,
true,
"omp.private.init");
1691 for (
auto [idx, zip] : llvm::enumerate(llvm::zip_equal(
1694 auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVar] = zip;
1696 builder, moduleTranslation, privDecl, mlirPrivVar, blockArg,
1697 llvmPrivateVar, privInitBlock, mappedPrivateVars);
1700 return privVarOrErr.takeError();
1702 llvmPrivateVar = privVarOrErr.get();
1703 moduleTranslation.
mapValue(blockArg, llvmPrivateVar);
1708 return llvm::Error::success();
1714template <
typename T>
1719 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1722 llvm::Instruction *allocaTerminator = allocaIP.getBlock()->getTerminator();
1723 splitBB(llvm::OpenMPIRBuilder::InsertPointTy(allocaIP.getBlock(),
1724 allocaTerminator->getIterator()),
1725 true, allocaTerminator->getStableDebugLoc(),
1726 "omp.region.after_alloca");
1728 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1730 allocaTerminator = allocaIP.getBlock()->getTerminator();
1731 builder.SetInsertPoint(allocaTerminator);
1733 assert(allocaTerminator->getNumSuccessors() == 1 &&
1734 "This is an unconditional branch created by splitBB");
1736 llvm::DataLayout dataLayout = builder.GetInsertBlock()->getDataLayout();
1737 llvm::BasicBlock *afterAllocas = allocaTerminator->getSuccessor(0);
1741 unsigned int allocaAS =
1742 moduleTranslation.
getLLVMModule()->getDataLayout().getAllocaAddrSpace();
1745 .getProgramAddressSpace();
1747 for (
auto [privDecl, mlirPrivVar, blockArg] :
1750 llvm::Type *llvmAllocType =
1751 moduleTranslation.
convertType(privDecl.getType());
1752 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1753 llvm::Value *llvmPrivateVar =
nullptr;
1755 llvmPrivateVar = ompBuilder->createOMPAllocShared(builder, llvmAllocType);
1757 llvmPrivateVar = builder.CreateAlloca(
1758 llvmAllocType,
nullptr,
"omp.private.alloc");
1759 if (allocaAS != defaultAS)
1760 llvmPrivateVar = builder.CreateAddrSpaceCast(
1761 llvmPrivateVar, builder.getPtrTy(defaultAS));
1764 privateVarsInfo.
llvmVars.push_back(llvmPrivateVar);
1767 return afterAllocas;
1775 if (mlir::isa<omp::SingleOp, omp::CriticalOp>(parent))
1784 if (mlir::isa<omp::ParallelOp>(parent))
1798 bool needsFirstprivate =
1799 llvm::any_of(privateDecls, [](omp::PrivateClauseOp &privOp) {
1800 return privOp.getDataSharingType() ==
1801 omp::DataSharingClauseType::FirstPrivate;
1804 if (!needsFirstprivate)
1807 llvm::BasicBlock *copyBlock =
1808 splitBB(builder,
true,
"omp.private.copy");
1811 for (
auto [decl, moldVar, llvmVar] :
1812 llvm::zip_equal(privateDecls, moldVars, llvmPrivateVars)) {
1813 if (decl.getDataSharingType() != omp::DataSharingClauseType::FirstPrivate)
1817 Region ©Region = decl.getCopyRegion();
1819 moduleTranslation.
mapValue(decl.getCopyMoldArg(), moldVar);
1822 moduleTranslation.
mapValue(decl.getCopyPrivateArg(), llvmVar);
1826 moduleTranslation)))
1827 return decl.emitError(
"failed to inline `copy` region of `omp.private`");
1841 llvm::OpenMPIRBuilder::InsertPointOrErrorTy res =
1842 ompBuilder->createBarrier(builder.saveIP(), llvm::omp::OMPD_barrier);
1858 llvm::transform(mlirPrivateVars, moldVars.begin(), [&](
mlir::Value mlirVar) {
1860 llvm::Value *moldVar = findAssociatedValue(
1861 mlirVar, builder, moduleTranslation, mappedPrivateVars);
1866 llvmPrivateVars, privateDecls, insertBarrier,
1870template <
typename T>
1878 std::back_inserter(privateCleanupRegions),
1879 [](omp::PrivateClauseOp privatizer) {
1880 return &privatizer.getDeallocRegion();
1884 privateVarsInfo.
llvmVars, moduleTranslation,
1885 builder,
"omp.private.dealloc",
1887 return mlir::emitError(loc,
"failed to inline `dealloc` region of an "
1888 "`omp.private` op in");
1892 for (
auto [privDecl, llvmPrivVar, blockArg] :
1896 ompBuilder->createOMPFreeShared(
1897 builder, llvmPrivVar,
1898 moduleTranslation.
convertType(privDecl.getType()));
1912 if (mlir::isa<omp::CancelOp, omp::CancellationPointOp>(child))
1922 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1923 using StorableBodyGenCallbackTy =
1924 llvm::OpenMPIRBuilder::StorableBodyGenCallbackTy;
1926 auto sectionsOp = cast<omp::SectionsOp>(opInst);
1932 assert(isByRef.size() == sectionsOp.getNumReductionVars());
1936 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1940 sectionsOp.getNumReductionVars());
1944 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
1947 sectionsOp, reductionArgs, builder, moduleTranslation, allocaIP,
1948 reductionDecls, privateReductionVariables, reductionVariableMap,
1955 auto sectionOp = dyn_cast<omp::SectionOp>(op);
1959 Region ®ion = sectionOp.getRegion();
1960 auto sectionCB = [§ionsOp, ®ion, &builder, &moduleTranslation](
1961 InsertPointTy allocaIP, InsertPointTy codeGenIP,
1963 builder.restoreIP(codeGenIP);
1970 sectionsOp.getRegion().getNumArguments());
1971 for (
auto [sectionsArg, sectionArg] : llvm::zip_equal(
1972 sectionsOp.getRegion().getArguments(), region.
getArguments())) {
1973 llvm::Value *llvmVal = moduleTranslation.
lookupValue(sectionsArg);
1975 moduleTranslation.
mapValue(sectionArg, llvmVal);
1982 sectionCBs.push_back(sectionCB);
1988 if (sectionCBs.empty())
1991 assert(isa<omp::SectionOp>(*sectionsOp.getRegion().op_begin()));
1996 auto privCB = [&](InsertPointTy, InsertPointTy codeGenIP, llvm::Value &,
1997 llvm::Value &vPtr, llvm::Value *&replacementValue)
1998 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1999 replacementValue = &vPtr;
2005 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
2009 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2010 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2012 ompLoc, allocaIP, sectionCBs, privCB, finiCB, isCancellable,
2013 sectionsOp.getNowait());
2018 builder.restoreIP(*afterIP);
2022 sectionsOp, builder, moduleTranslation, allocaIP, reductionDecls,
2023 privateReductionVariables, isByRef, sectionsOp.getNowait());
2030 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2037 assert(isByRef.size() == scopeOp.getNumReductionVars());
2046 scopeOp.getNumReductionVars());
2050 cast<omp::BlockArgOpenMPOpInterface>(*scopeOp).getReductionBlockArgs();
2054 scopeOp, builder, moduleTranslation, privateVarsInfo, allocaIP);
2059 scopeOp, reductionArgs, builder, moduleTranslation, allocaIP,
2060 reductionDecls, privateReductionVariables, reductionVariableMap,
2065 [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
2067 builder.restoreIP(codeGenIP);
2073 return llvm::make_error<PreviouslyReportedError>();
2076 scopeOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
2078 scopeOp.getPrivateNeedsBarrier())))
2079 return llvm::make_error<PreviouslyReportedError>();
2086 auto finiCB = [&](InsertPointTy codeGenIP) -> llvm::Error {
2087 InsertPointTy oldIP = builder.saveIP();
2088 builder.restoreIP(codeGenIP);
2090 scopeOp.getLoc(), privateVarsInfo)))
2091 return llvm::make_error<PreviouslyReportedError>();
2092 builder.restoreIP(oldIP);
2093 return llvm::Error::success();
2096 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2097 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2098 ompBuilder->createScope(ompLoc, bodyCB, finiCB, scopeOp.getNowait());
2103 builder.restoreIP(*afterIP);
2107 scopeOp, builder, moduleTranslation, allocaIP, reductionDecls,
2108 privateReductionVariables, isByRef, scopeOp.getNowait(),
2116 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2117 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2122 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP,
2124 builder.restoreIP(codegenIP);
2126 builder, moduleTranslation)
2129 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
2133 std::optional<ArrayAttr> cpFuncs = singleOp.getCopyprivateSyms();
2136 for (
size_t i = 0, e = cpVars.size(); i < e; ++i) {
2137 llvmCPVars.push_back(moduleTranslation.
lookupValue(cpVars[i]));
2139 singleOp, cast<SymbolRefAttr>((*cpFuncs)[i]));
2140 llvmCPFuncs.push_back(
2144 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2146 ompLoc, bodyCB, finiCB, singleOp.getNowait(), llvmCPVars,
2152 builder.restoreIP(*afterIP);
2156static omp::DistributeOp
2160 omp::DistributeOp distOp;
2161 WalkResult walk = teamsOp.getRegion().walk([&](omp::DistributeOp op) {
2167 if (walk.wasInterrupted() || !distOp)
2171 llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(teamsOp.getOperation());
2175 for (
auto ra : iface.getReductionBlockArgs())
2176 for (
auto &use : ra.getUses()) {
2177 auto *useOp = use.getOwner();
2179 if (mlir::isa<LLVM::DbgDeclareOp, LLVM::DbgValueOp>(useOp)) {
2180 debugUses.push_back(useOp);
2183 if (!distOp->isProperAncestor(useOp))
2190 for (
auto *use : debugUses)
2199 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2204 unsigned numReductionVars = op.getNumReductionVars();
2208 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2214 if (doTeamsReduction) {
2215 isByRef =
getIsByRef(op.getReductionByref());
2217 assert(isByRef.size() == op.getNumReductionVars());
2220 llvm::cast<omp::BlockArgOpenMPOpInterface>(*op).getReductionBlockArgs();
2225 op, reductionArgs, builder, moduleTranslation, allocaIP,
2226 reductionDecls, privateReductionVariables, reductionVariableMap,
2231 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP,
2234 moduleTranslation, allocaIP, deallocBlocks);
2235 builder.restoreIP(codegenIP);
2241 llvm::Value *numTeamsLower =
nullptr;
2242 if (
Value numTeamsLowerVar = op.getNumTeamsLower())
2243 numTeamsLower = moduleTranslation.
lookupValue(numTeamsLowerVar);
2245 llvm::Value *numTeamsUpper =
nullptr;
2246 if (!op.getNumTeamsUpperVars().empty())
2247 numTeamsUpper = moduleTranslation.
lookupValue(op.getNumTeams(0));
2249 llvm::Value *threadLimit =
nullptr;
2250 if (!op.getThreadLimitVars().empty())
2251 threadLimit = moduleTranslation.
lookupValue(op.getThreadLimit(0));
2253 llvm::Value *ifExpr =
nullptr;
2254 if (
Value ifVar = op.getIfExpr())
2257 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2258 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2260 ompLoc, bodyCB, numTeamsLower, numTeamsUpper, threadLimit, ifExpr);
2265 builder.restoreIP(*afterIP);
2266 if (doTeamsReduction) {
2269 op, builder, moduleTranslation, allocaIP, reductionDecls,
2270 privateReductionVariables, isByRef,
2276static llvm::omp::RTLDependenceKindTy
2279 case mlir::omp::ClauseTaskDepend::taskdependin:
2280 return llvm::omp::RTLDependenceKindTy::DepIn;
2284 case mlir::omp::ClauseTaskDepend::taskdependout:
2285 case mlir::omp::ClauseTaskDepend::taskdependinout:
2286 return llvm::omp::RTLDependenceKindTy::DepInOut;
2287 case mlir::omp::ClauseTaskDepend::taskdependmutexinoutset:
2288 return llvm::omp::RTLDependenceKindTy::DepMutexInOutSet;
2289 case mlir::omp::ClauseTaskDepend::taskdependinoutset:
2290 return llvm::omp::RTLDependenceKindTy::DepInOutSet;
2292 llvm_unreachable(
"unhandled depend kind");
2296 std::optional<ArrayAttr> dependKinds,
OperandRange dependVars,
2299 if (dependVars.empty())
2301 for (
auto dep : llvm::zip(dependVars, dependKinds->getValue())) {
2303 cast<mlir::omp::ClauseTaskDependAttr>(std::get<1>(dep)).getValue();
2305 llvm::Value *depVal = moduleTranslation.
lookupValue(std::get<0>(dep));
2306 llvm::OpenMPIRBuilder::DependData dd(type, depVal->getType(), depVal);
2307 dds.emplace_back(dd);
2319 llvm::IRBuilderBase &llvmBuilder, llvm::OpenMPIRBuilder &ompBuilder,
2321 auto finiCB = [&](llvm::OpenMPIRBuilder::InsertPointTy ip) -> llvm::Error {
2322 llvm::IRBuilderBase::InsertPointGuard guard(llvmBuilder);
2326 llvmBuilder.restoreIP(ip);
2332 cancelTerminators.push_back(llvmBuilder.CreateBr(ip.getBlock()));
2333 return llvm::Error::success();
2338 ompBuilder.pushFinalizationCB(
2348 llvm::OpenMPIRBuilder &ompBuilder,
2349 const llvm::OpenMPIRBuilder::InsertPointTy &afterIP) {
2350 ompBuilder.popFinalizationCB();
2351 llvm::BasicBlock *constructFini = afterIP.getBlock()->getSinglePredecessor();
2352 for (llvm::UncondBrInst *cancelBranch : cancelTerminators)
2353 cancelBranch->setSuccessor(constructFini);
2359class TaskContextStructManager {
2361 TaskContextStructManager(llvm::IRBuilderBase &builder,
2362 LLVM::ModuleTranslation &moduleTranslation,
2363 MutableArrayRef<omp::PrivateClauseOp> privateDecls)
2364 : builder{builder}, moduleTranslation{moduleTranslation},
2365 privateDecls{privateDecls} {}
2371 void generateTaskContextStruct();
2377 void createGEPsToPrivateVars();
2383 SmallVector<llvm::Value *>
2384 createGEPsToPrivateVars(llvm::Value *altStructPtr)
const;
2387 void freeStructPtr();
2389 MutableArrayRef<llvm::Value *> getLLVMPrivateVarGEPs() {
2390 return llvmPrivateVarGEPs;
2393 llvm::Value *getStructPtr() {
return structPtr; }
2396 llvm::IRBuilderBase &builder;
2397 LLVM::ModuleTranslation &moduleTranslation;
2398 MutableArrayRef<omp::PrivateClauseOp> privateDecls;
2401 SmallVector<llvm::Type *> privateVarTypes;
2405 SmallVector<llvm::Value *> llvmPrivateVarGEPs;
2408 llvm::Value *structPtr =
nullptr;
2410 llvm::Type *structTy =
nullptr;
2421 llvm::SmallVector<llvm::Value *> lowerBounds;
2422 llvm::SmallVector<llvm::Value *> upperBounds;
2423 llvm::SmallVector<llvm::Value *> steps;
2424 llvm::SmallVector<llvm::Value *> trips;
2426 llvm::Value *totalTrips;
2428 llvm::Value *lookUpAsI64(mlir::Value val,
const LLVM::ModuleTranslation &mt,
2429 llvm::IRBuilderBase &builder) {
2433 if (v->getType()->isIntegerTy(64))
2435 if (v->getType()->isIntegerTy())
2436 return builder.CreateSExtOrTrunc(v, builder.getInt64Ty());
2441 IteratorInfo(mlir::omp::IteratorOp itersOp,
2442 mlir::LLVM::ModuleTranslation &moduleTranslation,
2443 llvm::IRBuilderBase &builder) {
2444 dims = itersOp.getLoopLowerBounds().size();
2445 lowerBounds.resize(dims);
2446 upperBounds.resize(dims);
2450 for (
unsigned d = 0; d < dims; ++d) {
2451 llvm::Value *lb = lookUpAsI64(itersOp.getLoopLowerBounds()[d],
2452 moduleTranslation, builder);
2453 llvm::Value *ub = lookUpAsI64(itersOp.getLoopUpperBounds()[d],
2454 moduleTranslation, builder);
2456 lookUpAsI64(itersOp.getLoopSteps()[d], moduleTranslation, builder);
2457 assert(lb && ub && st &&
2458 "Expect lowerBounds, upperBounds, and steps in IteratorOp");
2459 assert((!llvm::isa<llvm::ConstantInt>(st) ||
2460 !llvm::cast<llvm::ConstantInt>(st)->isZero()) &&
2461 "Expect non-zero step in IteratorOp");
2463 lowerBounds[d] = lb;
2464 upperBounds[d] = ub;
2468 llvm::Value *diff = builder.CreateSub(ub, lb);
2469 llvm::Value *
div = builder.CreateSDiv(diff, st);
2470 trips[d] = builder.CreateAdd(
2471 div, llvm::ConstantInt::get(builder.getInt64Ty(), 1));
2474 totalTrips = llvm::ConstantInt::get(builder.getInt64Ty(), 1);
2475 for (
unsigned d = 0; d < dims; ++d)
2476 totalTrips = builder.CreateMul(totalTrips, trips[d]);
2479 unsigned getDims()
const {
return dims; }
2480 llvm::ArrayRef<llvm::Value *> getLowerBounds()
const {
return lowerBounds; }
2481 llvm::ArrayRef<llvm::Value *> getUpperBounds()
const {
return upperBounds; }
2482 llvm::ArrayRef<llvm::Value *> getSteps()
const {
return steps; }
2483 llvm::ArrayRef<llvm::Value *> getTrips()
const {
return trips; }
2484 llvm::Value *getTotalTrips()
const {
return totalTrips; }
2489void TaskContextStructManager::generateTaskContextStruct() {
2490 if (privateDecls.empty())
2492 privateVarTypes.reserve(privateDecls.size());
2494 for (omp::PrivateClauseOp &privOp : privateDecls) {
2497 if (!privOp.readsFromMold())
2499 Type mlirType = privOp.getType();
2500 privateVarTypes.push_back(moduleTranslation.
convertType(mlirType));
2503 if (privateVarTypes.empty())
2506 structTy = llvm::StructType::get(moduleTranslation.
getLLVMContext(),
2509 llvm::DataLayout dataLayout =
2510 builder.GetInsertBlock()->getModule()->getDataLayout();
2511 llvm::Type *intPtrTy = builder.getIntPtrTy(dataLayout);
2512 llvm::Constant *allocSize = llvm::ConstantExpr::getSizeOf(structTy);
2515 structPtr = builder.CreateMalloc(intPtrTy, structTy, allocSize,
2517 "omp.task.context_ptr");
2520SmallVector<llvm::Value *> TaskContextStructManager::createGEPsToPrivateVars(
2521 llvm::Value *altStructPtr)
const {
2522 SmallVector<llvm::Value *> ret;
2525 ret.reserve(privateDecls.size());
2526 llvm::Value *zero = builder.getInt32(0);
2528 for (
auto privDecl : privateDecls) {
2529 if (!privDecl.readsFromMold()) {
2531 ret.push_back(
nullptr);
2534 llvm::Value *iVal = builder.getInt32(i);
2535 llvm::Value *gep = builder.CreateGEP(structTy, altStructPtr, {zero, iVal});
2542void TaskContextStructManager::createGEPsToPrivateVars() {
2544 assert(privateVarTypes.empty());
2548 llvmPrivateVarGEPs = createGEPsToPrivateVars(structPtr);
2551void TaskContextStructManager::freeStructPtr() {
2555 llvm::IRBuilderBase::InsertPointGuard guard{builder};
2557 builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
2558 builder.CreateFree(structPtr);
2562 llvm::OpenMPIRBuilder &ompBuilder,
2563 llvm::Value *affinityList, llvm::Value *
index,
2564 llvm::Value *addr, llvm::Value *len) {
2565 llvm::StructType *kmpTaskAffinityInfoTy =
2566 ompBuilder.getKmpTaskAffinityInfoTy();
2567 llvm::Value *entry = builder.CreateInBoundsGEP(
2568 kmpTaskAffinityInfoTy, affinityList,
index,
"omp.affinity.entry");
2570 addr = builder.CreatePtrToInt(addr, kmpTaskAffinityInfoTy->getElementType(0));
2571 len = builder.CreateIntCast(len, kmpTaskAffinityInfoTy->getElementType(1),
2573 llvm::Value *flags = builder.getInt32(0);
2575 builder.CreateStore(addr,
2576 builder.CreateStructGEP(kmpTaskAffinityInfoTy, entry, 0));
2577 builder.CreateStore(len,
2578 builder.CreateStructGEP(kmpTaskAffinityInfoTy, entry, 1));
2579 builder.CreateStore(flags,
2580 builder.CreateStructGEP(kmpTaskAffinityInfoTy, entry, 2));
2584 llvm::IRBuilderBase &builder,
2586 llvm::Value *affinityList) {
2587 for (
auto [i, affinityVar] : llvm::enumerate(affinityVars)) {
2588 auto entryOp = affinityVar.getDefiningOp<mlir::omp::AffinityEntryOp>();
2589 assert(entryOp &&
"affinity item must be omp.affinity_entry");
2591 llvm::Value *addr = moduleTranslation.
lookupValue(entryOp.getAddr());
2592 llvm::Value *len = moduleTranslation.
lookupValue(entryOp.getLen());
2593 assert(addr && len &&
"expect affinity addr and len to be non-null");
2595 affinityList, builder.getInt64(i), addr, len);
2599static mlir::LogicalResult
2602 llvm::IRBuilderBase &builder,
2604 llvm::Value *tmp = linearIV;
2605 for (
int d = (
int)iterInfo.getDims() - 1; d >= 0; --d) {
2606 llvm::Value *trip = iterInfo.getTrips()[d];
2608 llvm::Value *idx = builder.CreateURem(tmp, trip);
2610 tmp = builder.CreateUDiv(tmp, trip);
2613 llvm::Value *physIV = builder.CreateAdd(
2614 iterInfo.getLowerBounds()[d],
2615 builder.CreateMul(idx, iterInfo.getSteps()[d]),
"omp.it.phys_iv");
2621 moduleTranslation.
mapBlock(&iteratorRegionBlock, builder.GetInsertBlock());
2622 if (mlir::failed(moduleTranslation.
convertBlock(iteratorRegionBlock,
2625 return mlir::failure();
2627 return mlir::success();
2633static mlir::LogicalResult
2636 IteratorInfo &iterInfo, llvm::StringRef loopName,
2641 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
2643 auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy bodyIP,
2644 llvm::Value *linearIV) -> llvm::Error {
2645 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2646 builder.restoreIP(bodyIP);
2649 builder, moduleTranslation))) {
2650 return llvm::make_error<llvm::StringError>(
2651 "failed to convert iterator region", llvm::inconvertibleErrorCode());
2655 mlir::dyn_cast<mlir::omp::YieldOp>(iteratorRegionBlock.
getTerminator());
2656 assert(yield && yield.getResults().size() == 1 &&
2657 "expect omp.yield in iterator region to have one result");
2659 genStoreEntry(linearIV, yield);
2665 return llvm::Error::success();
2668 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2670 loc, iterInfo.getTotalTrips(), bodyGen, loopName);
2674 builder.restoreIP(*afterIP);
2676 return mlir::success();
2679static mlir::LogicalResult
2682 llvm::OpenMPIRBuilder::AffinityData &ad) {
2684 if (taskOp.getAffinityVars().empty() && taskOp.getIterated().empty()) {
2687 return mlir::success();
2691 llvm::StructType *kmpTaskAffinityInfoTy =
2694 auto allocateAffinityList = [&](llvm::Value *count) -> llvm::Value * {
2695 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2696 if (llvm::isa<llvm::Constant>(count) || llvm::isa<llvm::Argument>(count))
2698 return builder.CreateAlloca(kmpTaskAffinityInfoTy, count,
2699 "omp.affinity_list");
2702 auto createAffinity =
2703 [&](llvm::Value *count,
2704 llvm::Value *info) -> llvm::OpenMPIRBuilder::AffinityData {
2705 llvm::OpenMPIRBuilder::AffinityData ad{};
2706 ad.Count = builder.CreateTrunc(count, builder.getInt32Ty());
2708 builder.CreatePointerBitCastOrAddrSpaceCast(info, builder.getPtrTy(0));
2712 if (!taskOp.getAffinityVars().empty()) {
2713 llvm::Value *count = llvm::ConstantInt::get(
2714 builder.getInt64Ty(), taskOp.getAffinityVars().size());
2715 llvm::Value *list = allocateAffinityList(count);
2718 ads.emplace_back(createAffinity(count, list));
2721 if (!taskOp.getIterated().empty()) {
2722 for (
auto [i, iter] : llvm::enumerate(taskOp.getIterated())) {
2723 auto itersOp = iter.getDefiningOp<omp::IteratorOp>();
2724 assert(itersOp &&
"iterated value must be defined by omp.iterator");
2725 IteratorInfo iterInfo(itersOp, moduleTranslation, builder);
2726 llvm::Value *affList = allocateAffinityList(iterInfo.getTotalTrips());
2728 itersOp, builder, moduleTranslation, iterInfo,
"iterator",
2729 [&](llvm::Value *linearIV, mlir::omp::YieldOp yield) {
2730 auto entryOp = yield.getResults()[0]
2731 .getDefiningOp<mlir::omp::AffinityEntryOp>();
2732 assert(entryOp &&
"expect yield produce an affinity entry");
2739 affList, linearIV, addr, len);
2741 return llvm::failure();
2742 ads.emplace_back(createAffinity(iterInfo.getTotalTrips(), affList));
2746 llvm::Value *totalAffinityCount = builder.getInt32(0);
2747 for (
const auto &affinity : ads)
2748 totalAffinityCount = builder.CreateAdd(
2750 builder.CreateIntCast(affinity.Count, builder.getInt32Ty(),
2753 llvm::Value *affinityInfo = ads.front().Info;
2754 if (ads.size() > 1) {
2755 llvm::StructType *kmpTaskAffinityInfoTy =
2757 llvm::Value *affinityInfoElemSize = builder.getInt64(
2758 moduleTranslation.
getLLVMModule()->getDataLayout().getTypeAllocSize(
2759 kmpTaskAffinityInfoTy));
2761 llvm::Value *packedAffinityInfo = allocateAffinityList(totalAffinityCount);
2762 llvm::Value *packedAffinityInfoOffset = builder.getInt32(0);
2763 for (
const auto &affinity : ads) {
2764 llvm::Value *affinityCount = builder.CreateIntCast(
2765 affinity.Count, builder.getInt32Ty(),
false);
2766 llvm::Value *affinityCountInt64 = builder.CreateIntCast(
2767 affinityCount, builder.getInt64Ty(),
false);
2768 llvm::Value *affinityInfoSize =
2769 builder.CreateMul(affinityCountInt64, affinityInfoElemSize);
2771 llvm::Value *packedAffinityInfoIndex = builder.CreateIntCast(
2772 packedAffinityInfoOffset, kmpTaskAffinityInfoTy->getElementType(0),
2774 packedAffinityInfoIndex = builder.CreateInBoundsGEP(
2775 kmpTaskAffinityInfoTy, packedAffinityInfo, packedAffinityInfoIndex);
2777 builder.CreateMemCpy(
2778 packedAffinityInfoIndex, llvm::Align(1),
2779 builder.CreatePointerBitCastOrAddrSpaceCast(
2780 affinity.Info, builder.getPtrTy(packedAffinityInfoIndex->getType()
2781 ->getPointerAddressSpace())),
2782 llvm::Align(1), affinityInfoSize);
2784 packedAffinityInfoOffset =
2785 builder.CreateAdd(packedAffinityInfoOffset, affinityCount);
2788 affinityInfo = packedAffinityInfo;
2791 ad.Count = totalAffinityCount;
2792 ad.Info = affinityInfo;
2794 return mlir::success();
2800static mlir::LogicalResult
2803 std::optional<ArrayAttr> dependIteratedKinds,
2804 llvm::IRBuilderBase &builder,
2806 llvm::OpenMPIRBuilder::DependenciesInfo &taskDeps) {
2807 if (dependIterated.empty()) {
2810 return mlir::success();
2814 llvm::Type *dependInfoTy = ompBuilder.DependInfo;
2815 unsigned numLocator = dependVars.size();
2818 llvm::Value *totalCount =
2819 llvm::ConstantInt::get(builder.getInt64Ty(), numLocator);
2822 for (
auto iter : dependIterated) {
2823 auto itersOp = iter.getDefiningOp<mlir::omp::IteratorOp>();
2824 assert(itersOp &&
"depend_iterated value must be defined by omp.iterator");
2825 iterInfos.emplace_back(itersOp, moduleTranslation, builder);
2827 builder.CreateAdd(totalCount, iterInfos.back().getTotalTrips());
2832 llvm::Constant *allocSize = llvm::ConstantExpr::getSizeOf(dependInfoTy);
2833 llvm::Value *depArray =
2834 builder.CreateMalloc(ompBuilder.SizeTy, dependInfoTy, allocSize,
2835 totalCount,
nullptr,
".dep.arr.addr");
2838 if (numLocator > 0) {
2841 for (
auto [i, dd] : llvm::enumerate(dds)) {
2842 llvm::Value *idx = llvm::ConstantInt::get(builder.getInt64Ty(), i);
2843 llvm::Value *entry =
2844 builder.CreateInBoundsGEP(dependInfoTy, depArray, idx);
2845 ompBuilder.emitTaskDependency(builder, entry, dd);
2850 llvm::Value *offset =
2851 llvm::ConstantInt::get(builder.getInt64Ty(), numLocator);
2852 for (
auto [i, iterInfo] : llvm::enumerate(iterInfos)) {
2853 auto kindAttr = cast<mlir::omp::ClauseTaskDependAttr>(
2854 dependIteratedKinds->getValue()[i]);
2855 llvm::omp::RTLDependenceKindTy rtlKind =
2858 auto itersOp = dependIterated[i].getDefiningOp<mlir::omp::IteratorOp>();
2860 itersOp, builder, moduleTranslation, iterInfo,
"dep_iterator",
2861 [&](llvm::Value *linearIV, mlir::omp::YieldOp yield) {
2863 moduleTranslation.
lookupValue(yield.getResults()[0]);
2864 llvm::Value *idx = builder.CreateAdd(offset, linearIV);
2865 llvm::Value *entry =
2866 builder.CreateInBoundsGEP(dependInfoTy, depArray, idx);
2867 ompBuilder.emitTaskDependency(
2869 llvm::OpenMPIRBuilder::DependData{rtlKind, addr->getType(),
2872 return mlir::failure();
2875 offset = builder.CreateAdd(offset, iterInfo.getTotalTrips());
2878 taskDeps.DepArray = depArray;
2879 taskDeps.NumDeps = builder.CreateTrunc(totalCount, builder.getInt32Ty());
2880 return mlir::success();
2887 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2892 TaskContextStructManager taskStructMgr{builder, moduleTranslation,
2904 InsertPointTy allocaIP =
2909 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end());
2910 llvm::BasicBlock *taskStartBlock = llvm::BasicBlock::Create(
2911 builder.getContext(),
"omp.task.start",
2912 builder.GetInsertBlock()->getParent());
2913 llvm::Instruction *branchToTaskStartBlock = builder.CreateBr(taskStartBlock);
2914 builder.SetInsertPoint(branchToTaskStartBlock);
2917 llvm::BasicBlock *copyBlock =
2918 splitBB(builder,
true,
"omp.private.copy");
2919 llvm::BasicBlock *initBlock =
2920 splitBB(builder,
true,
"omp.private.init");
2936 moduleTranslation, allocaIP, deallocBlocks);
2939 builder.SetInsertPoint(initBlock->getTerminator());
2942 taskStructMgr.generateTaskContextStruct();
2949 taskStructMgr.createGEPsToPrivateVars();
2951 for (
auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVarAlloc] :
2954 taskStructMgr.getLLVMPrivateVarGEPs())) {
2956 if (!privDecl.readsFromMold())
2958 assert(llvmPrivateVarAlloc &&
2959 "reads from mold so shouldn't have been skipped");
2962 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2963 blockArg, llvmPrivateVarAlloc, initBlock);
2964 if (!privateVarOrErr)
2965 return handleError(privateVarOrErr, *taskOp.getOperation());
2974 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
2975 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2976 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
2978 llvmPrivateVarAlloc = builder.CreateLoad(privateVarOrErr.get()->getType(),
2979 llvmPrivateVarAlloc);
2981 assert(llvmPrivateVarAlloc->getType() ==
2982 moduleTranslation.
convertType(blockArg.getType()));
2992 taskOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
2993 taskStructMgr.getLLVMPrivateVarGEPs(), privateVarsInfo.
privatizers,
2994 taskOp.getPrivateNeedsBarrier())))
2995 return llvm::failure();
2997 llvm::OpenMPIRBuilder::AffinityData ad;
2999 return llvm::failure();
3002 builder.SetInsertPoint(taskStartBlock);
3005 [&](InsertPointTy allocaIP, InsertPointTy codegenIP,
3010 moduleTranslation, allocaIP, deallocBlocks);
3013 builder.restoreIP(codegenIP);
3015 llvm::BasicBlock *privInitBlock =
nullptr;
3017 for (
auto [i, zip] : llvm::enumerate(llvm::zip_equal(
3020 auto [blockArg, privDecl, mlirPrivVar] = zip;
3022 if (privDecl.readsFromMold())
3025 llvm::IRBuilderBase::InsertPointGuard guard(builder);
3026 llvm::Type *llvmAllocType =
3027 moduleTranslation.
convertType(privDecl.getType());
3028 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
3029 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
3030 llvmAllocType,
nullptr,
"omp.private.alloc");
3033 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
3034 blockArg, llvmPrivateVar, privInitBlock);
3035 if (!privateVarOrError)
3036 return privateVarOrError.takeError();
3037 moduleTranslation.
mapValue(blockArg, privateVarOrError.get());
3038 privateVarsInfo.
llvmVars[i] = privateVarOrError.get();
3041 taskStructMgr.createGEPsToPrivateVars();
3042 for (
auto [i, llvmPrivVar] :
3043 llvm::enumerate(taskStructMgr.getLLVMPrivateVarGEPs())) {
3045 assert(privateVarsInfo.
llvmVars[i] &&
3046 "This is added in the loop above");
3049 privateVarsInfo.
llvmVars[i] = llvmPrivVar;
3054 for (
auto [blockArg, llvmPrivateVar, privateDecl] :
3058 if (!privateDecl.readsFromMold())
3061 if (!mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
3062 llvmPrivateVar = builder.CreateLoad(
3063 moduleTranslation.
convertType(blockArg.getType()), llvmPrivateVar);
3065 assert(llvmPrivateVar->getType() ==
3066 moduleTranslation.
convertType(blockArg.getType()));
3067 moduleTranslation.
mapValue(blockArg, llvmPrivateVar);
3071 taskOp.getRegion(),
"omp.task.region", builder, moduleTranslation);
3072 if (failed(
handleError(continuationBlockOrError, *taskOp)))
3073 return llvm::make_error<PreviouslyReportedError>();
3075 builder.SetInsertPoint(continuationBlockOrError.get()->getTerminator());
3078 taskOp.getLoc(), privateVarsInfo)))
3079 return llvm::make_error<PreviouslyReportedError>();
3082 taskStructMgr.freeStructPtr();
3084 return llvm::Error::success();
3093 llvm::omp::Directive::OMPD_taskgroup);
3095 llvm::OpenMPIRBuilder::DependenciesInfo dependencies;
3096 if (failed(
buildDependData(taskOp.getDependVars(), taskOp.getDependKinds(),
3097 taskOp.getDependIterated(),
3098 taskOp.getDependIteratedKinds(), builder,
3099 moduleTranslation, dependencies)))
3102 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3103 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3105 ompLoc, allocaIP, deallocBlocks, bodyCB, !taskOp.getUntied(),
3107 moduleTranslation.
lookupValue(taskOp.getIfExpr()), dependencies, ad,
3108 taskOp.getMergeable(),
3109 moduleTranslation.
lookupValue(taskOp.getEventHandle()),
3110 moduleTranslation.
lookupValue(taskOp.getPriority()));
3118 builder.restoreIP(*afterIP);
3120 if (dependencies.DepArray)
3121 builder.CreateFree(dependencies.DepArray);
3130 llvm::IRBuilderBase &builder,
3138 loopWrapperOp.getRegion(),
"omp.taskloop.wrapper.region", builder,
3141 if (failed(
handleError(continuationBlockOrError, opInst)))
3144 builder.SetInsertPoint(continuationBlockOrError.get());
3152static llvm::Expected<llvm::Value *>
3155 llvm::IRBuilderBase &builder) {
3156 if (llvm::Value *mapped = moduleTranslation.
lookupValue(value))
3161 return llvm::make_error<llvm::StringError>(
3162 "value is a block argument and is not mapped",
3163 llvm::inconvertibleErrorCode());
3165 return llvm::make_error<llvm::StringError>(
3166 "unsupported op defining taskloop loop bound",
3167 llvm::inconvertibleErrorCode());
3177 if (!operandOrError)
3178 return operandOrError.takeError();
3179 moduleTranslation.
mapValue(operand, *operandOrError);
3180 mappingsToRemove.push_back(operand);
3184 return llvm::make_error<llvm::StringError>(
3185 "failed to convert op defining taskloop loop bound",
3186 llvm::inconvertibleErrorCode());
3189 assert(
result &&
"expected conversion of loop bound op to produce a value");
3193 mappingsToRemove.push_back(resultValue);
3195 for (
Value mappedValue : mappingsToRemove)
3204 llvm::Value *&lbVal, llvm::Value *&ubVal,
3205 llvm::Value *&stepVal) {
3213 return firstLbOrErr.takeError();
3215 llvm::Type *boundType = (*firstLbOrErr)->getType();
3216 ubVal = builder.getIntN(boundType->getIntegerBitWidth(), 1);
3217 if (loopOp.getCollapseNumLoops() > 1) {
3235 for (uint64_t i = 0; i < loopOp.getCollapseNumLoops(); i++) {
3237 i == 0 ? std::move(firstLbOrErr)
3241 return lbOrErr.takeError();
3243 upperBounds[i], moduleTranslation, builder);
3245 return ubOrErr.takeError();
3249 return stepOrErr.takeError();
3251 llvm::Value *loopLb = *lbOrErr;
3252 llvm::Value *loopUb = *ubOrErr;
3253 llvm::Value *loopStep = *stepOrErr;
3259 llvm::Value *loopLbMinusOne = builder.CreateSub(
3260 loopLb, builder.getIntN(boundType->getIntegerBitWidth(), 1));
3261 llvm::Value *loopUbMinusOne = builder.CreateSub(
3262 loopUb, builder.getIntN(boundType->getIntegerBitWidth(), 1));
3263 llvm::Value *boundsCmp = builder.CreateICmpSLT(loopLb, loopUb);
3264 llvm::Value *ubMinusLb = builder.CreateSub(loopUb, loopLbMinusOne);
3265 llvm::Value *lbMinusUb = builder.CreateSub(loopLb, loopUbMinusOne);
3266 llvm::Value *loopTripCount =
3267 builder.CreateSelect(boundsCmp, ubMinusLb, lbMinusUb);
3268 loopTripCount = builder.CreateBinaryIntrinsic(
3269 llvm::Intrinsic::abs, loopTripCount, builder.getFalse());
3273 llvm::Value *loopTripCountDivStep =
3274 builder.CreateSDiv(loopTripCount, loopStep);
3275 loopTripCountDivStep = builder.CreateBinaryIntrinsic(
3276 llvm::Intrinsic::abs, loopTripCountDivStep, builder.getFalse());
3277 llvm::Value *loopTripCountRem =
3278 builder.CreateSRem(loopTripCount, loopStep);
3279 loopTripCountRem = builder.CreateBinaryIntrinsic(
3280 llvm::Intrinsic::abs, loopTripCountRem, builder.getFalse());
3281 llvm::Value *needsRoundUp = builder.CreateICmpNE(
3283 builder.getIntN(loopTripCountRem->getType()->getIntegerBitWidth(),
3286 builder.CreateAdd(loopTripCountDivStep,
3287 builder.CreateZExtOrTrunc(
3288 needsRoundUp, loopTripCountDivStep->getType()));
3289 ubVal = builder.CreateMul(ubVal, loopTripCount);
3291 lbVal = builder.getIntN(boundType->getIntegerBitWidth(), 1);
3292 stepVal = builder.getIntN(boundType->getIntegerBitWidth(), 1);
3297 return ubOrErr.takeError();
3301 return stepOrErr.takeError();
3302 lbVal = *firstLbOrErr;
3304 stepVal = *stepOrErr;
3307 assert(lbVal !=
nullptr &&
"Expected value for lbVal");
3308 assert(ubVal !=
nullptr &&
"Expected value for ubVal");
3309 assert(stepVal !=
nullptr &&
"Expected value for stepVal");
3310 return llvm::Error::success();
3316 llvm::IRBuilderBase &builder,
3318 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
3320 omp::TaskloopWrapperOp loopWrapperOp = contextOp.getLoopOp();
3328 TaskContextStructManager taskStructMgr{builder, moduleTranslation,
3332 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3335 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end());
3336 llvm::BasicBlock *taskloopStartBlock = llvm::BasicBlock::Create(
3337 builder.getContext(),
"omp.taskloop.wrapper.start",
3338 builder.GetInsertBlock()->getParent());
3339 llvm::Instruction *branchToTaskloopStartBlock =
3340 builder.CreateBr(taskloopStartBlock);
3341 builder.SetInsertPoint(branchToTaskloopStartBlock);
3343 llvm::BasicBlock *copyBlock =
3344 splitBB(builder,
true,
"omp.private.copy");
3345 llvm::BasicBlock *initBlock =
3346 splitBB(builder,
true,
"omp.private.init");
3349 moduleTranslation, allocaIP, deallocBlocks);
3352 builder.SetInsertPoint(initBlock->getTerminator());
3355 taskStructMgr.generateTaskContextStruct();
3356 taskStructMgr.createGEPsToPrivateVars();
3358 llvmFirstPrivateVars.resize(privateVarsInfo.
blockArgs.size());
3360 for (
auto [i, zip] : llvm::enumerate(llvm::zip_equal(
3362 privateVarsInfo.
blockArgs, taskStructMgr.getLLVMPrivateVarGEPs()))) {
3363 auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVarAlloc] = zip;
3365 if (!privDecl.readsFromMold())
3367 assert(llvmPrivateVarAlloc &&
3368 "reads from mold so shouldn't have been skipped");
3371 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
3372 blockArg, llvmPrivateVarAlloc, initBlock);
3373 if (!privateVarOrErr)
3374 return handleError(privateVarOrErr, *contextOp.getOperation());
3376 llvmFirstPrivateVars[i] = privateVarOrErr.get();
3378 llvm::IRBuilderBase::InsertPointGuard guard(builder);
3379 builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
3381 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
3382 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
3383 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
3385 llvmPrivateVarAlloc = builder.CreateLoad(privateVarOrErr.get()->getType(),
3386 llvmPrivateVarAlloc);
3388 assert(llvmPrivateVarAlloc->getType() ==
3389 moduleTranslation.
convertType(blockArg.getType()));
3395 contextOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
3396 taskStructMgr.getLLVMPrivateVarGEPs(), privateVarsInfo.
privatizers,
3397 contextOp.getPrivateNeedsBarrier())))
3398 return llvm::failure();
3401 builder.SetInsertPoint(taskloopStartBlock);
3403 auto loopOp = cast<omp::LoopNestOp>(loopWrapperOp.getWrappedLoop());
3404 llvm::Value *lbVal =
nullptr;
3405 llvm::Value *ubVal =
nullptr;
3406 llvm::Value *stepVal =
nullptr;
3408 loopOp, builder, moduleTranslation, lbVal, ubVal, stepVal))
3412 [&](InsertPointTy allocaIP, InsertPointTy codegenIP,
3417 moduleTranslation, allocaIP, deallocBlocks);
3420 builder.restoreIP(codegenIP);
3422 llvm::BasicBlock *privInitBlock =
nullptr;
3424 for (
auto [i, zip] : llvm::enumerate(llvm::zip_equal(
3427 auto [blockArg, privDecl, mlirPrivVar] = zip;
3429 if (privDecl.readsFromMold())
3432 llvm::IRBuilderBase::InsertPointGuard guard(builder);
3433 llvm::Type *llvmAllocType =
3434 moduleTranslation.
convertType(privDecl.getType());
3435 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
3436 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
3437 llvmAllocType,
nullptr,
"omp.private.alloc");
3440 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
3441 blockArg, llvmPrivateVar, privInitBlock);
3442 if (!privateVarOrError)
3443 return privateVarOrError.takeError();
3444 moduleTranslation.
mapValue(blockArg, privateVarOrError.get());
3445 privateVarsInfo.
llvmVars[i] = privateVarOrError.get();
3448 taskStructMgr.createGEPsToPrivateVars();
3449 for (
auto [i, llvmPrivVar] :
3450 llvm::enumerate(taskStructMgr.getLLVMPrivateVarGEPs())) {
3452 assert(privateVarsInfo.
llvmVars[i] &&
3453 "This is added in the loop above");
3456 privateVarsInfo.
llvmVars[i] = llvmPrivVar;
3461 for (
auto [blockArg, llvmPrivateVar, privateDecl] :
3465 if (!privateDecl.readsFromMold())
3468 if (!mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
3469 llvmPrivateVar = builder.CreateLoad(
3470 moduleTranslation.
convertType(blockArg.getType()), llvmPrivateVar);
3472 assert(llvmPrivateVar->getType() ==
3473 moduleTranslation.
convertType(blockArg.getType()));
3474 moduleTranslation.
mapValue(blockArg, llvmPrivateVar);
3480 contextOp.getRegion(),
"omp.taskloop.context.region", builder,
3483 if (failed(
handleError(continuationBlockOrError, opInst)))
3484 return llvm::make_error<PreviouslyReportedError>();
3486 builder.SetInsertPoint(continuationBlockOrError.get()->getTerminator());
3494 contextOp.getLoc(), privateVarsInfo)))
3495 return llvm::make_error<PreviouslyReportedError>();
3498 taskStructMgr.freeStructPtr();
3500 return llvm::Error::success();
3506 auto taskDupCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP,
3507 llvm::Value *destPtr, llvm::Value *srcPtr)
3509 llvm::IRBuilderBase::InsertPointGuard guard(builder);
3510 builder.restoreIP(codegenIP);
3513 builder.getPtrTy(srcPtr->getType()->getPointerAddressSpace());
3515 builder.CreateLoad(ptrTy, srcPtr,
"omp.taskloop.context.src");
3517 TaskContextStructManager &srcStructMgr = taskStructMgr;
3518 TaskContextStructManager destStructMgr(builder, moduleTranslation,
3520 destStructMgr.generateTaskContextStruct();
3521 llvm::Value *dest = destStructMgr.getStructPtr();
3522 dest->setName(
"omp.taskloop.context.dest");
3523 builder.CreateStore(dest, destPtr);
3526 srcStructMgr.createGEPsToPrivateVars(src);
3528 destStructMgr.createGEPsToPrivateVars(dest);
3531 for (
auto [privDecl, mold, blockArg, llvmPrivateVarAlloc] :
3532 llvm::zip_equal(privateVarsInfo.
privatizers, srcGEPs,
3535 if (!privDecl.readsFromMold())
3537 assert(llvmPrivateVarAlloc &&
3538 "reads from mold so shouldn't have been skipped");
3541 initPrivateVar(builder, moduleTranslation, privDecl, mold, blockArg,
3542 llvmPrivateVarAlloc, builder.GetInsertBlock());
3543 if (!privateVarOrErr)
3544 return privateVarOrErr.takeError();
3553 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
3554 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
3555 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
3557 llvmPrivateVarAlloc = builder.CreateLoad(
3558 privateVarOrErr.get()->getType(), llvmPrivateVarAlloc);
3560 assert(llvmPrivateVarAlloc->getType() ==
3561 moduleTranslation.
convertType(blockArg.getType()));
3569 moduleTranslation, srcGEPs, destGEPs,
3571 contextOp.getPrivateNeedsBarrier())))
3572 return llvm::make_error<PreviouslyReportedError>();
3574 return builder.saveIP();
3582 llvm::Value *ifCond =
nullptr;
3583 llvm::Value *grainsize =
nullptr;
3585 mlir::Value grainsizeVal = contextOp.getGrainsize();
3586 mlir::Value numTasksVal = contextOp.getNumTasks();
3587 if (
Value ifVar = contextOp.getIfExpr())
3590 grainsize = moduleTranslation.
lookupValue(grainsizeVal);
3592 }
else if (numTasksVal) {
3593 grainsize = moduleTranslation.
lookupValue(numTasksVal);
3597 llvm::OpenMPIRBuilder::TaskDupCallbackTy taskDupOrNull =
nullptr;
3598 if (taskStructMgr.getStructPtr())
3599 taskDupOrNull = taskDupCB;
3609 llvm::omp::Directive::OMPD_taskgroup);
3611 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3612 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3614 ompLoc, allocaIP, deallocBlocks, bodyCB, loopInfo, lbVal, ubVal,
3615 stepVal, contextOp.getUntied(), ifCond, grainsize,
3616 contextOp.getNogroup(), sched,
3617 moduleTranslation.
lookupValue(contextOp.getFinal()),
3618 contextOp.getMergeable(),
3619 moduleTranslation.
lookupValue(contextOp.getPriority()),
3620 loopOp.getCollapseNumLoops(), taskDupOrNull,
3621 taskStructMgr.getStructPtr());
3628 builder.restoreIP(*afterIP);
3636 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
3640 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP,
3642 builder.restoreIP(codegenIP);
3644 builder, moduleTranslation)
3649 InsertPointTy allocaIP =
3651 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3652 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3654 ompLoc, allocaIP, deallocBlocks, bodyCB);
3659 builder.restoreIP(*afterIP);
3678 auto wsloopOp = cast<omp::WsloopOp>(opInst);
3682 auto loopOp = cast<omp::LoopNestOp>(wsloopOp.getWrappedLoop());
3684 assert(isByRef.size() == wsloopOp.getNumReductionVars());
3688 wsloopOp.getScheduleKind().value_or(omp::ClauseScheduleKind::Static);
3691 llvm::Value *step = moduleTranslation.
lookupValue(loopOp.getLoopSteps()[0]);
3692 llvm::Type *ivType = step->getType();
3693 llvm::Value *chunk =
nullptr;
3694 if (wsloopOp.getScheduleChunk()) {
3695 llvm::Value *chunkVar =
3696 moduleTranslation.
lookupValue(wsloopOp.getScheduleChunk());
3697 chunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
3700 omp::DistributeOp distributeOp =
nullptr;
3701 llvm::Value *distScheduleChunk =
nullptr;
3702 bool hasDistSchedule =
false;
3703 if (llvm::isa_and_present<omp::DistributeOp>(opInst.
getParentOp())) {
3704 distributeOp = cast<omp::DistributeOp>(opInst.
getParentOp());
3705 hasDistSchedule = distributeOp.getDistScheduleStatic();
3706 if (distributeOp.getDistScheduleChunkSize()) {
3707 llvm::Value *chunkVar = moduleTranslation.
lookupValue(
3708 distributeOp.getDistScheduleChunkSize());
3709 distScheduleChunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
3718 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3722 wsloopOp.getNumReductionVars());
3725 wsloopOp, builder, moduleTranslation, privateVarsInfo, allocaIP);
3732 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
3737 moduleTranslation, allocaIP, reductionDecls,
3738 privateReductionVariables, reductionVariableMap,
3739 deferredStores, isByRef)))
3748 wsloopOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
3750 wsloopOp.getPrivateNeedsBarrier())))
3753 assert(afterAllocas.get()->getSinglePredecessor());
3754 if (failed(initReductionVars(wsloopOp, reductionArgs, builder,
3756 afterAllocas.get()->getSinglePredecessor(),
3757 reductionDecls, privateReductionVariables,
3758 reductionVariableMap, isByRef, deferredStores)))
3762 bool isOrdered = wsloopOp.getOrdered().has_value();
3763 std::optional<omp::ScheduleModifier> scheduleMod = wsloopOp.getScheduleMod();
3764 bool isSimd = wsloopOp.getScheduleSimd();
3765 bool loopNeedsBarrier = !wsloopOp.getNowait();
3770 llvm::omp::WorksharingLoopType workshareLoopType =
3771 llvm::isa_and_present<omp::DistributeOp>(opInst.
getParentOp())
3772 ? llvm::omp::WorksharingLoopType::DistributeForStaticLoop
3773 : llvm::omp::WorksharingLoopType::ForStaticLoop;
3777 llvm::omp::Directive::OMPD_for);
3779 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3782 LinearClauseProcessor linearClauseProcessor;
3784 if (!wsloopOp.getLinearVars().empty()) {
3785 auto linearVarTypes = wsloopOp.getLinearVarTypes().value();
3787 linearClauseProcessor.registerType(moduleTranslation, linearVarType);
3789 for (
auto [idx, linearVar] : llvm::enumerate(wsloopOp.getLinearVars()))
3790 linearClauseProcessor.createLinearVar(
3791 builder, moduleTranslation, moduleTranslation.
lookupValue(linearVar),
3793 for (
mlir::Value linearStep : wsloopOp.getLinearStepVars())
3794 linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
3798 wsloopOp.getRegion(),
"omp.wsloop.region", builder, moduleTranslation);
3806 if (!wsloopOp.getLinearVars().empty()) {
3807 linearClauseProcessor.initLinearVar(builder, moduleTranslation,
3808 loopInfo->getPreheader());
3809 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
3811 builder.saveIP(), llvm::omp::OMPD_barrier);
3814 builder.restoreIP(*afterBarrierIP);
3815 linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
3816 loopInfo->getIndVar());
3817 linearClauseProcessor.splitLinearFiniBB(builder, loopInfo->getExit());
3820 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
3823 bool noLoopMode =
false;
3824 omp::TargetOp targetOp = wsloopOp->getParentOfType<mlir::omp::TargetOp>();
3826 Operation *targetCapturedOp = targetOp.getInnermostCapturedOmpOp();
3830 if (loopOp == targetCapturedOp) {
3831 if (targetOp.getKernelExecFlags(targetCapturedOp) ==
3832 omp::TargetExecMode::no_loop)
3837 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
3838 ompBuilder->applyWorkshareLoop(
3839 ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
3840 convertToScheduleKind(schedule), chunk, isSimd,
3841 scheduleMod == omp::ScheduleModifier::monotonic,
3842 scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
3843 workshareLoopType, noLoopMode, hasDistSchedule, distScheduleChunk);
3849 if (!wsloopOp.getLinearVars().empty()) {
3850 llvm::OpenMPIRBuilder::InsertPointTy oldIP = builder.saveIP();
3851 assert(loopInfo->getLastIter() &&
3852 "`lastiter` in CanonicalLoopInfo is nullptr");
3853 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
3854 linearClauseProcessor.finalizeLinearVar(builder, moduleTranslation,
3855 loopInfo->getLastIter());
3858 for (
size_t index = 0;
index < wsloopOp.getLinearVars().size();
index++)
3859 linearClauseProcessor.rewriteInPlace(builder,
"omp.loop_nest.region",
3861 builder.restoreIP(oldIP);
3869 wsloopOp, builder, moduleTranslation, allocaIP, reductionDecls,
3870 privateReductionVariables, isByRef, wsloopOp.getNowait(),
3875 wsloopOp.getLoc(), privateVarsInfo);
3882 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
3884 assert(isByRef.size() == opInst.getNumReductionVars());
3897 opInst.getNumReductionVars());
3901 [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
3904 opInst, builder, moduleTranslation, privateVarsInfo, allocaIP);
3906 return llvm::make_error<PreviouslyReportedError>();
3912 cast<omp::BlockArgOpenMPOpInterface>(*opInst).getReductionBlockArgs();
3915 InsertPointTy(allocaIP.getBlock(),
3916 allocaIP.getBlock()->getTerminator()->getIterator());
3919 opInst, reductionArgs, builder, moduleTranslation, allocaIP,
3920 reductionDecls, privateReductionVariables, reductionVariableMap,
3921 deferredStores, isByRef)))
3922 return llvm::make_error<PreviouslyReportedError>();
3924 assert(afterAllocas.get()->getSinglePredecessor());
3925 builder.restoreIP(codeGenIP);
3931 return llvm::make_error<PreviouslyReportedError>();
3934 opInst, builder, moduleTranslation, privateVarsInfo.
mlirVars,
3936 opInst.getPrivateNeedsBarrier())))
3937 return llvm::make_error<PreviouslyReportedError>();
3940 initReductionVars(opInst, reductionArgs, builder, moduleTranslation,
3941 afterAllocas.get()->getSinglePredecessor(),
3942 reductionDecls, privateReductionVariables,
3943 reductionVariableMap, isByRef, deferredStores)))
3944 return llvm::make_error<PreviouslyReportedError>();
3949 moduleTranslation, allocaIP, deallocBlocks);
3953 opInst.getRegion(),
"omp.par.region", builder, moduleTranslation);
3955 return regionBlock.takeError();
3958 if (opInst.getNumReductionVars() > 0) {
3963 owningReductionGenRefDataPtrGens;
3965 collectReductionInfo(opInst, builder, moduleTranslation, reductionDecls,
3967 owningReductionGenRefDataPtrGens,
3968 privateReductionVariables, reductionInfos, isByRef);
3971 builder.SetInsertPoint((*regionBlock)->getTerminator());
3974 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
3975 builder.SetInsertPoint(tempTerminator);
3977 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
3978 ompBuilder->createReductions(
3979 builder.saveIP(), allocaIP, reductionInfos, isByRef,
3981 if (!contInsertPoint)
3982 return contInsertPoint.takeError();
3984 if (!contInsertPoint->getBlock())
3985 return llvm::make_error<PreviouslyReportedError>();
3987 tempTerminator->eraseFromParent();
3988 builder.restoreIP(*contInsertPoint);
3991 return llvm::Error::success();
3994 auto privCB = [](InsertPointTy allocaIP, InsertPointTy codeGenIP,
3995 llvm::Value &, llvm::Value &val, llvm::Value *&replVal) {
4004 auto finiCB = [&](InsertPointTy codeGenIP) -> llvm::Error {
4005 InsertPointTy oldIP = builder.saveIP();
4006 builder.restoreIP(codeGenIP);
4011 llvm::transform(reductionDecls, std::back_inserter(reductionCleanupRegions),
4012 [](omp::DeclareReductionOp reductionDecl) {
4013 return &reductionDecl.getCleanupRegion();
4016 reductionCleanupRegions, privateReductionVariables,
4017 moduleTranslation, builder,
"omp.reduction.cleanup")))
4018 return llvm::createStringError(
4019 "failed to inline `cleanup` region of `omp.declare_reduction`");
4022 opInst.getLoc(), privateVarsInfo)))
4023 return llvm::make_error<PreviouslyReportedError>();
4027 if (isCancellable) {
4028 auto IPOrErr = ompBuilder->createBarrier(
4029 llvm::OpenMPIRBuilder::LocationDescription(builder),
4030 llvm::omp::Directive::OMPD_unknown,
4034 return IPOrErr.takeError();
4037 builder.restoreIP(oldIP);
4038 return llvm::Error::success();
4041 llvm::Value *ifCond =
nullptr;
4042 if (
auto ifVar = opInst.getIfExpr())
4044 llvm::Value *numThreads =
nullptr;
4045 if (!opInst.getNumThreadsVars().empty())
4046 numThreads = moduleTranslation.
lookupValue(opInst.getNumThreads(0));
4047 auto pbKind = llvm::omp::OMP_PROC_BIND_default;
4048 if (
auto bind = opInst.getProcBindKind())
4052 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4054 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4056 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4057 ompBuilder->createParallel(ompLoc, allocaIP, deallocBlocks, bodyGenCB,
4058 privCB, finiCB, ifCond, numThreads, pbKind,
4064 builder.restoreIP(*afterIP);
4069static llvm::omp::OrderKind
4072 return llvm::omp::OrderKind::OMP_ORDER_unknown;
4074 case omp::ClauseOrderKind::Concurrent:
4075 return llvm::omp::OrderKind::OMP_ORDER_concurrent;
4077 llvm_unreachable(
"Unknown ClauseOrderKind kind");
4085 auto simdOp = cast<omp::SimdOp>(opInst);
4093 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
4096 simdOp.getNumReductionVars());
4101 assert(isByRef.size() == simdOp.getNumReductionVars());
4103 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4107 simdOp, builder, moduleTranslation, privateVarsInfo, allocaIP);
4112 LinearClauseProcessor linearClauseProcessor;
4114 if (!simdOp.getLinearVars().empty()) {
4115 auto linearVarTypes = simdOp.getLinearVarTypes().value();
4117 linearClauseProcessor.registerType(moduleTranslation, linearVarType);
4118 for (
auto [idx, linearVar] : llvm::enumerate(simdOp.getLinearVars())) {
4119 bool isImplicit =
false;
4120 for (
auto [mlirPrivVar, llvmPrivateVar] : llvm::zip_equal(
4124 if (linearVar == mlirPrivVar) {
4126 linearClauseProcessor.createLinearVar(builder, moduleTranslation,
4127 llvmPrivateVar, idx);
4133 linearClauseProcessor.createLinearVar(
4134 builder, moduleTranslation,
4137 for (
mlir::Value linearStep : simdOp.getLinearStepVars())
4138 linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
4142 moduleTranslation, allocaIP, reductionDecls,
4143 privateReductionVariables, reductionVariableMap,
4144 deferredStores, isByRef)))
4155 assert(afterAllocas.get()->getSinglePredecessor());
4156 if (failed(initReductionVars(simdOp, reductionArgs, builder,
4158 afterAllocas.get()->getSinglePredecessor(),
4159 reductionDecls, privateReductionVariables,
4160 reductionVariableMap, isByRef, deferredStores)))
4163 llvm::ConstantInt *simdlen =
nullptr;
4164 if (std::optional<uint64_t> simdlenVar = simdOp.getSimdlen())
4165 simdlen = builder.getInt64(simdlenVar.value());
4167 llvm::ConstantInt *safelen =
nullptr;
4168 if (std::optional<uint64_t> safelenVar = simdOp.getSafelen())
4169 safelen = builder.getInt64(safelenVar.value());
4171 llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars;
4174 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
4175 std::optional<ArrayAttr> alignmentValues = simdOp.getAlignments();
4177 for (
size_t i = 0; i < operands.size(); ++i) {
4178 llvm::Value *alignment =
nullptr;
4179 llvm::Value *llvmVal = moduleTranslation.
lookupValue(operands[i]);
4180 llvm::Type *ty = llvmVal->getType();
4182 auto intAttr = cast<IntegerAttr>((*alignmentValues)[i]);
4183 alignment = builder.getInt64(intAttr.getInt());
4184 assert(ty->isPointerTy() &&
"Invalid type for aligned variable");
4185 assert(alignment &&
"Invalid alignment value");
4189 if (!intAttr.getValue().isPowerOf2())
4192 auto curInsert = builder.saveIP();
4193 builder.SetInsertPoint(sourceBlock);
4194 llvmVal = builder.CreateLoad(ty, llvmVal);
4195 builder.restoreIP(curInsert);
4196 alignedVars[llvmVal] = alignment;
4200 simdOp.getRegion(),
"omp.simd.region", builder, moduleTranslation);
4207 if (simdOp.getLinearVars().size()) {
4208 linearClauseProcessor.initLinearVar(builder, moduleTranslation,
4209 loopInfo->getPreheader());
4211 linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
4212 loopInfo->getIndVar());
4214 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
4216 ompBuilder->applySimd(loopInfo, alignedVars,
4218 ? moduleTranslation.
lookupValue(simdOp.getIfExpr())
4220 order, simdlen, safelen);
4222 linearClauseProcessor.emitStoresForLinearVar(builder);
4225 bool hasOrderedRegions =
false;
4226 simdOp.getRegion().walk([&](omp::OrderedRegionOp orderedOp) {
4227 hasOrderedRegions =
true;
4231 for (
size_t index = 0;
index < simdOp.getLinearVars().size();
index++) {
4232 linearClauseProcessor.rewriteInPlace(builder,
"omp.loop_nest.region",
4234 if (hasOrderedRegions) {
4236 linearClauseProcessor.rewriteInPlace(builder,
"omp.ordered.region",
4239 linearClauseProcessor.rewriteInPlace(builder,
"omp_region.finalize",
4248 for (
auto [i, tuple] : llvm::enumerate(
4249 llvm::zip(reductionDecls, isByRef, simdOp.getReductionVars(),
4250 privateReductionVariables))) {
4251 auto [decl, byRef, reductionVar, privateReductionVar] = tuple;
4253 OwningReductionGen gen =
makeReductionGen(decl, builder, moduleTranslation);
4254 llvm::Value *originalVariable = moduleTranslation.
lookupValue(reductionVar);
4255 llvm::Type *reductionType = moduleTranslation.
convertType(decl.getType());
4259 llvm::Value *redValue = originalVariable;
4262 builder.CreateLoad(reductionType, redValue,
"red.value." + Twine(i));
4263 llvm::Value *privateRedValue = builder.CreateLoad(
4264 reductionType, privateReductionVar,
"red.private.value." + Twine(i));
4265 llvm::Value *reduced;
4267 auto res = gen(builder.saveIP(), redValue, privateRedValue, reduced);
4270 builder.restoreIP(res.get());
4274 builder.CreateStore(reduced, originalVariable);
4279 llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
4280 [](omp::DeclareReductionOp reductionDecl) {
4281 return &reductionDecl.getCleanupRegion();
4284 moduleTranslation, builder,
4285 "omp.reduction.cleanup")))
4297 auto loopOp = cast<omp::LoopNestOp>(opInst);
4303 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4308 auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip,
4309 llvm::Value *iv) -> llvm::Error {
4312 loopOp.getRegion().front().getArgument(loopInfos.size()), iv);
4317 bodyInsertPoints.push_back(ip);
4319 if (loopInfos.size() != loopOp.getNumLoops() - 1)
4320 return llvm::Error::success();
4323 builder.restoreIP(ip);
4325 loopOp.getRegion(),
"omp.loop_nest.region", builder, moduleTranslation);
4327 return regionBlock.takeError();
4329 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
4330 return llvm::Error::success();
4338 for (
unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) {
4339 llvm::Value *lowerBound =
4340 moduleTranslation.
lookupValue(loopOp.getLoopLowerBounds()[i]);
4341 llvm::Value *upperBound =
4342 moduleTranslation.
lookupValue(loopOp.getLoopUpperBounds()[i]);
4343 llvm::Value *step = moduleTranslation.
lookupValue(loopOp.getLoopSteps()[i]);
4348 llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc;
4349 llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP;
4351 loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back(),
4353 computeIP = loopInfos.front()->getPreheaderIP();
4357 ompBuilder->createCanonicalLoop(
4358 loc, bodyGen, lowerBound, upperBound, step,
4359 true, loopOp.getLoopInclusive(), computeIP);
4364 loopInfos.push_back(*loopResult);
4367 llvm::OpenMPIRBuilder::InsertPointTy afterIP =
4368 loopInfos.front()->getAfterIP();
4371 if (
const auto &tiles = loopOp.getTileSizes()) {
4372 llvm::Type *ivType = loopInfos.front()->getIndVarType();
4375 for (
auto tile : tiles.value()) {
4376 llvm::Value *tileVal = llvm::ConstantInt::get(ivType,
tile);
4377 tileSizes.push_back(tileVal);
4380 std::vector<llvm::CanonicalLoopInfo *> newLoops =
4381 ompBuilder->tileLoops(ompLoc.DL, loopInfos, tileSizes);
4385 llvm::BasicBlock *afterBB = newLoops.front()->getAfter();
4386 llvm::BasicBlock *afterAfterBB = afterBB->getSingleSuccessor();
4387 afterIP = {afterAfterBB, afterAfterBB->begin()};
4391 for (
const auto &newLoop : newLoops)
4392 loopInfos.push_back(newLoop);
4396 const auto &numCollapse = loopOp.getCollapseNumLoops();
4398 loopInfos.begin(), loopInfos.begin() + (numCollapse));
4400 auto newTopLoopInfo =
4401 ompBuilder->collapseLoops(ompLoc.DL, collapseLoopInfos, {});
4403 assert(newTopLoopInfo &&
"New top loop information is missing");
4404 moduleTranslation.
stackWalk<OpenMPLoopInfoStackFrame>(
4405 [&](OpenMPLoopInfoStackFrame &frame) {
4406 frame.loopInfo = newTopLoopInfo;
4414 builder.restoreIP(afterIP);
4424 llvm::OpenMPIRBuilder::LocationDescription loopLoc(builder);
4425 Value loopIV = op.getInductionVar();
4426 Value loopTC = op.getTripCount();
4428 llvm::Value *llvmTC = moduleTranslation.
lookupValue(loopTC);
4431 ompBuilder->createCanonicalLoop(
4433 [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *llvmIV) {
4436 moduleTranslation.
mapValue(loopIV, llvmIV);
4438 builder.restoreIP(ip);
4443 return bodyGenStatus.takeError();
4445 llvmTC,
"omp.loop");
4447 return op.emitError(llvm::toString(llvmOrError.takeError()));
4449 llvm::CanonicalLoopInfo *llvmCLI = *llvmOrError;
4450 llvm::IRBuilderBase::InsertPoint afterIP = llvmCLI->getAfterIP();
4451 builder.restoreIP(afterIP);
4454 if (
Value cli = op.getCli())
4467 Value applyee = op.getApplyee();
4468 assert(applyee &&
"Loop to apply unrolling on required");
4470 llvm::CanonicalLoopInfo *consBuilderCLI =
4472 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
4473 ompBuilder->unrollLoopHeuristic(loc.DL, consBuilderCLI);
4481static LogicalResult
applyTile(omp::TileOp op, llvm::IRBuilderBase &builder,
4484 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
4489 for (
Value size : op.getSizes()) {
4490 llvm::Value *translatedSize = moduleTranslation.
lookupValue(size);
4491 assert(translatedSize &&
4492 "sizes clause arguments must already be translated");
4493 translatedSizes.push_back(translatedSize);
4496 for (
Value applyee : op.getApplyees()) {
4497 llvm::CanonicalLoopInfo *consBuilderCLI =
4499 assert(applyee &&
"Canonical loop must already been translated");
4500 translatedLoops.push_back(consBuilderCLI);
4503 auto generatedLoops =
4504 ompBuilder->tileLoops(loc.DL, translatedLoops, translatedSizes);
4505 if (!op.getGeneratees().empty()) {
4506 for (
auto [mlirLoop,
genLoop] :
4507 zip_equal(op.getGeneratees(), generatedLoops))
4512 for (
Value applyee : op.getApplyees())
4520static LogicalResult
applyFuse(omp::FuseOp op, llvm::IRBuilderBase &builder,
4523 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
4527 for (
size_t i = 0; i < op.getApplyees().size(); i++) {
4528 Value applyee = op.getApplyees()[i];
4529 llvm::CanonicalLoopInfo *consBuilderCLI =
4531 assert(applyee &&
"Canonical loop must already been translated");
4532 if (op.getFirst().has_value() && i < op.getFirst().value() - 1)
4533 beforeFuse.push_back(consBuilderCLI);
4534 else if (op.getCount().has_value() &&
4535 i >= op.getFirst().value() + op.getCount().value() - 1)
4536 afterFuse.push_back(consBuilderCLI);
4538 toFuse.push_back(consBuilderCLI);
4541 (op.getGeneratees().empty() ||
4542 beforeFuse.size() + afterFuse.size() + 1 == op.getGeneratees().size()) &&
4543 "Wrong number of generatees");
4546 auto generatedLoop = ompBuilder->fuseLoops(loc.DL, toFuse);
4547 if (!op.getGeneratees().empty()) {
4549 for (; i < beforeFuse.size(); i++)
4550 moduleTranslation.
mapOmpLoop(op.getGeneratees()[i], beforeFuse[i]);
4551 moduleTranslation.
mapOmpLoop(op.getGeneratees()[i++], generatedLoop);
4552 for (; i < afterFuse.size(); i++)
4553 moduleTranslation.
mapOmpLoop(op.getGeneratees()[i], afterFuse[i]);
4557 for (
Value applyee : op.getApplyees())
4564static llvm::AtomicOrdering
4567 return llvm::AtomicOrdering::Monotonic;
4570 case omp::ClauseMemoryOrderKind::Seq_cst:
4571 return llvm::AtomicOrdering::SequentiallyConsistent;
4572 case omp::ClauseMemoryOrderKind::Acq_rel:
4573 return llvm::AtomicOrdering::AcquireRelease;
4574 case omp::ClauseMemoryOrderKind::Acquire:
4575 return llvm::AtomicOrdering::Acquire;
4576 case omp::ClauseMemoryOrderKind::Release:
4577 return llvm::AtomicOrdering::Release;
4578 case omp::ClauseMemoryOrderKind::Relaxed:
4579 return llvm::AtomicOrdering::Monotonic;
4581 llvm_unreachable(
"Unknown ClauseMemoryOrderKind kind");
4588 auto readOp = cast<omp::AtomicReadOp>(opInst);
4593 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4596 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4599 llvm::Value *x = moduleTranslation.
lookupValue(readOp.getX());
4600 llvm::Value *v = moduleTranslation.
lookupValue(readOp.getV());
4602 llvm::Type *elementType =
4603 moduleTranslation.
convertType(readOp.getElementType());
4605 llvm::OpenMPIRBuilder::AtomicOpValue V = {v, elementType,
false,
false};
4606 llvm::OpenMPIRBuilder::AtomicOpValue X = {x, elementType,
false,
false};
4607 builder.restoreIP(ompBuilder->createAtomicRead(ompLoc, X, V, AO, allocaIP));
4615 auto writeOp = cast<omp::AtomicWriteOp>(opInst);
4620 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4623 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4625 llvm::Value *expr = moduleTranslation.
lookupValue(writeOp.getExpr());
4626 llvm::Value *dest = moduleTranslation.
lookupValue(writeOp.getX());
4627 llvm::Type *ty = moduleTranslation.
convertType(writeOp.getExpr().getType());
4628 llvm::OpenMPIRBuilder::AtomicOpValue x = {dest, ty,
false,
4631 ompBuilder->createAtomicWrite(ompLoc, x, expr, ao, allocaIP));
4639 .Case([&](LLVM::AddOp) {
return llvm::AtomicRMWInst::BinOp::Add; })
4640 .Case([&](LLVM::SubOp) {
return llvm::AtomicRMWInst::BinOp::Sub; })
4641 .Case([&](LLVM::AndOp) {
return llvm::AtomicRMWInst::BinOp::And; })
4642 .Case([&](LLVM::OrOp) {
return llvm::AtomicRMWInst::BinOp::Or; })
4643 .Case([&](LLVM::XOrOp) {
return llvm::AtomicRMWInst::BinOp::Xor; })
4644 .Case([&](LLVM::UMaxOp) {
return llvm::AtomicRMWInst::BinOp::UMax; })
4645 .Case([&](LLVM::UMinOp) {
return llvm::AtomicRMWInst::BinOp::UMin; })
4646 .Case([&](LLVM::FAddOp) {
return llvm::AtomicRMWInst::BinOp::FAdd; })
4647 .Case([&](LLVM::FSubOp) {
return llvm::AtomicRMWInst::BinOp::FSub; })
4648 .Default(llvm::AtomicRMWInst::BinOp::BAD_BINOP);
4652 bool &isIgnoreDenormalMode,
4653 bool &isFineGrainedMemory,
4654 bool &isRemoteMemory) {
4655 isIgnoreDenormalMode =
false;
4656 isFineGrainedMemory =
false;
4657 isRemoteMemory =
false;
4658 if (atomicUpdateOp &&
4659 atomicUpdateOp->hasAttr(atomicUpdateOp.getAtomicControlAttrName())) {
4660 mlir::omp::AtomicControlAttr atomicControlAttr =
4661 atomicUpdateOp.getAtomicControlAttr();
4662 isIgnoreDenormalMode = atomicControlAttr.getIgnoreDenormalMode();
4663 isFineGrainedMemory = atomicControlAttr.getFineGrainedMemory();
4664 isRemoteMemory = atomicControlAttr.getRemoteMemory();
4671 llvm::IRBuilderBase &builder,
4678 auto &innerOpList = opInst.getRegion().front().getOperations();
4679 bool isXBinopExpr{
false};
4680 llvm::AtomicRMWInst::BinOp binop;
4682 llvm::Value *llvmExpr =
nullptr;
4683 llvm::Value *llvmX =
nullptr;
4684 llvm::Type *llvmXElementType =
nullptr;
4685 if (innerOpList.size() == 2) {
4691 opInst.getRegion().getArgument(0))) {
4692 return opInst.emitError(
"no atomic update operation with region argument"
4693 " as operand found inside atomic.update region");
4696 isXBinopExpr = innerOp.
getOperand(0) == opInst.getRegion().getArgument(0);
4698 llvmExpr = moduleTranslation.
lookupValue(mlirExpr);
4702 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
4704 llvmX = moduleTranslation.
lookupValue(opInst.getX());
4706 opInst.getRegion().getArgument(0).getType());
4707 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
4711 llvm::AtomicOrdering atomicOrdering =
4716 [&opInst, &moduleTranslation](
4717 llvm::Value *atomicx,
4720 moduleTranslation.
mapValue(*opInst.getRegion().args_begin(), atomicx);
4721 moduleTranslation.
mapBlock(&bb, builder.GetInsertBlock());
4722 if (failed(moduleTranslation.
convertBlock(bb,
true, builder)))
4723 return llvm::make_error<PreviouslyReportedError>();
4725 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.
getTerminator());
4726 assert(yieldop && yieldop.getResults().size() == 1 &&
4727 "terminator must be omp.yield op and it must have exactly one "
4729 return moduleTranslation.
lookupValue(yieldop.getResults()[0]);
4732 bool isIgnoreDenormalMode;
4733 bool isFineGrainedMemory;
4734 bool isRemoteMemory;
4739 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4740 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4741 ompBuilder->createAtomicUpdate(ompLoc, allocaIP, llvmAtomicX, llvmExpr,
4742 atomicOrdering, binop, updateFn,
4743 isXBinopExpr, isIgnoreDenormalMode,
4744 isFineGrainedMemory, isRemoteMemory);
4749 builder.restoreIP(*afterIP);
4755 llvm::IRBuilderBase &builder,
4762 bool isXBinopExpr =
false, isPostfixUpdate =
false;
4763 llvm::AtomicRMWInst::BinOp binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
4765 omp::AtomicUpdateOp atomicUpdateOp = atomicCaptureOp.getAtomicUpdateOp();
4766 omp::AtomicWriteOp atomicWriteOp = atomicCaptureOp.getAtomicWriteOp();
4768 assert((atomicUpdateOp || atomicWriteOp) &&
4769 "internal op must be an atomic.update or atomic.write op");
4771 if (atomicWriteOp) {
4772 isPostfixUpdate =
true;
4773 mlirExpr = atomicWriteOp.getExpr();
4775 isPostfixUpdate = atomicCaptureOp.getSecondOp() ==
4776 atomicCaptureOp.getAtomicUpdateOp().getOperation();
4777 auto &innerOpList = atomicUpdateOp.getRegion().front().getOperations();
4780 if (innerOpList.size() == 2) {
4783 atomicUpdateOp.getRegion().getArgument(0))) {
4784 return atomicUpdateOp.emitError(
4785 "no atomic update operation with region argument"
4786 " as operand found inside atomic.update region");
4790 innerOp.
getOperand(0) == atomicUpdateOp.getRegion().getArgument(0);
4793 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
4797 llvm::Value *llvmExpr = moduleTranslation.
lookupValue(mlirExpr);
4798 llvm::Value *llvmX =
4799 moduleTranslation.
lookupValue(atomicCaptureOp.getAtomicReadOp().getX());
4800 llvm::Value *llvmV =
4801 moduleTranslation.
lookupValue(atomicCaptureOp.getAtomicReadOp().getV());
4802 llvm::Type *llvmXElementType = moduleTranslation.
convertType(
4803 atomicCaptureOp.getAtomicReadOp().getElementType());
4804 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
4807 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicV = {llvmV, llvmXElementType,
4811 llvm::AtomicOrdering atomicOrdering =
4815 [&](llvm::Value *atomicx,
4818 return moduleTranslation.
lookupValue(atomicWriteOp.getExpr());
4819 Block &bb = *atomicUpdateOp.getRegion().
begin();
4820 moduleTranslation.
mapValue(*atomicUpdateOp.getRegion().args_begin(),
4822 moduleTranslation.
mapBlock(&bb, builder.GetInsertBlock());
4823 if (failed(moduleTranslation.
convertBlock(bb,
true, builder)))
4824 return llvm::make_error<PreviouslyReportedError>();
4826 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.
getTerminator());
4827 assert(yieldop && yieldop.getResults().size() == 1 &&
4828 "terminator must be omp.yield op and it must have exactly one "
4830 return moduleTranslation.
lookupValue(yieldop.getResults()[0]);
4833 bool isIgnoreDenormalMode;
4834 bool isFineGrainedMemory;
4835 bool isRemoteMemory;
4837 isFineGrainedMemory, isRemoteMemory);
4840 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4841 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4842 ompBuilder->createAtomicCapture(
4843 ompLoc, allocaIP, llvmAtomicX, llvmAtomicV, llvmExpr, atomicOrdering,
4844 binop, updateFn, atomicUpdateOp, isPostfixUpdate, isXBinopExpr,
4845 isIgnoreDenormalMode, isFineGrainedMemory, isRemoteMemory);
4847 if (failed(
handleError(afterIP, *atomicCaptureOp)))
4850 builder.restoreIP(*afterIP);
4855 omp::ClauseCancellationConstructType directive) {
4856 switch (directive) {
4857 case omp::ClauseCancellationConstructType::Loop:
4858 return llvm::omp::Directive::OMPD_for;
4859 case omp::ClauseCancellationConstructType::Parallel:
4860 return llvm::omp::Directive::OMPD_parallel;
4861 case omp::ClauseCancellationConstructType::Sections:
4862 return llvm::omp::Directive::OMPD_sections;
4863 case omp::ClauseCancellationConstructType::Taskgroup:
4864 return llvm::omp::Directive::OMPD_taskgroup;
4866 llvm_unreachable(
"Unhandled cancellation construct type");
4875 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4878 llvm::Value *ifCond =
nullptr;
4879 if (
Value ifVar = op.getIfExpr())
4882 llvm::omp::Directive cancelledDirective =
4885 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4886 ompBuilder->createCancel(ompLoc, ifCond, cancelledDirective);
4888 if (failed(
handleError(afterIP, *op.getOperation())))
4891 builder.restoreIP(afterIP.get());
4898 llvm::IRBuilderBase &builder,
4903 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4906 llvm::omp::Directive cancelledDirective =
4909 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4910 ompBuilder->createCancellationPoint(ompLoc, cancelledDirective);
4912 if (failed(
handleError(afterIP, *op.getOperation())))
4915 builder.restoreIP(afterIP.get());
4925 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4927 auto threadprivateOp = cast<omp::ThreadprivateOp>(opInst);
4932 Value symAddr = threadprivateOp.getSymAddr();
4935 if (
auto asCast = dyn_cast<LLVM::AddrSpaceCastOp>(symOp))
4938 if (!isa<LLVM::AddressOfOp>(symOp))
4939 return opInst.
emitError(
"Addressing symbol not found");
4940 LLVM::AddressOfOp addressOfOp = dyn_cast<LLVM::AddressOfOp>(symOp);
4942 LLVM::GlobalOp global =
4943 addressOfOp.getGlobal(moduleTranslation.
symbolTable());
4944 llvm::GlobalValue *globalValue = moduleTranslation.
lookupGlobal(global);
4945 llvm::Type *type = globalValue->getValueType();
4946 llvm::TypeSize typeSize =
4947 builder.GetInsertBlock()->getModule()->getDataLayout().getTypeStoreSize(
4949 llvm::ConstantInt *size = builder.getInt64(typeSize.getFixedValue());
4950 llvm::Value *callInst = ompBuilder->createCachedThreadPrivate(
4951 ompLoc, globalValue, size, global.getSymName() +
".cache");
4957static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind
4959 switch (deviceClause) {
4960 case mlir::omp::DeclareTargetDeviceType::host:
4961 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseHost;
4963 case mlir::omp::DeclareTargetDeviceType::nohost:
4964 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseNoHost;
4966 case mlir::omp::DeclareTargetDeviceType::any:
4967 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseAny;
4970 llvm_unreachable(
"unhandled device clause");
4973static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind
4975 mlir::omp::DeclareTargetCaptureClause captureClause) {
4976 switch (captureClause) {
4977 case mlir::omp::DeclareTargetCaptureClause::to:
4978 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo;
4979 case mlir::omp::DeclareTargetCaptureClause::link:
4980 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink;
4981 case mlir::omp::DeclareTargetCaptureClause::enter:
4982 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter;
4983 case mlir::omp::DeclareTargetCaptureClause::none:
4984 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryNone;
4986 llvm_unreachable(
"unhandled capture clause");
4991 if (
auto addrCast = dyn_cast_if_present<LLVM::AddrSpaceCastOp>(op))
4993 if (
auto addressOfOp = dyn_cast_if_present<LLVM::AddressOfOp>(op)) {
4994 auto modOp = addressOfOp->getParentOfType<mlir::ModuleOp>();
4995 return modOp.lookupSymbol(addressOfOp.getGlobalName());
5002 if (
auto addrCast = dyn_cast_if_present<LLVM::AddrSpaceCastOp>(op))
5003 value = addrCast.getOperand();
5020static llvm::SmallString<64>
5022 llvm::OpenMPIRBuilder &ompBuilder,
5023 llvm::vfs::FileSystem &vfs) {
5025 llvm::raw_svector_ostream os(suffix);
5028 auto fileInfoCallBack = [&loc]() {
5029 return std::pair<std::string, uint64_t>(
5030 llvm::StringRef(loc.getFilename()), loc.getLine());
5035 ompBuilder.getTargetEntryUniqueInfo(fileInfoCallBack, vfs).FileID);
5037 os <<
"_decl_tgt_ref_ptr";
5043 if (
auto declareTargetGlobal =
5044 dyn_cast_if_present<omp::DeclareTargetInterface>(
5046 if (declareTargetGlobal.getDeclareTargetCaptureClause() ==
5047 omp::DeclareTargetCaptureClause::link)
5053 if (
auto declareTargetGlobal =
5054 dyn_cast_if_present<omp::DeclareTargetInterface>(
5056 if (declareTargetGlobal.getDeclareTargetCaptureClause() ==
5057 omp::DeclareTargetCaptureClause::to ||
5058 declareTargetGlobal.getDeclareTargetCaptureClause() ==
5059 omp::DeclareTargetCaptureClause::enter)
5073 if (
auto declareTargetGlobal =
5074 dyn_cast<omp::DeclareTargetInterface>(gOp.getOperation())) {
5077 if ((declareTargetGlobal.getDeclareTargetCaptureClause() ==
5078 omp::DeclareTargetCaptureClause::link) ||
5079 (declareTargetGlobal.getDeclareTargetCaptureClause() ==
5080 omp::DeclareTargetCaptureClause::to &&
5081 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
5085 if (gOp.getSymName().contains(suffix))
5090 (gOp.getSymName().str() + suffix.str()).str());
5099struct MapInfosTy : llvm::OpenMPIRBuilder::MapInfosTy {
5100 SmallVector<Operation *, 4> Mappers;
5103 void append(MapInfosTy &curInfo) {
5104 Mappers.append(curInfo.Mappers.begin(), curInfo.Mappers.end());
5105 llvm::OpenMPIRBuilder::MapInfosTy::append(curInfo);
5114struct MapInfoData : MapInfosTy {
5115 llvm::SmallVector<bool, 4> IsDeclareTarget;
5116 llvm::SmallVector<bool, 4> IsAMember;
5118 llvm::SmallVector<bool, 4> IsAMapping;
5119 llvm::SmallVector<mlir::Operation *, 4> MapClause;
5120 llvm::SmallVector<llvm::Value *, 4> OriginalValue;
5123 llvm::SmallVector<llvm::Type *, 4> BaseType;
5126 void append(MapInfoData &CurInfo) {
5127 IsDeclareTarget.append(CurInfo.IsDeclareTarget.begin(),
5128 CurInfo.IsDeclareTarget.end());
5129 MapClause.append(CurInfo.MapClause.begin(), CurInfo.MapClause.end());
5130 OriginalValue.append(CurInfo.OriginalValue.begin(),
5131 CurInfo.OriginalValue.end());
5132 BaseType.append(CurInfo.BaseType.begin(), CurInfo.BaseType.end());
5133 MapInfosTy::append(CurInfo);
5137enum class TargetDirectiveEnumTy : uint32_t {
5141 TargetEnterData = 3,
5146static TargetDirectiveEnumTy getTargetDirectiveEnumTyFromOp(Operation *op) {
5147 return llvm::TypeSwitch<Operation *, TargetDirectiveEnumTy>(op)
5148 .Case([](omp::TargetDataOp) {
return TargetDirectiveEnumTy::TargetData; })
5149 .Case([](omp::TargetEnterDataOp) {
5150 return TargetDirectiveEnumTy::TargetEnterData;
5152 .Case([&](omp::TargetExitDataOp) {
5153 return TargetDirectiveEnumTy::TargetExitData;
5155 .Case([&](omp::TargetUpdateOp) {
5156 return TargetDirectiveEnumTy::TargetUpdate;
5158 .Case([&](omp::TargetOp) {
return TargetDirectiveEnumTy::Target; })
5159 .Default([&](Operation *op) {
return TargetDirectiveEnumTy::None; });
5166 if (
auto nestedArrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(
5167 arrTy.getElementType()))
5184 llvm::Value *basePointer,
5185 llvm::Type *baseType,
5186 llvm::IRBuilderBase &builder,
5188 if (
auto memberClause =
5189 mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(clauseOp)) {
5194 if (!memberClause.getBounds().empty()) {
5195 llvm::Value *elementCount = builder.getInt64(1);
5196 for (
auto bounds : memberClause.getBounds()) {
5197 if (
auto boundOp = mlir::dyn_cast_if_present<mlir::omp::MapBoundsOp>(
5198 bounds.getDefiningOp())) {
5203 elementCount = builder.CreateMul(
5207 moduleTranslation.
lookupValue(boundOp.getUpperBound()),
5208 moduleTranslation.
lookupValue(boundOp.getLowerBound())),
5209 builder.getInt64(1)));
5216 if (
auto arrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(type))
5224 return builder.CreateMul(elementCount,
5225 builder.getInt64(underlyingTypeSzInBits / 8));
5236static llvm::omp::OpenMPOffloadMappingFlags
5238 auto mapTypeToBool = [&mlirFlags](omp::ClauseMapFlags flag) {
5239 return (mlirFlags & flag) == flag;
5241 const bool hasExplicitMap =
5242 (mlirFlags &
~omp::ClauseMapFlags::is_device_ptr) !=
5243 omp::ClauseMapFlags::none;
5245 llvm::omp::OpenMPOffloadMappingFlags mapType =
5246 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
5249 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
5252 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
5255 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
5258 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
5261 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
5264 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRIVATE;
5267 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
5270 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
5273 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
5276 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT;
5279 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD;
5282 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH;
5285 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
5286 if (!hasExplicitMap)
5287 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
5297 ArrayRef<Value> useDevAddrOperands = {},
5298 ArrayRef<Value> hasDevAddrOperands = {}) {
5299 auto checkIsAMember = [](
const auto &mapVars,
auto mapOp) {
5307 for (Value mapValue : mapVars) {
5308 auto map = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
5309 for (
auto member : map.getMembers())
5310 if (member == mapOp)
5317 for (Value mapValue : mapVars) {
5318 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
5320 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
5321 mapData.OriginalValue.push_back(moduleTranslation.
lookupValue(offloadPtr));
5322 mapData.Pointers.push_back(mapData.OriginalValue.back());
5324 if (llvm::Value *refPtr =
5326 mapData.IsDeclareTarget.push_back(
true);
5327 mapData.BasePointers.push_back(refPtr);
5329 mapData.IsDeclareTarget.push_back(
true);
5330 mapData.BasePointers.push_back(mapData.OriginalValue.back());
5332 mapData.IsDeclareTarget.push_back(
false);
5333 mapData.BasePointers.push_back(mapData.OriginalValue.back());
5336 mapData.BaseType.push_back(
5337 moduleTranslation.
convertType(mapOp.getVarType()));
5338 mapData.Sizes.push_back(
5339 getSizeInBytes(dl, mapOp.getVarType(), mapOp, mapData.Pointers.back(),
5340 mapData.BaseType.back(), builder, moduleTranslation));
5341 mapData.MapClause.push_back(mapOp.getOperation());
5345 mapData.DevicePointers.push_back(llvm::OpenMPIRBuilder::DeviceInfoTy::None);
5346 if (mapOp.getMapperId())
5347 mapData.Mappers.push_back(
5349 mapOp, mapOp.getMapperIdAttr()));
5351 mapData.Mappers.push_back(
nullptr);
5352 mapData.IsAMapping.push_back(
true);
5353 mapData.IsAMember.push_back(checkIsAMember(mapVars, mapOp));
5356 auto findMapInfo = [&mapData](llvm::Value *val,
5357 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
5360 for (llvm::Value *basePtr : mapData.OriginalValue) {
5361 if (basePtr == val && mapData.IsAMapping[index]) {
5363 mapData.Types[index] |=
5364 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
5365 mapData.DevicePointers[index] = devInfoTy;
5373 auto addDevInfos = [&](
const llvm::ArrayRef<Value> &useDevOperands,
5374 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
5375 for (Value mapValue : useDevOperands) {
5376 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
5378 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
5379 llvm::Value *origValue = moduleTranslation.
lookupValue(offloadPtr);
5382 if (!findMapInfo(origValue, devInfoTy)) {
5383 mapData.OriginalValue.push_back(origValue);
5384 mapData.Pointers.push_back(mapData.OriginalValue.back());
5385 mapData.IsDeclareTarget.push_back(
false);
5386 mapData.BasePointers.push_back(mapData.OriginalValue.back());
5387 mapData.BaseType.push_back(
5388 moduleTranslation.
convertType(mapOp.getVarType()));
5389 mapData.Sizes.push_back(builder.getInt64(0));
5390 mapData.MapClause.push_back(mapOp.getOperation());
5391 mapData.Types.push_back(
5392 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM);
5395 mapData.DevicePointers.push_back(devInfoTy);
5396 mapData.Mappers.push_back(
nullptr);
5397 mapData.IsAMapping.push_back(
false);
5398 mapData.IsAMember.push_back(checkIsAMember(useDevOperands, mapOp));
5403 addDevInfos(useDevAddrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
5404 addDevInfos(useDevPtrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer);
5406 for (Value mapValue : hasDevAddrOperands) {
5407 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
5409 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
5410 llvm::Value *origValue = moduleTranslation.
lookupValue(offloadPtr);
5412 auto mapTypeAlways = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
5414 (mapOp.getMapType() & omp::ClauseMapFlags::is_device_ptr) !=
5415 omp::ClauseMapFlags::none;
5417 mapData.OriginalValue.push_back(origValue);
5418 mapData.BasePointers.push_back(origValue);
5419 mapData.Pointers.push_back(origValue);
5420 mapData.IsDeclareTarget.push_back(
false);
5421 mapData.BaseType.push_back(
5422 moduleTranslation.
convertType(mapOp.getVarType()));
5423 mapData.Sizes.push_back(
5424 builder.getInt64(dl.
getTypeSize(mapOp.getVarType())));
5425 mapData.MapClause.push_back(mapOp.getOperation());
5426 if (llvm::to_underlying(mapType & mapTypeAlways)) {
5430 mapData.Types.push_back(mapType);
5434 if (mapOp.getMapperId()) {
5435 mapData.Mappers.push_back(
5437 mapOp, mapOp.getMapperIdAttr()));
5439 mapData.Mappers.push_back(
nullptr);
5444 mapData.Types.push_back(
5445 isDevicePtr ? mapType
5446 : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL);
5447 mapData.Mappers.push_back(
nullptr);
5451 mapData.DevicePointers.push_back(
5452 isDevicePtr ? llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer
5453 : llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
5454 mapData.IsAMapping.push_back(
false);
5455 mapData.IsAMember.push_back(checkIsAMember(hasDevAddrOperands, mapOp));
5460 auto *res = llvm::find(mapData.MapClause, memberOp);
5461 assert(res != mapData.MapClause.end() &&
5462 "MapInfoOp for member not found in MapData, cannot return index");
5463 return std::distance(mapData.MapClause.begin(), res);
5467 omp::MapInfoOp mapInfo) {
5468 ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
5478 auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]);
5479 auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]);
5481 for (auto it : llvm::zip(memberIndicesA, memberIndicesB)) {
5482 int64_t aIndex = mlir::cast<IntegerAttr>(std::get<0>(it)).getInt();
5483 int64_t bIndex = mlir::cast<IntegerAttr>(std::get<1>(it)).getInt();
5485 if (aIndex == bIndex)
5488 if (aIndex < bIndex)
5491 if (aIndex > bIndex)
5498 bool memberAParent = memberIndicesA.size() < memberIndicesB.size();
5500 occludedChildren.push_back(
b);
5502 occludedChildren.push_back(a);
5503 return memberAParent;
5509 for (
auto v : occludedChildren)
5516 ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
5518 if (indexAttr.size() == 1)
5519 return cast<omp::MapInfoOp>(mapInfo.getMembers()[0].getDefiningOp());
5523 return llvm::cast<omp::MapInfoOp>(
5548static std::vector<llvm::Value *>
5550 llvm::IRBuilderBase &builder,
bool isArrayTy,
5552 std::vector<llvm::Value *> idx;
5563 idx.push_back(builder.getInt64(0));
5564 for (
int i = bounds.size() - 1; i >= 0; --i) {
5565 if (
auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
5566 bounds[i].getDefiningOp())) {
5567 idx.push_back(moduleTranslation.
lookupValue(boundOp.getLowerBound()));
5585 std::vector<llvm::Value *> dimensionIndexSizeOffset;
5586 for (
int i = bounds.size() - 1; i >= 0; --i) {
5587 if (
auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
5588 bounds[i].getDefiningOp())) {
5589 if (i == ((
int)bounds.size() - 1))
5591 moduleTranslation.
lookupValue(boundOp.getLowerBound()));
5593 idx.back() = builder.CreateAdd(
5594 builder.CreateMul(idx.back(), moduleTranslation.
lookupValue(
5595 boundOp.getExtent())),
5596 moduleTranslation.
lookupValue(boundOp.getLowerBound()));
5605 llvm::transform(values, std::back_inserter(ints), [](
Attribute value) {
5606 return cast<IntegerAttr>(value).getInt();
5614 omp::MapInfoOp parentOp) {
5616 if (parentOp.getMembers().empty())
5620 if (parentOp.getMembers().size() == 1) {
5621 overlapMapDataIdxs.push_back(0);
5627 ArrayAttr indexAttr = parentOp.getMembersIndexAttr();
5628 for (
auto [memIndex, indicesAttr] : llvm::enumerate(indexAttr))
5629 memberByIndex.push_back(
5630 std::make_pair(memIndex, cast<ArrayAttr>(indicesAttr)));
5635 llvm::sort(memberByIndex.begin(), memberByIndex.end(),
5636 [&](
auto a,
auto b) { return a.second.size() < b.second.size(); });
5642 for (
auto v : memberByIndex) {
5646 *std::find_if(memberByIndex.begin(), memberByIndex.end(), [&](
auto x) {
5649 llvm::SmallVector<int64_t> xArr(x.second.size());
5650 getAsIntegers(x.second, xArr);
5651 return std::equal(vArr.begin(), vArr.end(), xArr.begin()) &&
5652 xArr.size() >= vArr.size();
5658 for (
auto v : memberByIndex)
5659 if (find(skipList.begin(), skipList.end(), v) == skipList.end())
5660 overlapMapDataIdxs.push_back(v.first);
5672 if (mapOp.getVarPtrPtr())
5701 llvm::OpenMPIRBuilder &ompBuilder,
DataLayout &dl, MapInfosTy &combinedInfo,
5702 MapInfoData &mapData, uint64_t mapDataIndex,
5703 TargetDirectiveEnumTy targetDirective) {
5704 assert(!ompBuilder.Config.isTargetDevice() &&
5705 "function only supported for host device codegen");
5708 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
5710 auto *parentMapper = mapData.Mappers[mapDataIndex];
5716 llvm::omp::OpenMPOffloadMappingFlags baseFlag =
5717 (targetDirective == TargetDirectiveEnumTy::Target &&
5718 !mapData.IsDeclareTarget[mapDataIndex])
5719 ? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM
5720 : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
5723 using mapFlags = llvm::omp::OpenMPOffloadMappingFlags;
5727 mapFlags parentFlags = mapData.Types[mapDataIndex];
5728 mapFlags preserve = mapFlags::OMP_MAP_TO | mapFlags::OMP_MAP_FROM |
5729 mapFlags::OMP_MAP_ALWAYS | mapFlags::OMP_MAP_CLOSE |
5730 mapFlags::OMP_MAP_PRESENT | mapFlags::OMP_MAP_OMPX_HOLD;
5731 baseFlag |= (parentFlags & preserve);
5734 combinedInfo.Types.emplace_back(baseFlag);
5735 combinedInfo.DevicePointers.emplace_back(
5736 mapData.DevicePointers[mapDataIndex]);
5740 combinedInfo.Mappers.emplace_back(
5741 parentMapper && !parentClause.getPartialMap() ? parentMapper :
nullptr);
5743 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
5744 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
5753 llvm::Value *lowAddr, *highAddr;
5754 if (!parentClause.getPartialMap()) {
5755 lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
5756 builder.getPtrTy());
5757 highAddr = builder.CreatePointerCast(
5758 builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
5759 mapData.Pointers[mapDataIndex], 1),
5760 builder.getPtrTy());
5761 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
5763 auto mapOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
5766 lowAddr = builder.CreatePointerCast(mapData.Pointers[firstMemberIdx],
5767 builder.getPtrTy());
5770 highAddr = builder.CreatePointerCast(
5771 builder.CreateGEP(mapData.BaseType[lastMemberIdx],
5772 mapData.Pointers[lastMemberIdx], builder.getInt64(1)),
5773 builder.getPtrTy());
5774 combinedInfo.Pointers.emplace_back(mapData.Pointers[firstMemberIdx]);
5777 llvm::Value *size = builder.CreateIntCast(
5778 builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
5779 builder.getInt64Ty(),
5781 combinedInfo.Sizes.push_back(size);
5783 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag =
5784 ompBuilder.getMemberOfFlag(combinedInfo.BasePointers.size() - 1);
5792 if (!parentClause.getPartialMap()) {
5797 llvm::omp::OpenMPOffloadMappingFlags mapFlag = mapData.Types[mapDataIndex];
5798 bool hasMapClose = (llvm::omp::OpenMPOffloadMappingFlags(mapFlag) &
5799 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE) ==
5800 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
5801 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
5803 if (targetDirective == TargetDirectiveEnumTy::TargetUpdate || hasMapClose) {
5804 combinedInfo.Types.emplace_back(mapFlag);
5805 combinedInfo.DevicePointers.emplace_back(
5806 mapData.DevicePointers[mapDataIndex]);
5808 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
5809 combinedInfo.BasePointers.emplace_back(
5810 mapData.BasePointers[mapDataIndex]);
5811 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
5812 combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIndex]);
5813 combinedInfo.Mappers.emplace_back(
nullptr);
5824 lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
5825 builder.getPtrTy());
5826 highAddr = builder.CreatePointerCast(
5827 builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
5828 mapData.Pointers[mapDataIndex], 1),
5829 builder.getPtrTy());
5836 for (
auto v : overlapIdxs) {
5839 cast<omp::MapInfoOp>(parentClause.getMembers()[v].getDefiningOp()));
5840 combinedInfo.Types.emplace_back(mapFlag);
5841 combinedInfo.DevicePointers.emplace_back(
5842 mapData.DevicePointers[mapDataOverlapIdx]);
5844 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
5845 combinedInfo.BasePointers.emplace_back(
5846 mapData.BasePointers[mapDataIndex]);
5847 combinedInfo.Mappers.emplace_back(
nullptr);
5848 combinedInfo.Pointers.emplace_back(lowAddr);
5849 combinedInfo.Sizes.emplace_back(builder.CreateIntCast(
5850 builder.CreatePtrDiff(builder.getInt8Ty(),
5851 mapData.OriginalValue[mapDataOverlapIdx],
5853 builder.getInt64Ty(),
true));
5854 lowAddr = builder.CreateConstGEP1_32(
5856 mapData.MapClause[mapDataOverlapIdx]))
5857 ? builder.getPtrTy()
5858 : mapData.BaseType[mapDataOverlapIdx],
5859 mapData.BasePointers[mapDataOverlapIdx], 1);
5862 combinedInfo.Types.emplace_back(mapFlag);
5863 combinedInfo.DevicePointers.emplace_back(
5864 mapData.DevicePointers[mapDataIndex]);
5866 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
5867 combinedInfo.BasePointers.emplace_back(
5868 mapData.BasePointers[mapDataIndex]);
5869 combinedInfo.Mappers.emplace_back(
nullptr);
5870 combinedInfo.Pointers.emplace_back(lowAddr);
5871 combinedInfo.Sizes.emplace_back(builder.CreateIntCast(
5872 builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
5873 builder.getInt64Ty(),
true));
5876 return memberOfFlag;
5882 llvm::OpenMPIRBuilder &ompBuilder,
DataLayout &dl, MapInfosTy &combinedInfo,
5883 MapInfoData &mapData, uint64_t mapDataIndex,
5884 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag,
5885 TargetDirectiveEnumTy targetDirective) {
5886 assert(!ompBuilder.Config.isTargetDevice() &&
5887 "function only supported for host device codegen");
5890 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
5892 for (
auto mappedMembers : parentClause.getMembers()) {
5894 llvm::cast<omp::MapInfoOp>(mappedMembers.getDefiningOp());
5897 assert(memberDataIdx >= 0 &&
"could not find mapped member of structure");
5908 mapFlag &=
~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
5909 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
5910 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
5911 combinedInfo.Types.emplace_back(mapFlag);
5912 combinedInfo.DevicePointers.emplace_back(
5913 llvm::OpenMPIRBuilder::DeviceInfoTy::None);
5914 combinedInfo.Mappers.emplace_back(
nullptr);
5915 combinedInfo.Names.emplace_back(
5917 combinedInfo.BasePointers.emplace_back(
5918 mapData.BasePointers[mapDataIndex]);
5919 combinedInfo.Pointers.emplace_back(mapData.BasePointers[memberDataIdx]);
5920 combinedInfo.Sizes.emplace_back(builder.getInt64(
5921 moduleTranslation.
getLLVMModule()->getDataLayout().getPointerSize()));
5927 mapFlag &=
~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
5928 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
5929 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
5931 ? parentClause.getVarPtr()
5932 : parentClause.getVarPtrPtr());
5935 (targetDirective != TargetDirectiveEnumTy::TargetUpdate &&
5936 targetDirective != TargetDirectiveEnumTy::TargetData))) {
5937 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
5940 combinedInfo.Types.emplace_back(mapFlag);
5941 combinedInfo.DevicePointers.emplace_back(
5942 mapData.DevicePointers[memberDataIdx]);
5943 combinedInfo.Mappers.emplace_back(mapData.Mappers[memberDataIdx]);
5944 combinedInfo.Names.emplace_back(
5946 uint64_t basePointerIndex =
5948 combinedInfo.BasePointers.emplace_back(
5949 mapData.BasePointers[basePointerIndex]);
5950 combinedInfo.Pointers.emplace_back(mapData.Pointers[memberDataIdx]);
5952 llvm::Value *size = mapData.Sizes[memberDataIdx];
5954 size = builder.CreateSelect(
5955 builder.CreateIsNull(mapData.Pointers[memberDataIdx]),
5956 builder.getInt64(0), size);
5959 combinedInfo.Sizes.emplace_back(size);
5964 MapInfosTy &combinedInfo,
5965 TargetDirectiveEnumTy targetDirective,
5966 int mapDataParentIdx = -1) {
5970 auto mapFlag = mapData.Types[mapDataIdx];
5971 auto mapInfoOp = llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIdx]);
5975 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
5977 if (targetDirective == TargetDirectiveEnumTy::Target &&
5978 !mapData.IsDeclareTarget[mapDataIdx])
5979 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
5981 if (mapInfoOp.getMapCaptureType() == omp::VariableCaptureKind::ByCopy &&
5983 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
5988 if (mapDataParentIdx >= 0)
5989 combinedInfo.BasePointers.emplace_back(
5990 mapData.BasePointers[mapDataParentIdx]);
5992 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIdx]);
5994 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIdx]);
5995 combinedInfo.DevicePointers.emplace_back(mapData.DevicePointers[mapDataIdx]);
5996 combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIdx]);
5997 combinedInfo.Names.emplace_back(mapData.Names[mapDataIdx]);
5998 combinedInfo.Types.emplace_back(mapFlag);
5999 combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIdx]);
6003 llvm::IRBuilderBase &builder,
6004 llvm::OpenMPIRBuilder &ompBuilder,
6006 MapInfoData &mapData, uint64_t mapDataIndex,
6007 TargetDirectiveEnumTy targetDirective) {
6008 assert(!ompBuilder.Config.isTargetDevice() &&
6009 "function only supported for host device codegen");
6012 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
6017 if (parentClause.getMembers().size() == 1 && parentClause.getPartialMap()) {
6018 auto memberClause = llvm::cast<omp::MapInfoOp>(
6019 parentClause.getMembers()[0].getDefiningOp());
6036 llvm::omp::OpenMPOffloadMappingFlags memberOfParentFlag =
6038 combinedInfo, mapData, mapDataIndex,
6041 combinedInfo, mapData, mapDataIndex,
6042 memberOfParentFlag, targetDirective);
6052 llvm::IRBuilderBase &builder) {
6054 "function only supported for host device codegen");
6055 for (
size_t i = 0; i < mapData.MapClause.size(); ++i) {
6057 if (!mapData.IsDeclareTarget[i]) {
6058 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
6059 omp::VariableCaptureKind captureKind = mapOp.getMapCaptureType();
6069 switch (captureKind) {
6070 case omp::VariableCaptureKind::ByRef: {
6071 llvm::Value *newV = mapData.Pointers[i];
6073 moduleTranslation, builder, mapData.BaseType[i]->isArrayTy(),
6076 newV = builder.CreateLoad(builder.getPtrTy(), newV);
6078 if (!offsetIdx.empty())
6079 newV = builder.CreateInBoundsGEP(mapData.BaseType[i], newV, offsetIdx,
6081 mapData.Pointers[i] = newV;
6083 case omp::VariableCaptureKind::ByCopy: {
6084 llvm::Type *type = mapData.BaseType[i];
6086 if (mapData.Pointers[i]->getType()->isPointerTy())
6087 newV = builder.CreateLoad(type, mapData.Pointers[i]);
6089 newV = mapData.Pointers[i];
6092 auto curInsert = builder.saveIP();
6093 llvm::DebugLoc DbgLoc = builder.getCurrentDebugLocation();
6095 auto *memTempAlloc =
6096 builder.CreateAlloca(builder.getPtrTy(),
nullptr,
".casted");
6097 builder.SetCurrentDebugLocation(DbgLoc);
6098 builder.restoreIP(curInsert);
6100 builder.CreateStore(newV, memTempAlloc);
6101 newV = builder.CreateLoad(builder.getPtrTy(), memTempAlloc);
6104 mapData.Pointers[i] = newV;
6105 mapData.BasePointers[i] = newV;
6107 case omp::VariableCaptureKind::This:
6108 case omp::VariableCaptureKind::VLAType:
6109 mapData.MapClause[i]->emitOpError(
"Unhandled capture kind");
6120 MapInfoData &mapData,
6121 TargetDirectiveEnumTy targetDirective) {
6123 "function only supported for host device codegen");
6144 for (
size_t i = 0; i < mapData.MapClause.size(); ++i) {
6147 if (mapData.IsAMember[i])
6150 auto mapInfoOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[i]);
6151 if (!mapInfoOp.getMembers().empty()) {
6153 combinedInfo, mapData, i, targetDirective);
6161static llvm::Expected<llvm::Function *>
6163 LLVM::ModuleTranslation &moduleTranslation,
6164 llvm::StringRef mapperFuncName,
6165 TargetDirectiveEnumTy targetDirective);
6167static llvm::Expected<llvm::Function *>
6170 TargetDirectiveEnumTy targetDirective) {
6172 "function only supported for host device codegen");
6173 auto declMapperOp = cast<omp::DeclareMapperOp>(op);
6174 std::string mapperFuncName =
6176 {
"omp_mapper", declMapperOp.getSymName()});
6178 if (
auto *lookupFunc = moduleTranslation.
lookupFunction(mapperFuncName))
6186 if (llvm::Function *existingFunc =
6187 moduleTranslation.
getLLVMModule()->getFunction(mapperFuncName)) {
6188 moduleTranslation.
mapFunction(mapperFuncName, existingFunc);
6189 return existingFunc;
6193 mapperFuncName, targetDirective);
6196static llvm::Expected<llvm::Function *>
6199 llvm::StringRef mapperFuncName,
6200 TargetDirectiveEnumTy targetDirective) {
6202 "function only supported for host device codegen");
6203 auto declMapperOp = cast<omp::DeclareMapperOp>(op);
6204 auto declMapperInfoOp = declMapperOp.getDeclareMapperInfo();
6207 llvm::Type *varType = moduleTranslation.
convertType(declMapperOp.getType());
6210 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
6213 MapInfosTy combinedInfo;
6215 [&](InsertPointTy codeGenIP, llvm::Value *ptrPHI,
6216 llvm::Value *unused2) -> llvm::OpenMPIRBuilder::MapInfosOrErrorTy {
6217 builder.restoreIP(codeGenIP);
6218 moduleTranslation.
mapValue(declMapperOp.getSymVal(), ptrPHI);
6219 moduleTranslation.
mapBlock(&declMapperOp.getRegion().front(),
6220 builder.GetInsertBlock());
6221 if (failed(moduleTranslation.
convertBlock(declMapperOp.getRegion().front(),
6224 return llvm::make_error<PreviouslyReportedError>();
6225 MapInfoData mapData;
6228 genMapInfos(builder, moduleTranslation, dl, combinedInfo, mapData,
6234 return combinedInfo;
6238 if (!combinedInfo.Mappers[i])
6241 moduleTranslation, targetDirective);
6245 genMapInfoCB, varType, mapperFuncName, customMapperCB);
6247 return newFn.takeError();
6248 if ([[maybe_unused]] llvm::Function *mappedFunc =
6250 assert(mappedFunc == *newFn &&
6251 "mapper function mapping disagrees with emitted function");
6253 moduleTranslation.
mapFunction(mapperFuncName, *newFn);
6261 llvm::Value *ifCond =
nullptr;
6262 llvm::Value *deviceID = builder.getInt64(llvm::omp::OMP_DEVICEID_UNDEF);
6266 llvm::omp::RuntimeFunction RTLFn;
6268 TargetDirectiveEnumTy targetDirective = getTargetDirectiveEnumTyFromOp(op);
6271 llvm::OpenMPIRBuilder::TargetDataInfo info(
6274 assert(!ompBuilder->Config.isTargetDevice() &&
6275 "target data/enter/exit/update are host ops");
6276 bool isOffloadEntry = !ompBuilder->Config.TargetTriples.empty();
6278 auto getDeviceID = [&](
mlir::Value dev) -> llvm::Value * {
6279 llvm::Value *v = moduleTranslation.
lookupValue(dev);
6280 return builder.CreateIntCast(v, builder.getInt64Ty(),
true);
6285 .Case([&](omp::TargetDataOp dataOp) {
6289 if (
auto ifVar = dataOp.getIfExpr())
6293 deviceID = getDeviceID(devId);
6295 mapVars = dataOp.getMapVars();
6296 useDevicePtrVars = dataOp.getUseDevicePtrVars();
6297 useDeviceAddrVars = dataOp.getUseDeviceAddrVars();
6300 .Case([&](omp::TargetEnterDataOp enterDataOp) -> LogicalResult {
6304 if (
auto ifVar = enterDataOp.getIfExpr())
6308 deviceID = getDeviceID(devId);
6311 enterDataOp.getNowait()
6312 ? llvm::omp::OMPRTL___tgt_target_data_begin_nowait_mapper
6313 : llvm::omp::OMPRTL___tgt_target_data_begin_mapper;
6314 mapVars = enterDataOp.getMapVars();
6315 info.HasNoWait = enterDataOp.getNowait();
6318 .Case([&](omp::TargetExitDataOp exitDataOp) -> LogicalResult {
6322 if (
auto ifVar = exitDataOp.getIfExpr())
6326 deviceID = getDeviceID(devId);
6328 RTLFn = exitDataOp.getNowait()
6329 ? llvm::omp::OMPRTL___tgt_target_data_end_nowait_mapper
6330 : llvm::omp::OMPRTL___tgt_target_data_end_mapper;
6331 mapVars = exitDataOp.getMapVars();
6332 info.HasNoWait = exitDataOp.getNowait();
6335 .Case([&](omp::TargetUpdateOp updateDataOp) -> LogicalResult {
6339 if (
auto ifVar = updateDataOp.getIfExpr())
6343 deviceID = getDeviceID(devId);
6346 updateDataOp.getNowait()
6347 ? llvm::omp::OMPRTL___tgt_target_data_update_nowait_mapper
6348 : llvm::omp::OMPRTL___tgt_target_data_update_mapper;
6349 mapVars = updateDataOp.getMapVars();
6350 info.HasNoWait = updateDataOp.getNowait();
6353 .DefaultUnreachable(
"unexpected operation");
6358 if (!isOffloadEntry)
6359 ifCond = builder.getFalse();
6361 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
6362 MapInfoData mapData;
6364 builder, useDevicePtrVars, useDeviceAddrVars);
6367 MapInfosTy combinedInfo;
6368 auto genMapInfoCB = [&](InsertPointTy codeGenIP) -> MapInfosTy & {
6369 builder.restoreIP(codeGenIP);
6370 genMapInfos(builder, moduleTranslation, DL, combinedInfo, mapData,
6372 return combinedInfo;
6378 [&moduleTranslation](
6379 llvm::OpenMPIRBuilder::DeviceInfoTy type,
6383 for (
auto [arg, useDevVar] :
6384 llvm::zip_equal(blockArgs, useDeviceVars)) {
6386 auto getMapBasePtr = [](omp::MapInfoOp mapInfoOp) {
6387 return mapInfoOp.getVarPtrPtr() ? mapInfoOp.getVarPtrPtr()
6388 : mapInfoOp.getVarPtr();
6391 auto useDevMap = cast<omp::MapInfoOp>(useDevVar.getDefiningOp());
6392 for (
auto [mapClause, devicePointer, basePointer] : llvm::zip_equal(
6393 mapInfoData.MapClause, mapInfoData.DevicePointers,
6394 mapInfoData.BasePointers)) {
6395 auto mapOp = cast<omp::MapInfoOp>(mapClause);
6396 if (getMapBasePtr(mapOp) != getMapBasePtr(useDevMap) ||
6397 devicePointer != type)
6400 if (llvm::Value *devPtrInfoMap =
6401 mapper ? mapper(basePointer) : basePointer) {
6402 moduleTranslation.
mapValue(arg, devPtrInfoMap);
6409 using BodyGenTy = llvm::OpenMPIRBuilder::BodyGenTy;
6410 auto bodyGenCB = [&](InsertPointTy codeGenIP, BodyGenTy bodyGenType)
6411 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
6414 builder.restoreIP(codeGenIP);
6415 assert(isa<omp::TargetDataOp>(op) &&
6416 "BodyGen requested for non TargetDataOp");
6417 auto blockArgIface = cast<omp::BlockArgOpenMPOpInterface>(op);
6418 Region ®ion = cast<omp::TargetDataOp>(op).getRegion();
6419 switch (bodyGenType) {
6420 case BodyGenTy::Priv:
6422 if (!info.DevicePtrInfoMap.empty()) {
6423 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
6424 blockArgIface.getUseDeviceAddrBlockArgs(),
6425 useDeviceAddrVars, mapData,
6426 [&](llvm::Value *basePointer) -> llvm::Value * {
6427 if (!info.DevicePtrInfoMap[basePointer].second)
6429 return builder.CreateLoad(
6431 info.DevicePtrInfoMap[basePointer].second);
6433 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
6434 blockArgIface.getUseDevicePtrBlockArgs(), useDevicePtrVars,
6435 mapData, [&](llvm::Value *basePointer) {
6436 return info.DevicePtrInfoMap[basePointer].second;
6440 moduleTranslation)))
6441 return llvm::make_error<PreviouslyReportedError>();
6444 case BodyGenTy::DupNoPriv:
6445 if (info.DevicePtrInfoMap.empty()) {
6448 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
6449 blockArgIface.getUseDeviceAddrBlockArgs(),
6450 useDeviceAddrVars, mapData);
6451 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
6452 blockArgIface.getUseDevicePtrBlockArgs(), useDevicePtrVars,
6456 case BodyGenTy::NoPriv:
6458 if (info.DevicePtrInfoMap.empty()) {
6460 moduleTranslation)))
6461 return llvm::make_error<PreviouslyReportedError>();
6465 return builder.saveIP();
6468 auto customMapperCB =
6470 if (!combinedInfo.Mappers[i])
6472 info.HasMapper =
true;
6474 moduleTranslation, targetDirective);
6477 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
6479 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
6481 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = [&]() {
6482 if (isa<omp::TargetDataOp>(op))
6483 return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(),
6484 deallocBlocks, deviceID, ifCond, info,
6485 genMapInfoCB, customMapperCB,
6488 return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(),
6489 deallocBlocks, deviceID, ifCond, info,
6490 genMapInfoCB, customMapperCB, &RTLFn);
6496 builder.restoreIP(*afterIP);
6504 auto distributeOp = cast<omp::DistributeOp>(opInst);
6511 bool doDistributeReduction =
6515 unsigned numReductionVars = teamsOp ? teamsOp.getNumReductionVars() : 0;
6520 if (doDistributeReduction) {
6521 isByRef =
getIsByRef(teamsOp.getReductionByref());
6522 assert(isByRef.size() == teamsOp.getNumReductionVars());
6525 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
6529 llvm::cast<omp::BlockArgOpenMPOpInterface>(*teamsOp)
6530 .getReductionBlockArgs();
6533 teamsOp, reductionArgs, builder, moduleTranslation, allocaIP,
6534 reductionDecls, privateReductionVariables, reductionVariableMap,
6539 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
6541 [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
6546 moduleTranslation, allocaIP, deallocBlocks);
6549 builder.restoreIP(codeGenIP);
6553 distributeOp, builder, moduleTranslation, privVarsInfo, allocaIP);
6555 return llvm::make_error<PreviouslyReportedError>();
6560 return llvm::make_error<PreviouslyReportedError>();
6563 distributeOp, builder, moduleTranslation, privVarsInfo.
mlirVars,
6565 distributeOp.getPrivateNeedsBarrier())))
6566 return llvm::make_error<PreviouslyReportedError>();
6569 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
6572 builder, moduleTranslation);
6574 return regionBlock.takeError();
6575 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
6580 if (!isa_and_present<omp::WsloopOp>(distributeOp.getNestedWrapper())) {
6583 bool hasDistSchedule = distributeOp.getDistScheduleStatic();
6584 auto schedule = hasDistSchedule ? omp::ClauseScheduleKind::Distribute
6585 : omp::ClauseScheduleKind::Static;
6587 bool isOrdered = hasDistSchedule;
6588 std::optional<omp::ScheduleModifier> scheduleMod;
6589 bool isSimd =
false;
6590 llvm::omp::WorksharingLoopType workshareLoopType =
6591 llvm::omp::WorksharingLoopType::DistributeStaticLoop;
6592 bool loopNeedsBarrier =
false;
6593 llvm::Value *chunk = moduleTranslation.
lookupValue(
6594 distributeOp.getDistScheduleChunkSize());
6595 llvm::CanonicalLoopInfo *loopInfo =
6597 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
6598 ompBuilder->applyWorkshareLoop(
6599 ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
6600 convertToScheduleKind(schedule), chunk, isSimd,
6601 scheduleMod == omp::ScheduleModifier::monotonic,
6602 scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
6603 workshareLoopType,
false, hasDistSchedule, chunk);
6606 return wsloopIP.takeError();
6609 distributeOp.getLoc(), privVarsInfo)))
6610 return llvm::make_error<PreviouslyReportedError>();
6612 return llvm::Error::success();
6616 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
6618 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
6619 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
6620 ompBuilder->createDistribute(ompLoc, allocaIP, deallocBlocks, bodyGenCB);
6625 builder.restoreIP(*afterIP);
6627 if (doDistributeReduction) {
6630 teamsOp, builder, moduleTranslation, allocaIP, reductionDecls,
6631 privateReductionVariables, isByRef,
6643 if (!cast<mlir::ModuleOp>(op))
6648 ompBuilder->M.addModuleFlag(llvm::Module::Max,
"openmp-device",
6649 attribute.getOpenmpDeviceVersion());
6651 if (attribute.getNoGpuLib())
6654 ompBuilder->createGlobalFlag(
6655 attribute.getDebugKind() ,
6656 "__omp_rtl_debug_kind");
6657 ompBuilder->createGlobalFlag(
6659 .getAssumeTeamsOversubscription()
6661 "__omp_rtl_assume_teams_oversubscription");
6662 ompBuilder->createGlobalFlag(
6664 .getAssumeThreadsOversubscription()
6666 "__omp_rtl_assume_threads_oversubscription");
6667 ompBuilder->createGlobalFlag(
6668 attribute.getAssumeNoThreadState() ,
6669 "__omp_rtl_assume_no_thread_state");
6670 ompBuilder->createGlobalFlag(
6672 .getAssumeNoNestedParallelism()
6674 "__omp_rtl_assume_no_nested_parallelism");
6679 omp::TargetOp targetOp,
6680 llvm::OpenMPIRBuilder &ompBuilder,
6681 llvm::vfs::FileSystem &vfs,
6682 llvm::StringRef parentName =
"") {
6683 auto fileLoc = targetOp.getLoc()->findInstanceOf<
FileLineColLoc>();
6684 assert(fileLoc &&
"No file found from location");
6686 auto fileInfoCallBack = [&fileLoc]() {
6687 return std::pair<std::string, uint64_t>(
6688 llvm::StringRef(fileLoc.getFilename()), fileLoc.getLine());
6692 ompBuilder.getTargetEntryUniqueInfo(fileInfoCallBack, vfs, parentName);
6698 llvm::IRBuilderBase &builder, llvm::Function *
func) {
6700 "function only supported for target device codegen");
6701 llvm::IRBuilderBase::InsertPointGuard guard(builder);
6702 for (
size_t i = 0; i < mapData.MapClause.size(); ++i) {
6715 if (mapData.IsDeclareTarget[i]) {
6722 if (
auto *constant = dyn_cast<llvm::Constant>(mapData.OriginalValue[i]))
6723 convertUsersOfConstantsToInstructions(constant,
func,
false);
6730 for (llvm::User *user : mapData.OriginalValue[i]->users())
6731 userVec.push_back(user);
6733 for (llvm::User *user : userVec) {
6734 if (
auto *insn = dyn_cast<llvm::Instruction>(user)) {
6735 if (insn->getFunction() ==
func) {
6736 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
6737 llvm::Value *substitute = mapData.BasePointers[i];
6739 : mapOp.getVarPtr())) {
6740 builder.SetCurrentDebugLocation(insn->getDebugLoc());
6741 substitute = builder.CreateLoad(
6742 mapData.BasePointers[i]->getType(), mapData.BasePointers[i]);
6743 cast<llvm::LoadInst>(substitute)->moveBefore(insn->getIterator());
6745 user->replaceUsesOfWith(mapData.OriginalValue[i], substitute);
6793 omp::TargetOp targetOp, MapInfoData &mapData, llvm::Argument &arg,
6794 llvm::Value *input, llvm::Value *&retVal, llvm::IRBuilderBase &builder,
6795 llvm::OpenMPIRBuilder &ompBuilder,
6797 llvm::IRBuilderBase::InsertPoint allocaIP,
6798 llvm::IRBuilderBase::InsertPoint codeGenIP,
6800 assert(ompBuilder.Config.isTargetDevice() &&
6801 "function only supported for target device codegen");
6802 builder.restoreIP(allocaIP);
6804 omp::VariableCaptureKind capture = omp::VariableCaptureKind::ByRef;
6806 ompBuilder.M.getContext());
6807 unsigned alignmentValue = 0;
6810 cast<omp::BlockArgOpenMPOpInterface>(*targetOp).getBlockArgsPairs(
6813 for (
size_t i = 0; i < mapData.MapClause.size(); ++i) {
6814 if (mapData.OriginalValue[i] == input) {
6815 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
6816 capture = mapOp.getMapCaptureType();
6819 mapOp.getVarType(), ompBuilder.M.getDataLayout());
6823 for (
auto &[val, arg] : blockArgsPairs) {
6824 if (mapOp.getResult() == val) {
6829 assert(mlirArg &&
"expected to find entry block argument for map clause");
6834 unsigned int allocaAS = ompBuilder.M.getDataLayout().getAllocaAddrSpace();
6835 unsigned int defaultAS =
6836 ompBuilder.M.getDataLayout().getProgramAddressSpace();
6839 llvm::Value *v =
nullptr;
6847 builder.SetInsertPoint(codeGenIP.getBlock()->getFirstInsertionPt());
6848 v = ompBuilder.createOMPAllocShared(builder, arg.getType());
6852 llvm::IRBuilderBase::InsertPointGuard guard(builder);
6853 for (
auto deallocIP : deallocIPs) {
6854 builder.SetInsertPoint(deallocIP.getBlock(), deallocIP.getPoint());
6855 ompBuilder.createOMPFreeShared(builder, v, arg.getType());
6859 v = builder.CreateAlloca(arg.getType(), allocaAS);
6861 if (allocaAS != defaultAS && arg.getType()->isPointerTy())
6862 v = builder.CreateAddrSpaceCast(v, builder.getPtrTy(defaultAS));
6865 builder.CreateStore(&arg, v);
6867 builder.restoreIP(codeGenIP);
6870 case omp::VariableCaptureKind::ByCopy: {
6874 case omp::VariableCaptureKind::ByRef: {
6875 llvm::LoadInst *loadInst = builder.CreateAlignedLoad(
6877 ompBuilder.M.getDataLayout().getPrefTypeAlign(v->getType()));
6892 if (v->getType()->isPointerTy() && alignmentValue) {
6893 llvm::MDBuilder MDB(builder.getContext());
6894 loadInst->setMetadata(
6895 llvm::LLVMContext::MD_align,
6896 llvm::MDNode::get(builder.getContext(),
6897 MDB.createConstant(llvm::ConstantInt::get(
6898 llvm::Type::getInt64Ty(builder.getContext()),
6905 case omp::VariableCaptureKind::This:
6906 case omp::VariableCaptureKind::VLAType:
6909 assert(
false &&
"Currently unsupported capture kind");
6913 return builder.saveIP();
6930 auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
6931 for (
auto item : llvm::zip_equal(targetOp.getHostEvalVars(),
6932 blockArgIface.getHostEvalBlockArgs())) {
6933 Value hostEvalVar = std::get<0>(item), blockArg = std::get<1>(item);
6937 .Case([&](omp::TeamsOp teamsOp) {
6938 if (teamsOp.getNumTeamsLower() == blockArg)
6939 numTeamsLower = hostEvalVar;
6940 else if (llvm::is_contained(teamsOp.getNumTeamsUpperVars(),
6942 numTeamsUpper = hostEvalVar;
6943 else if (!teamsOp.getThreadLimitVars().empty() &&
6944 teamsOp.getThreadLimit(0) == blockArg)
6945 threadLimit = hostEvalVar;
6947 llvm_unreachable(
"unsupported host_eval use");
6949 .Case([&](omp::ParallelOp parallelOp) {
6950 if (!parallelOp.getNumThreadsVars().empty() &&
6951 parallelOp.getNumThreads(0) == blockArg)
6952 numThreads = hostEvalVar;
6954 llvm_unreachable(
"unsupported host_eval use");
6956 .Case([&](omp::LoopNestOp loopOp) {
6957 auto processBounds =
6961 for (
auto [i, lb] : llvm::enumerate(opBounds)) {
6962 if (lb == blockArg) {
6965 (*outBounds)[i] = hostEvalVar;
6971 processBounds(loopOp.getLoopLowerBounds(), lowerBounds);
6972 found = processBounds(loopOp.getLoopUpperBounds(), upperBounds) ||
6974 found = processBounds(loopOp.getLoopSteps(), steps) || found;
6976 assert(found &&
"unsupported host_eval use");
6978 .DefaultUnreachable(
"unsupported host_eval use");
6990template <
typename OpTy>
6995 if (OpTy casted = dyn_cast<OpTy>(op))
6998 if (immediateParent)
6999 return dyn_cast_if_present<OpTy>(op->
getParentOp());
7008 return std::nullopt;
7011 if (
auto constAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
7012 return constAttr.getInt();
7014 return std::nullopt;
7019 uint64_t sizeInBytes = sizeInBits / 8;
7023template <
typename OpTy>
7025 if (op.getNumReductionVars() > 0) {
7030 members.reserve(reductions.size());
7031 for (omp::DeclareReductionOp &red : reductions) {
7035 if (red.getByrefElementType())
7036 members.push_back(*red.getByrefElementType());
7038 members.push_back(red.getType());
7041 auto structType = mlir::LLVM::LLVMStructType::getLiteral(
7057 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs,
7058 bool isTargetDevice,
bool isGPU) {
7061 Value numThreads, numTeamsLower, numTeamsUpper, threadLimit;
7062 if (!isTargetDevice) {
7070 numTeamsLower = teamsOp.getNumTeamsLower();
7072 if (!teamsOp.getNumTeamsUpperVars().empty())
7073 numTeamsUpper = teamsOp.getNumTeams(0);
7074 if (!teamsOp.getThreadLimitVars().empty())
7075 threadLimit = teamsOp.getThreadLimit(0);
7079 if (!parallelOp.getNumThreadsVars().empty())
7080 numThreads = parallelOp.getNumThreads(0);
7086 int32_t minTeamsVal = 1, maxTeamsVal = -1;
7090 if (numTeamsUpper) {
7092 minTeamsVal = maxTeamsVal = *val;
7094 minTeamsVal = maxTeamsVal = 0;
7100 minTeamsVal = maxTeamsVal = 1;
7102 minTeamsVal = maxTeamsVal = -1;
7107 auto setMaxValueFromClause = [](
Value clauseValue, int32_t &
result) {
7121 int32_t targetThreadLimitVal = -1, teamsThreadLimitVal = -1;
7122 if (!targetOp.getThreadLimitVars().empty())
7123 setMaxValueFromClause(targetOp.getThreadLimit(0), targetThreadLimitVal);
7124 setMaxValueFromClause(threadLimit, teamsThreadLimitVal);
7127 int32_t maxThreadsVal = -1;
7129 setMaxValueFromClause(numThreads, maxThreadsVal);
7137 int32_t combinedMaxThreadsVal = targetThreadLimitVal;
7138 if (combinedMaxThreadsVal < 0 ||
7139 (teamsThreadLimitVal >= 0 && teamsThreadLimitVal < combinedMaxThreadsVal))
7140 combinedMaxThreadsVal = teamsThreadLimitVal;
7142 if (combinedMaxThreadsVal < 0 ||
7143 (maxThreadsVal >= 0 && maxThreadsVal < combinedMaxThreadsVal))
7144 combinedMaxThreadsVal = maxThreadsVal;
7146 int32_t reductionDataSize = 0;
7147 if (isGPU && capturedOp) {
7153 omp::TargetExecMode execMode = targetOp.getKernelExecFlags(capturedOp);
7155 case omp::TargetExecMode::bare:
7156 attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_BARE;
7158 case omp::TargetExecMode::generic:
7159 attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_GENERIC;
7161 case omp::TargetExecMode::spmd:
7162 attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_SPMD;
7164 case omp::TargetExecMode::no_loop:
7165 attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_SPMD_NO_LOOP;
7168 attrs.MinTeams = minTeamsVal;
7169 attrs.MaxTeams.front() = maxTeamsVal;
7170 attrs.MinThreads = 1;
7171 attrs.MaxThreads.front() = combinedMaxThreadsVal;
7172 attrs.ReductionDataSize = reductionDataSize;
7175 if (attrs.ReductionDataSize != 0)
7176 attrs.ReductionBufferLength = 1024;
7188 omp::TargetOp targetOp,
Operation *capturedOp,
7189 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs) {
7191 unsigned numLoops = loopOp ? loopOp.getNumLoops() : 0;
7193 Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
7197 teamsThreadLimit, &lowerBounds, &upperBounds, &steps);
7200 if (!targetOp.getThreadLimitVars().empty()) {
7201 Value targetThreadLimit = targetOp.getThreadLimit(0);
7202 attrs.TargetThreadLimit.front() =
7210 attrs.MinTeams = builder.CreateSExtOrTrunc(
7211 moduleTranslation.
lookupValue(numTeamsLower), builder.getInt32Ty());
7214 attrs.MaxTeams.front() = builder.CreateSExtOrTrunc(
7215 moduleTranslation.
lookupValue(numTeamsUpper), builder.getInt32Ty());
7217 if (teamsThreadLimit)
7218 attrs.TeamsThreadLimit.front() = builder.CreateSExtOrTrunc(
7219 moduleTranslation.
lookupValue(teamsThreadLimit), builder.getInt32Ty());
7222 attrs.MaxThreads = moduleTranslation.
lookupValue(numThreads);
7224 bool hostEvalTripCount;
7225 targetOp.getKernelExecFlags(capturedOp, &hostEvalTripCount);
7226 if (hostEvalTripCount) {
7228 attrs.LoopTripCount =
nullptr;
7233 for (
auto [loopLower, loopUpper, loopStep] :
7234 llvm::zip_equal(lowerBounds, upperBounds, steps)) {
7235 llvm::Value *lowerBound = moduleTranslation.
lookupValue(loopLower);
7236 llvm::Value *upperBound = moduleTranslation.
lookupValue(loopUpper);
7237 llvm::Value *step = moduleTranslation.
lookupValue(loopStep);
7239 if (!lowerBound || !upperBound || !step) {
7240 attrs.LoopTripCount =
nullptr;
7244 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
7245 llvm::Value *tripCount = ompBuilder->calculateCanonicalLoopTripCount(
7246 loc, lowerBound, upperBound, step,
true,
7247 loopOp.getLoopInclusive());
7249 if (!attrs.LoopTripCount) {
7250 attrs.LoopTripCount = tripCount;
7255 attrs.LoopTripCount = builder.CreateMul(attrs.LoopTripCount, tripCount,
7260 attrs.DeviceID = builder.getInt64(llvm::omp::OMP_DEVICEID_UNDEF);
7262 attrs.DeviceID = moduleTranslation.
lookupValue(devId);
7264 builder.CreateSExtOrTrunc(attrs.DeviceID, builder.getInt64Ty());
7268static llvm::omp::OMPDynGroupprivateFallbackType
7270 omp::FallbackModifier fb = fallbackAttr ? fallbackAttr.getValue()
7271 : omp::FallbackModifier::default_mem;
7273 case omp::FallbackModifier::abort:
7274 return llvm::omp::OMPDynGroupprivateFallbackType::Abort;
7275 case omp::FallbackModifier::null:
7276 return llvm::omp::OMPDynGroupprivateFallbackType::Null;
7277 case omp::FallbackModifier::default_mem:
7278 return llvm::omp::OMPDynGroupprivateFallbackType::DefaultMem;
7281 llvm_unreachable(
"unexpected dyn_groupprivate fallback type");
7287 auto targetOp = cast<omp::TargetOp>(opInst);
7292 llvm::DebugLoc outlinedFnLoc = builder.getCurrentDebugLocation();
7301 llvm::BasicBlock *parentBB = builder.GetInsertBlock();
7302 assert(parentBB &&
"No insert block is set for the builder");
7303 llvm::Function *parentLLVMFn = parentBB->getParent();
7304 assert(parentLLVMFn &&
"Parent Function must be valid");
7305 if (llvm::DISubprogram *SP = parentLLVMFn->getSubprogram())
7306 builder.SetCurrentDebugLocation(llvm::DILocation::get(
7307 parentLLVMFn->getContext(), outlinedFnLoc.getLine(),
7308 outlinedFnLoc.getCol(), SP, outlinedFnLoc.getInlinedAt()));
7311 bool isTargetDevice = ompBuilder->Config.isTargetDevice();
7312 bool isGPU = ompBuilder->Config.isGPU();
7315 auto argIface = cast<omp::BlockArgOpenMPOpInterface>(opInst);
7316 auto &targetRegion = targetOp.getRegion();
7333 llvm::Function *llvmOutlinedFn =
nullptr;
7334 TargetDirectiveEnumTy targetDirective =
7335 getTargetDirectiveEnumTyFromOp(&opInst);
7339 bool isOffloadEntry =
7340 isTargetDevice || !ompBuilder->Config.TargetTriples.empty();
7347 if (!targetOp.getPrivateVars().empty() && !targetOp.getMapVars().empty()) {
7349 std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms();
7350 std::optional<DenseI64ArrayAttr> privateMapIndices =
7351 targetOp.getPrivateMapsAttr();
7353 for (
auto [privVarIdx, privVarSymPair] :
7354 llvm::enumerate(llvm::zip_equal(privateVars, *privateSyms))) {
7355 auto privVar = std::get<0>(privVarSymPair);
7356 auto privSym = std::get<1>(privVarSymPair);
7358 SymbolRefAttr privatizerName = llvm::cast<SymbolRefAttr>(privSym);
7359 omp::PrivateClauseOp privatizer =
7362 if (!privatizer.needsMap())
7366 targetOp.getMappedValueForPrivateVar(privVarIdx);
7367 assert(mappedValue &&
"Expected to find mapped value for a privatized "
7368 "variable that needs mapping");
7373 auto mapInfoOp = mappedValue.
getDefiningOp<omp::MapInfoOp>();
7374 [[maybe_unused]]
Type varType = mapInfoOp.getVarType();
7378 if (!isa<LLVM::LLVMPointerType>(privVar.getType()))
7380 varType == privVar.getType() &&
7381 "Type of private var doesn't match the type of the mapped value");
7385 mappedPrivateVars.insert(
7387 targetRegion.getArgument(argIface.getMapBlockArgsStart() +
7388 (*privateMapIndices)[privVarIdx])});
7392 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
7393 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
7395 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
7396 llvm::IRBuilderBase::InsertPointGuard guard(builder);
7397 builder.SetCurrentDebugLocation(llvm::DebugLoc());
7400 llvm::Function *llvmParentFn =
7402 llvmOutlinedFn = codeGenIP.getBlock()->getParent();
7403 assert(llvmParentFn && llvmOutlinedFn &&
7404 "Both parent and outlined functions must exist at this point");
7406 if (outlinedFnLoc && llvmParentFn->getSubprogram())
7407 llvmOutlinedFn->setSubprogram(outlinedFnLoc->getScope()->getSubprogram());
7409 if (
auto attr = llvmParentFn->getFnAttribute(
"target-cpu");
7410 attr.isStringAttribute())
7411 llvmOutlinedFn->addFnAttr(attr);
7413 if (
auto attr = llvmParentFn->getFnAttribute(
"target-features");
7414 attr.isStringAttribute())
7415 llvmOutlinedFn->addFnAttr(attr);
7417 for (
auto [arg, mapOp] : llvm::zip_equal(mapBlockArgs, mapVars)) {
7418 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
7419 llvm::Value *mapOpValue =
7420 moduleTranslation.
lookupValue(mapInfoOp.getVarPtr());
7421 moduleTranslation.
mapValue(arg, mapOpValue);
7423 for (
auto [arg, mapOp] : llvm::zip_equal(hdaBlockArgs, hdaVars)) {
7424 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
7425 llvm::Value *mapOpValue =
7426 moduleTranslation.
lookupValue(mapInfoOp.getVarPtr());
7427 moduleTranslation.
mapValue(arg, mapOpValue);
7436 privateVarsInfo, allocaIP, &mappedPrivateVars);
7439 return llvm::make_error<PreviouslyReportedError>();
7441 builder.restoreIP(codeGenIP);
7443 &mappedPrivateVars),
7446 return llvm::make_error<PreviouslyReportedError>();
7449 targetOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
7451 targetOp.getPrivateNeedsBarrier(), &mappedPrivateVars)))
7452 return llvm::make_error<PreviouslyReportedError>();
7455 moduleTranslation, allocaIP, deallocBlocks);
7457 targetRegion,
"omp.target", builder, moduleTranslation);
7460 return llvm::make_error<PreviouslyReportedError>();
7462 builder.SetInsertPoint(exitBlock.get()->getTerminator());
7465 targetOp.getLoc(), privateVarsInfo)))
7466 return llvm::make_error<PreviouslyReportedError>();
7468 return builder.saveIP();
7471 StringRef parentName = parentFn.getName();
7473 llvm::TargetRegionEntryInfo entryInfo;
7479 MapInfoData mapData;
7484 MapInfosTy combinedInfos;
7486 [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP) -> MapInfosTy & {
7487 builder.restoreIP(codeGenIP);
7488 genMapInfos(builder, moduleTranslation, dl, combinedInfos, mapData,
7493 auto *nullPtr = llvm::Constant::getNullValue(builder.getPtrTy());
7494 combinedInfos.BasePointers.push_back(nullPtr);
7495 combinedInfos.Pointers.push_back(nullPtr);
7496 combinedInfos.DevicePointers.push_back(
7497 llvm::OpenMPIRBuilder::DeviceInfoTy::None);
7498 combinedInfos.Sizes.push_back(builder.getInt64(0));
7499 combinedInfos.Types.push_back(
7500 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM |
7501 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL);
7502 if (!combinedInfos.Names.empty())
7503 combinedInfos.Names.push_back(nullPtr);
7504 combinedInfos.Mappers.push_back(
nullptr);
7506 return combinedInfos;
7509 auto argAccessorCB = [&](llvm::Argument &arg, llvm::Value *input,
7510 llvm::Value *&retVal, InsertPointTy allocaIP,
7511 InsertPointTy codeGenIP,
7513 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
7514 llvm::IRBuilderBase::InsertPointGuard guard(builder);
7515 builder.SetCurrentDebugLocation(llvm::DebugLoc());
7521 if (!isTargetDevice) {
7522 retVal = cast<llvm::Value>(&arg);
7527 builder, *ompBuilder, moduleTranslation,
7528 allocaIP, codeGenIP, deallocIPs);
7531 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs runtimeAttrs;
7532 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs;
7533 Operation *targetCapturedOp = targetOp.getInnermostCapturedOmpOp();
7535 isTargetDevice, isGPU);
7539 if (!isTargetDevice)
7541 targetCapturedOp, runtimeAttrs);
7549 for (
auto [arg, var] : llvm::zip_equal(hostEvalBlockArgs, hostEvalVars)) {
7550 llvm::Value *value = moduleTranslation.
lookupValue(var);
7551 moduleTranslation.
mapValue(arg, value);
7553 if (!llvm::isa<llvm::Constant>(value))
7554 kernelInput.push_back(value);
7557 for (
size_t i = 0, e = mapData.OriginalValue.size(); i != e; ++i) {
7564 if (!mapData.IsDeclareTarget[i] && !mapData.IsAMember[i])
7565 kernelInput.push_back(mapData.OriginalValue[i]);
7569 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
7572 llvm::OpenMPIRBuilder::DependenciesInfo dds;
7574 targetOp.getDependVars(), targetOp.getDependKinds(),
7575 targetOp.getDependIterated(), targetOp.getDependIteratedKinds(),
7576 builder, moduleTranslation, dds)))
7579 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
7581 llvm::OpenMPIRBuilder::TargetDataInfo info(
7585 auto customMapperCB =
7587 if (!combinedInfos.Mappers[i])
7589 info.HasMapper =
true;
7591 moduleTranslation, targetDirective);
7594 llvm::Value *ifCond =
nullptr;
7595 if (
Value targetIfCond = targetOp.getIfExpr())
7596 ifCond = moduleTranslation.
lookupValue(targetIfCond);
7598 Value dynGroupPrivateSize = targetOp.getDynGroupprivateSize();
7599 llvm::Value *dynSizeVal =
nullptr;
7600 if (dynGroupPrivateSize) {
7601 dynSizeVal = moduleTranslation.
lookupValue(dynGroupPrivateSize);
7602 dynSizeVal = builder.CreateIntCast(dynSizeVal, builder.getInt32Ty(),
7606 llvm::omp::OMPDynGroupprivateFallbackType fallbackType =
7609 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
7611 ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), deallocBlocks,
7612 info, entryInfo, defaultAttrs, runtimeAttrs, ifCond, kernelInput,
7613 genMapInfoCB, bodyCB, argAccessorCB, customMapperCB, dds,
7614 targetOp.getNowait(), dynSizeVal, fallbackType);
7619 builder.restoreIP(*afterIP);
7622 builder.CreateFree(dds.DepArray);
7635 llvm::OpenMPIRBuilder *ompBuilder,
7644 if (FunctionOpInterface funcOp = dyn_cast<FunctionOpInterface>(op)) {
7645 if (
auto offloadMod = dyn_cast<omp::OffloadModuleInterface>(
7647 if (!offloadMod.getIsTargetDevice())
7650 omp::DeclareTargetDeviceType declareType =
7651 attribute.getDeviceType().getValue();
7653 if (declareType == omp::DeclareTargetDeviceType::host) {
7654 llvm::Function *llvmFunc =
7656 llvmFunc->dropAllReferences();
7657 llvmFunc->eraseFromParent();
7661 ompBuilder->Builder.ClearInsertionPoint();
7662 ompBuilder->Builder.SetCurrentDebugLocation(llvm::DebugLoc());
7668 if (LLVM::GlobalOp gOp = dyn_cast<LLVM::GlobalOp>(op)) {
7669 llvm::Module *llvmModule = moduleTranslation.
getLLVMModule();
7670 if (
auto *gVal = llvmModule->getNamedValue(gOp.getSymName())) {
7672 bool isDeclaration = gOp.isDeclaration();
7673 bool isExternallyVisible =
7676 llvm::StringRef mangledName = gOp.getSymName();
7677 auto captureClause =
7683 std::vector<llvm::GlobalVariable *> generatedRefs;
7685 std::vector<llvm::Triple> targetTriple;
7686 auto targetTripleAttr = dyn_cast_or_null<mlir::StringAttr>(
7688 LLVM::LLVMDialect::getTargetTripleAttrName()));
7689 if (targetTripleAttr)
7690 targetTriple.emplace_back(targetTripleAttr.data());
7692 auto fileInfoCallBack = [&loc]() {
7693 std::string filename =
"";
7694 std::uint64_t lineNo = 0;
7697 filename = loc.getFilename().str();
7698 lineNo = loc.getLine();
7701 return std::pair<std::string, std::uint64_t>(llvm::StringRef(filename),
7705 llvm::vfs::FileSystem &vfs = moduleTranslation.
getFileSystem();
7707 ompBuilder->registerTargetGlobalVariable(
7708 captureClause, deviceClause, isDeclaration, isExternallyVisible,
7709 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack, vfs),
7710 mangledName, generatedRefs,
false, targetTriple,
7712 gVal->getType(), gVal);
7714 if (ompBuilder->Config.isTargetDevice() &&
7715 (attribute.getCaptureClause().getValue() !=
7716 mlir::omp::DeclareTargetCaptureClause::to ||
7717 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
7718 ompBuilder->getAddrOfDeclareTargetVar(
7719 captureClause, deviceClause, isDeclaration, isExternallyVisible,
7720 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack, vfs),
7721 mangledName, generatedRefs,
false, targetTriple,
7722 gVal->getType(),
nullptr,
7735class OpenMPDialectLLVMIRTranslationInterface
7736 :
public LLVMTranslationDialectInterface {
7738 using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
7743 convertOperation(Operation *op, llvm::IRBuilderBase &builder,
7744 LLVM::ModuleTranslation &moduleTranslation)
const final;
7749 amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
7750 NamedAttribute attribute,
7751 LLVM::ModuleTranslation &moduleTranslation)
const final;
7756 void registerAllocatedPtr(Value var, llvm::Value *ptr)
const {
7757 ompAllocatedPtrs[var] = ptr;
7762 llvm::Value *lookupAllocatedPtr(Value var)
const {
7763 auto it = ompAllocatedPtrs.find(var);
7764 return it != ompAllocatedPtrs.end() ? it->second :
nullptr;
7776LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
7777 Operation *op, ArrayRef<llvm::Instruction *> instructions,
7778 NamedAttribute attribute,
7779 LLVM::ModuleTranslation &moduleTranslation)
const {
7780 return llvm::StringSwitch<llvm::function_ref<LogicalResult(Attribute)>>(
7782 .Case(
"omp.is_target_device",
7783 [&](Attribute attr) {
7784 if (
auto deviceAttr = dyn_cast<BoolAttr>(attr)) {
7785 llvm::OpenMPIRBuilderConfig &config =
7787 config.setIsTargetDevice(deviceAttr.getValue());
7793 [&](Attribute attr) {
7794 if (
auto gpuAttr = dyn_cast<BoolAttr>(attr)) {
7795 llvm::OpenMPIRBuilderConfig &config =
7797 config.setIsGPU(gpuAttr.getValue());
7802 .Case(
"omp.host_ir_filepath",
7803 [&](Attribute attr) {
7804 if (
auto filepathAttr = dyn_cast<StringAttr>(attr)) {
7805 llvm::OpenMPIRBuilder *ompBuilder =
7807 ompBuilder->loadOffloadInfoMetadata(
7808 moduleTranslation.
getFileSystem(), filepathAttr.getValue());
7814 [&](Attribute attr) {
7815 if (
auto rtlAttr = dyn_cast<omp::FlagsAttr>(attr))
7819 .Case(
"omp.version",
7820 [&](Attribute attr) {
7821 if (
auto versionAttr = dyn_cast<omp::VersionAttr>(attr)) {
7822 llvm::OpenMPIRBuilder *ompBuilder =
7824 ompBuilder->M.addModuleFlag(llvm::Module::Max,
"openmp",
7825 versionAttr.getVersion());
7830 .Case(
"omp.declare_target",
7831 [&](Attribute attr) {
7832 if (
auto declareTargetAttr =
7833 dyn_cast<omp::DeclareTargetAttr>(attr)) {
7834 llvm::OpenMPIRBuilder *ompBuilder =
7837 ompBuilder, moduleTranslation);
7841 .Case(
"omp.requires",
7842 [&](Attribute attr) {
7843 if (
auto requiresAttr = dyn_cast<omp::ClauseRequiresAttr>(attr)) {
7844 using Requires = omp::ClauseRequires;
7845 Requires flags = requiresAttr.getValue();
7846 llvm::OpenMPIRBuilderConfig &config =
7848 config.setHasRequiresReverseOffload(
7849 bitEnumContainsAll(flags, Requires::reverse_offload));
7850 config.setHasRequiresUnifiedAddress(
7851 bitEnumContainsAll(flags, Requires::unified_address));
7852 config.setHasRequiresUnifiedSharedMemory(
7853 bitEnumContainsAll(flags, Requires::unified_shared_memory));
7854 config.setHasRequiresDynamicAllocators(
7855 bitEnumContainsAll(flags, Requires::dynamic_allocators));
7860 .Case(
"omp.target_triples",
7861 [&](Attribute attr) {
7862 if (
auto triplesAttr = dyn_cast<ArrayAttr>(attr)) {
7863 llvm::OpenMPIRBuilderConfig &config =
7865 config.TargetTriples.clear();
7866 config.TargetTriples.reserve(triplesAttr.size());
7867 for (Attribute tripleAttr : triplesAttr) {
7868 if (
auto tripleStrAttr = dyn_cast<StringAttr>(tripleAttr))
7869 config.TargetTriples.emplace_back(tripleStrAttr.getValue());
7877 .Default([](Attribute) {
7893 if (
auto declareTargetIface =
7894 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
7895 parentFn.getOperation()))
7896 if (declareTargetIface.isDeclareTarget() &&
7897 declareTargetIface.getDeclareTargetDeviceType() !=
7898 mlir::omp::DeclareTargetDeviceType::host)
7908 llvm::Module *llvmModule) {
7909 llvm::Type *i64Ty = builder.getInt64Ty();
7910 llvm::Type *i32Ty = builder.getInt32Ty();
7911 llvm::Type *returnType = builder.getPtrTy(0);
7912 llvm::FunctionType *fnType =
7913 llvm::FunctionType::get(returnType, {i64Ty, i32Ty},
false);
7914 llvm::Function *
func = cast<llvm::Function>(
7915 llvmModule->getOrInsertFunction(
"omp_target_alloc", fnType).getCallee());
7919template <
typename T>
7923 llvm::DataLayout dataLayout =
7925 llvm::Type *llvmHeapTy =
7926 moduleTranslation.
convertType(op.getMemElemTypeAttr().getValue());
7928 auto alignment = op.getMemAlignment();
7929 llvm::TypeSize typeSize = llvm::alignTo(
7930 dataLayout.getTypeStoreSize(llvmHeapTy),
7931 alignment ? *alignment : dataLayout.getABITypeAlign(llvmHeapTy).value());
7933 llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue());
7934 return builder.CreateMul(
7936 builder.CreateIntCast(moduleTranslation.
lookupValue(op.getMemArraySize()),
7937 builder.getInt64Ty(),
7944 omp::TargetAllocMemOp op) {
7945 llvm::DataLayout dataLayout =
7947 llvm::Type *llvmHeapTy = moduleTranslation.
convertType(op.getAllocatedType());
7948 llvm::TypeSize typeSize = dataLayout.getTypeAllocSize(llvmHeapTy);
7949 llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue());
7950 for (
auto typeParam : op.getTypeparams()) {
7951 allocSize = builder.CreateMul(
7953 builder.CreateIntCast(moduleTranslation.
lookupValue(typeParam),
7954 builder.getInt64Ty(),
7963 auto allocMemOp = cast<omp::TargetAllocMemOp>(opInst);
7968 llvm::Module *llvmModule = moduleTranslation.
getLLVMModule();
7972 llvm::Value *llvmDeviceNum = moduleTranslation.
lookupValue(deviceNum);
7974 llvm::Value *allocSize =
7977 llvm::CallInst *call =
7978 builder.CreateCall(ompTargetAllocFunc, {allocSize, llvmDeviceNum});
7979 llvm::Value *resultI64 = builder.CreatePtrToInt(call, builder.getInt64Ty());
7982 moduleTranslation.
mapValue(allocMemOp.getResult(), resultI64);
7988 llvm::IRBuilderBase &builder,
7992 moduleTranslation.
mapValue(allocMemOp.getResult(),
7993 ompBuilder->createOMPAllocShared(builder, size));
8000 const OpenMPDialectLLVMIRTranslationInterface &ompIface) {
8001 auto allocateDirOp = cast<omp::AllocateDirOp>(opInst);
8004 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
8005 llvm::Module *llvmModule = moduleTranslation.
getLLVMModule();
8006 llvm::DataLayout dataLayout = llvmModule->getDataLayout();
8008 std::optional<int64_t> alignAttr = allocateDirOp.getAlign();
8010 llvm::Value *allocator;
8011 if (
auto allocatorVar = allocateDirOp.getAllocator()) {
8012 allocator = moduleTranslation.
lookupValue(allocatorVar);
8013 if (allocator->getType()->isIntegerTy())
8014 allocator = builder.CreateIntToPtr(allocator, builder.getPtrTy());
8015 else if (allocator->getType()->isPointerTy())
8016 allocator = builder.CreatePointerBitCastOrAddrSpaceCast(
8017 allocator, builder.getPtrTy());
8019 allocator = llvm::ConstantPointerNull::get(builder.getPtrTy());
8022 for (
Value var : vars) {
8023 llvm::Type *llvmVarTy = moduleTranslation.
convertType(var.getType());
8027 llvm::Type *typeToInspect = llvmVarTy;
8028 if (llvmVarTy->isPointerTy()) {
8031 if (
auto gop = dyn_cast<LLVM::GlobalOp>(globalOp))
8032 typeToInspect = moduleTranslation.
convertType(gop.getGlobalType());
8037 if (
auto arrTy = llvm::dyn_cast<llvm::ArrayType>(typeToInspect)) {
8038 llvm::Value *elementCount = builder.getInt64(1);
8039 llvm::Type *currentType = arrTy;
8040 while (
auto nestedArrTy = llvm::dyn_cast<llvm::ArrayType>(currentType)) {
8041 elementCount = builder.CreateMul(
8042 elementCount, builder.getInt64(nestedArrTy->getNumElements()));
8043 currentType = nestedArrTy->getElementType();
8045 uint64_t elemSizeInBits = dataLayout.getTypeSizeInBits(currentType);
8047 builder.CreateMul(elementCount, builder.getInt64(elemSizeInBits / 8));
8049 size = builder.getInt64(
8050 dataLayout.getTypeStoreSize(typeToInspect).getFixedValue());
8053 uint64_t alignValue =
8054 alignAttr ? alignAttr.value()
8055 : dataLayout.getABITypeAlign(typeToInspect).value();
8056 llvm::Value *alignConst = builder.getInt64(alignValue);
8058 size = builder.CreateAdd(size, builder.getInt64(alignValue - 1),
"",
true);
8059 size = builder.CreateUDiv(size, alignConst);
8060 size = builder.CreateMul(size, alignConst,
"",
true);
8062 std::string allocName =
8063 ompBuilder->createPlatformSpecificName({
".void.addr"});
8064 llvm::CallInst *allocCall;
8065 if (alignAttr.has_value()) {
8066 allocCall = ompBuilder->createOMPAlignedAlloc(
8067 ompLoc, builder.getInt64(alignAttr.value()), size, allocator,
8071 ompBuilder->createOMPAlloc(ompLoc, size, allocator, allocName);
8074 ompIface.registerAllocatedPtr(var, allocCall);
8083 const OpenMPDialectLLVMIRTranslationInterface &ompIface) {
8084 auto freeOp = cast<omp::AllocateFreeOp>(opInst);
8086 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
8088 llvm::Value *allocator;
8089 if (
auto allocatorVar = freeOp.getAllocator()) {
8090 allocator = moduleTranslation.
lookupValue(allocatorVar);
8091 if (allocator->getType()->isIntegerTy())
8092 allocator = builder.CreateIntToPtr(allocator, builder.getPtrTy());
8093 else if (allocator->getType()->isPointerTy())
8094 allocator = builder.CreatePointerBitCastOrAddrSpaceCast(
8095 allocator, builder.getPtrTy());
8097 allocator = llvm::ConstantPointerNull::get(builder.getPtrTy());
8102 for (
Value var : llvm::reverse(vars)) {
8103 llvm::Value *allocPtr = ompIface.lookupAllocatedPtr(var);
8105 return opInst.
emitError(
"omp.allocate_free: no allocation recorded");
8106 ompBuilder->createOMPFree(ompLoc, allocPtr, allocator,
"");
8113 llvm::Module *llvmModule) {
8114 llvm::Type *ptrTy = builder.getPtrTy(0);
8115 llvm::Type *i32Ty = builder.getInt32Ty();
8116 llvm::Type *voidTy = builder.getVoidTy();
8117 llvm::FunctionType *fnType =
8118 llvm::FunctionType::get(voidTy, {ptrTy, i32Ty},
false);
8119 llvm::Function *
func = dyn_cast<llvm::Function>(
8120 llvmModule->getOrInsertFunction(
"omp_target_free", fnType).getCallee());
8127 auto freeMemOp = cast<omp::TargetFreeMemOp>(opInst);
8132 llvm::Module *llvmModule = moduleTranslation.
getLLVMModule();
8136 llvm::Value *llvmDeviceNum = moduleTranslation.
lookupValue(deviceNum);
8139 llvm::Value *llvmHeapref = moduleTranslation.
lookupValue(heapref);
8141 llvm::Value *intToPtr =
8142 builder.CreateIntToPtr(llvmHeapref, builder.getPtrTy(0));
8143 builder.CreateCall(ompTragetFreeFunc, {intToPtr, llvmDeviceNum});
8149 llvm::IRBuilderBase &builder,
8153 ompBuilder->createOMPFreeShared(
8154 builder, moduleTranslation.
lookupValue(freeMemOp.getHeapref()), size);
8163 auto groupprivateOp = cast<omp::GroupprivateOp>(opInst);
8168 bool isTargetDevice = ompBuilder->Config.isTargetDevice();
8172 bool shouldAllocate =
true;
8173 switch (groupprivateOp.getDeviceType().value_or(
8174 mlir::omp::DeclareTargetDeviceType::any)) {
8175 case mlir::omp::DeclareTargetDeviceType::host:
8176 shouldAllocate = !isTargetDevice;
8178 case mlir::omp::DeclareTargetDeviceType::nohost:
8179 shouldAllocate = isTargetDevice;
8181 case mlir::omp::DeclareTargetDeviceType::any:
8182 shouldAllocate =
true;
8188 &opInst, groupprivateOp.getSymNameAttr());
8191 <<
"expected symbol '" << groupprivateOp.getSymName()
8192 <<
"' to reference an LLVM global variable";
8194 llvm::GlobalValue *globalValue = moduleTranslation.
lookupGlobal(global);
8195 llvm::Type *varType = moduleTranslation.
convertType(global.getType());
8196 std::string varName = globalValue->getName().str();
8198 llvm::Value *resultPtr;
8199 if (shouldAllocate && isTargetDevice) {
8200 llvm::Module *llvmModule = moduleTranslation.
getLLVMModule();
8201 llvm::Triple targetTriple(llvmModule->getTargetTriple());
8202 unsigned sharedAddressSpace;
8203 if (targetTriple.isAMDGCN())
8204 sharedAddressSpace = llvm::AMDGPUAS::LOCAL_ADDRESS;
8205 else if (targetTriple.isNVPTX())
8206 sharedAddressSpace = llvm::NVPTXAS::ADDRESS_SPACE_SHARED;
8208 return opInst.
emitError() <<
"groupprivate is not supported for target: "
8209 << targetTriple.str();
8210 llvm::GlobalVariable *sharedVar =
new llvm::GlobalVariable(
8211 *llvmModule, varType,
false,
8212 llvm::GlobalValue::InternalLinkage, llvm::PoisonValue::get(varType),
8213 varName,
nullptr, llvm::GlobalValue::NotThreadLocal,
8216 resultPtr = sharedVar;
8218 if (shouldAllocate && !isTargetDevice)
8219 opInst.
emitWarning(
"groupprivate directive is currently ignored on the "
8220 "host, using original global");
8221 resultPtr = globalValue;
8230LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
8231 Operation *op, llvm::IRBuilderBase &builder,
8232 LLVM::ModuleTranslation &moduleTranslation)
const {
8235 if (ompBuilder->Config.isTargetDevice() &&
8236 !isa<omp::TargetOp, omp::MapInfoOp, omp::TerminatorOp, omp::YieldOp>(
8239 return op->
emitOpError() <<
"unsupported host op found in device";
8247 bool isOutermostLoopWrapper =
8248 isa_and_present<omp::LoopWrapperInterface>(op) &&
8249 !dyn_cast_if_present<omp::LoopWrapperInterface>(op->
getParentOp());
8258 if (isa<omp::TaskloopContextOp>(op))
8259 isOutermostLoopWrapper =
true;
8260 else if (isa<omp::TaskloopWrapperOp>(op))
8261 isOutermostLoopWrapper =
false;
8263 if (isOutermostLoopWrapper)
8264 moduleTranslation.
stackPush<OpenMPLoopInfoStackFrame>();
8267 llvm::TypeSwitch<Operation *, LogicalResult>(op)
8268 .Case([&](omp::BarrierOp op) -> LogicalResult {
8272 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
8273 ompBuilder->createBarrier(builder.saveIP(),
8274 llvm::omp::OMPD_barrier);
8276 if (res.succeeded()) {
8279 builder.restoreIP(*afterIP);
8283 .Case([&](omp::TaskyieldOp op) {
8287 ompBuilder->createTaskyield(builder.saveIP());
8290 .Case([&](omp::FlushOp op) {
8302 ompBuilder->createFlush(builder.saveIP());
8305 .Case([&](omp::ParallelOp op) {
8308 .Case([&](omp::MaskedOp) {
8311 .Case([&](omp::MasterOp) {
8314 .Case([&](omp::CriticalOp) {
8317 .Case([&](omp::OrderedRegionOp) {
8320 .Case([&](omp::OrderedOp) {
8323 .Case([&](omp::WsloopOp) {
8326 .Case([&](omp::SimdOp) {
8329 .Case([&](omp::AtomicReadOp) {
8332 .Case([&](omp::AtomicWriteOp) {
8335 .Case([&](omp::AtomicUpdateOp op) {
8338 .Case([&](omp::AtomicCaptureOp op) {
8341 .Case([&](omp::CancelOp op) {
8344 .Case([&](omp::CancellationPointOp op) {
8347 .Case([&](omp::SectionsOp) {
8350 .Case([&](omp::ScopeOp op) {
8353 .Case([&](omp::SingleOp op) {
8356 .Case([&](omp::TeamsOp op) {
8359 .Case([&](omp::TaskOp op) {
8362 .Case([&](omp::TaskloopWrapperOp op) {
8365 .Case([&](omp::TaskloopContextOp op) {
8368 .Case([&](omp::TaskgroupOp op) {
8371 .Case([&](omp::TaskwaitOp op) {
8374 .Case<omp::YieldOp, omp::TerminatorOp, omp::DeclareMapperOp,
8375 omp::DeclareMapperInfoOp, omp::DeclareReductionOp,
8376 omp::CriticalDeclareOp>([](
auto op) {
8389 .Case([&](omp::ThreadprivateOp) {
8392 .Case<omp::TargetDataOp, omp::TargetEnterDataOp,
8393 omp::TargetExitDataOp, omp::TargetUpdateOp>([&](
auto op) {
8396 .Case([&](omp::TargetOp) {
8399 .Case([&](omp::DistributeOp) {
8402 .Case([&](omp::LoopNestOp) {
8405 .Case<omp::MapInfoOp, omp::MapBoundsOp, omp::PrivateClauseOp,
8406 omp::AffinityEntryOp, omp::IteratorOp>([&](
auto op) {
8412 .Case([&](omp::NewCliOp op) {
8417 .Case([&](omp::CanonicalLoopOp op) {
8420 .Case([&](omp::UnrollHeuristicOp op) {
8429 .Case([&](omp::TileOp op) {
8430 return applyTile(op, builder, moduleTranslation);
8432 .Case([&](omp::FuseOp op) {
8433 return applyFuse(op, builder, moduleTranslation);
8435 .Case([&](omp::TargetAllocMemOp) {
8438 .Case([&](omp::TargetFreeMemOp) {
8441 .Case([&](omp::AllocateDirOp) {
8444 .Case([&](omp::AllocateFreeOp) {
8448 .Case([&](omp::AllocSharedMemOp op) {
8451 .Case([&](omp::FreeSharedMemOp op) {
8454 .Case([&](omp::GroupprivateOp) {
8457 .Default([&](Operation *inst) {
8459 <<
"not yet implemented: " << inst->
getName();
8462 if (isOutermostLoopWrapper)
8469 registry.
insert<omp::OpenMPDialect>();
8471 dialect->addInterfaces<OpenMPDialectLLVMIRTranslationInterface>();
static bool mapTypeToBool(ClauseMapFlags value, ClauseMapFlags flag)
static mlir::LogicalResult buildDependData(OperandRange dependVars, std::optional< ArrayAttr > dependKinds, OperandRange dependIterated, std::optional< ArrayAttr > dependIteratedKinds, llvm::IRBuilderBase &builder, mlir::LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::DependenciesInfo &taskDeps)
static void handleDeclareTargetMapVar(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::Function *func)
static LogicalResult convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP atomic update operation using OpenMPIRBuilder.
static llvm::omp::OrderKind convertOrderKind(std::optional< omp::ClauseOrderKind > o)
Convert Order attribute to llvm::omp::OrderKind.
static llvm::OpenMPIRBuilder::InsertPointTy findAllocInsertPoints(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::SmallVectorImpl< llvm::BasicBlock * > *deallocBlocks=nullptr)
Find the insertion point for allocas given the current insertion point for normal operations in the b...
static LogicalResult convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
owningDataPtrPtrReductionGens[i]
static LogicalResult convertOmpTaskloopContextOp(omp::TaskloopContextOp contextOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static Operation * getGlobalOpFromValue(Value value)
static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind convertToCaptureClauseKind(mlir::omp::DeclareTargetCaptureClause captureClause)
static mlir::LogicalResult convertIteratorRegion(llvm::Value *linearIV, IteratorInfo &iterInfo, mlir::Block &iteratorRegionBlock, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo, bool first)
static OpTy castOrGetParentOfType(Operation *op, bool immediateParent=false)
If op is of the given type parameter, return it casted to that type. Otherwise, if its immediate pare...
static LogicalResult convertOmpOrderedRegion(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'ordered_region' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertFreeSharedMemOp(omp::FreeSharedMemOp freeMemOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertTargetFreeMemOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static void processMapMembersWithParent(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, llvm::omp::OpenMPOffloadMappingFlags memberOfFlag, TargetDirectiveEnumTy targetDirective)
static LogicalResult convertOmpAtomicWrite(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an omp.atomic.write operation to LLVM IR.
static OwningAtomicReductionGen makeAtomicReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Create an OpenMPIRBuilder-compatible atomic reduction generator for the given reduction declaration.
static OwningDataPtrPtrReductionGen makeRefDataPtrGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, bool isByRef)
Create an OpenMPIRBuilder-compatible data_ptr_ptr reduction generator for the given reduction declara...
static void popCancelFinalizationCB(const ArrayRef< llvm::UncondBrInst * > cancelTerminators, llvm::OpenMPIRBuilder &ompBuilder, const llvm::OpenMPIRBuilder::InsertPointTy &afterIP)
If we cancelled the construct, we should branch to the finalization block of that construct....
static llvm::Value * getRefPtrIfDeclareTarget(Value value, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult applyUnrollHeuristic(omp::UnrollHeuristicOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Apply a #pragma omp unroll / "!$omp unroll" transformation using the OpenMPIRBuilder.
static LogicalResult convertOmpMaster(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'master' operation into LLVM IR using OpenMPIRBuilder.
static void getAsIntegers(ArrayAttr values, llvm::SmallVector< int64_t > &ints)
static llvm::Value * findAssociatedValue(Value privateVar, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Return the llvm::Value * corresponding to the privateVar that is being privatized....
static ArrayRef< bool > getIsByRef(std::optional< ArrayRef< bool > > attr)
static llvm::Expected< llvm::Value * > lookupOrTranslatePureValue(Value value, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder)
Look up the given value in the mapping, and if it's not there, translate its defining operation at th...
static LogicalResult allocReductionVars(T op, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, SmallVectorImpl< DeferredStore > &deferredStores, llvm::ArrayRef< bool > isByRefs)
Allocate space for privatized reduction variables.
static LogicalResult convertOmpTaskwaitOp(omp::TaskwaitOp twOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP loop nest into LLVM IR using OpenMPIRBuilder.
static mlir::LogicalResult fillIteratorLoop(mlir::omp::IteratorOp itersOp, llvm::IRBuilderBase &builder, mlir::LLVM::ModuleTranslation &moduleTranslation, IteratorInfo &iterInfo, llvm::StringRef loopName, IteratorStoreEntryTy genStoreEntry)
static llvm::Expected< llvm::BasicBlock * > allocatePrivateVars(T op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, PrivateVarsInfo &privateVarsInfo, const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Allocate and initialize delayed private variables. Returns the basic block which comes after all of t...
static void createAlteredByCaptureMap(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder)
static LogicalResult convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP task construct into LLVM IR using OpenMPIRBuilder.
static void genMapInfos(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, TargetDirectiveEnumTy targetDirective)
static llvm::AtomicOrdering convertAtomicOrdering(std::optional< omp::ClauseMemoryOrderKind > ao)
Convert an Atomic Ordering attribute to llvm::AtomicOrdering.
static void setInsertPointForPossiblyEmptyBlock(llvm::IRBuilderBase &builder, llvm::BasicBlock *block=nullptr)
llvm::function_ref< void(llvm::Value *linearIV, mlir::omp::YieldOp yield)> IteratorStoreEntryTy
static LogicalResult convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'critical' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, TargetDirectiveEnumTy targetDirective)
static omp::DistributeOp getDistributeCapturingTeamsReduction(omp::TeamsOp teamsOp)
static LogicalResult convertOmpCanonicalLoopOp(omp::CanonicalLoopOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Convert an omp.canonical_loop to LLVM-IR.
static LogicalResult convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static std::optional< int64_t > extractConstInteger(Value value)
If the given value is defined by an llvm.mlir.constant operation and it is of an integer type,...
static llvm::Expected< llvm::Value * > initPrivateVar(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, omp::PrivateClauseOp &privDecl, llvm::Value *nonPrivateVar, BlockArgument &blockArg, llvm::Value *llvmPrivateVar, llvm::BasicBlock *privInitBlock, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Initialize a single (first)private variable. You probably want to use allocateAndInitPrivateVars inst...
static mlir::LogicalResult buildAffinityData(mlir::omp::TaskOp &taskOp, llvm::IRBuilderBase &builder, mlir::LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::AffinityData &ad)
static LogicalResult allocAndInitializeReductionVars(OP op, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, llvm::ArrayRef< bool > isByRef)
static LogicalResult convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Value * getAllocationSize(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, T op)
static llvm::Function * getOmpTargetAlloc(llvm::IRBuilderBase &builder, llvm::Module *llvmModule)
static llvm::omp::OMPDynGroupprivateFallbackType getDynGroupprivateFallbackType(omp::FallbackModifierAttr fallbackAttr)
static llvm::Expected< llvm::Function * > emitUserDefinedMapper(Operation *declMapperOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::StringRef mapperFuncName, TargetDirectiveEnumTy targetDirective)
static LogicalResult convertOmpOrdered(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'ordered' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult cleanupPrivateVars(T op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, Location loc, PrivateVarsInfo &privateVarsInfo)
static void processMapWithMembersOf(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, TargetDirectiveEnumTy targetDirective)
static LogicalResult convertOmpMasked(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'masked' operation into LLVM IR using OpenMPIRBuilder.
static llvm::AtomicRMWInst::BinOp convertBinOpToAtomic(Operation &op)
Converts an LLVM dialect binary operation to the corresponding enum value for atomicrmw supported bin...
static LogicalResult convertOmpCancel(omp::CancelOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp)
allocatedType moduleTranslation static convertType(allocatedType) LogicalResult inlineOmpRegionCleanup(llvm::SmallVectorImpl< Region * > &cleanupRegions, llvm::ArrayRef< llvm::Value * > privateVariables, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, StringRef regionName, bool shouldLoadCleanupRegionArg=true)
handling of DeclareReductionOp's cleanup region
static LogicalResult applyFuse(omp::FuseOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Apply a #pragma omp fuse / !$omp fuse transformation using the OpenMPIRBuilder.
static LogicalResult convertOmpScope(omp::ScopeOp &scopeOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP scope construct into LLVM IR.
static llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Error initPrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, PrivateVarsInfo &privateVarsInfo, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
static LogicalResult convertAllocSharedMemOp(omp::AllocSharedMemOp allocMemOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::CanonicalLoopInfo * findCurrentLoopInfo(LLVM::ModuleTranslation &moduleTranslation)
Find the loop information structure for the loop nest being translated.
static OwningReductionGen makeReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Create an OpenMPIRBuilder-compatible reduction generator for the given reduction declaration.
static std::vector< llvm::Value * > calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, bool isArrayTy, OperandRange bounds)
This function calculates the array/pointer offset for map data provided with bounds operations,...
static void storeAffinityEntry(llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, llvm::Value *affinityList, llvm::Value *index, llvm::Value *addr, llvm::Value *len)
static LogicalResult convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts the OpenMP parallel operation to LLVM IR.
static void pushCancelFinalizationCB(SmallVectorImpl< llvm::UncondBrInst * > &cancelTerminators, llvm::IRBuilderBase &llvmBuilder, llvm::OpenMPIRBuilder &ompBuilder, mlir::Operation *op, llvm::omp::Directive cancelDirective)
Shared implementation of a callback which adds a termiator for the new block created for the branch t...
static LogicalResult inlineConvertOmpRegions(Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::Value * > *continuationBlockArgs=nullptr)
Translates the blocks contained in the given region and appends them to at the current insertion poin...
static void getTargetEntryUniqueInfo(llvm::TargetRegionEntryInfo &targetInfo, omp::TargetOp targetOp, llvm::OpenMPIRBuilder &ompBuilder, llvm::vfs::FileSystem &vfs, llvm::StringRef parentName="")
static LogicalResult convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP Threadprivate operation into LLVM IR using OpenMPIRBuilder.
static omp::PrivateClauseOp findPrivatizer(Operation *from, SymbolRefAttr symbolName)
Looks up from the operation from and returns the PrivateClauseOp with name symbolName.
static LogicalResult convertOmpGroupprivate(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP groupprivate operation into LLVM IR.
static llvm::Expected< llvm::Function * > getOrCreateUserDefinedMapperFunc(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, TargetDirectiveEnumTy targetDirective)
static uint64_t getTypeByteSize(mlir::Type type, const DataLayout &dl)
static void processIndividualMap(MapInfoData &mapData, size_t mapDataIdx, MapInfosTy &combinedInfo, TargetDirectiveEnumTy targetDirective, int mapDataParentIdx=-1)
static llvm::SmallString< 64 > getDeclareTargetRefPtrSuffix(LLVM::GlobalOp globalOp, llvm::OpenMPIRBuilder &ompBuilder, llvm::vfs::FileSystem &vfs)
static void extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads, Value &numTeamsLower, Value &numTeamsUpper, Value &threadLimit, llvm::SmallVectorImpl< Value > *lowerBounds=nullptr, llvm::SmallVectorImpl< Value > *upperBounds=nullptr, llvm::SmallVectorImpl< Value > *steps=nullptr)
Follow uses of host_eval-defined block arguments of the given omp.target operation and populate outpu...
static llvm::Expected< llvm::BasicBlock * > convertOmpOpRegions(Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::PHINode * > *continuationBlockPHIs=nullptr)
Converts the given region that appears within an OpenMP dialect operation to LLVM IR,...
static LogicalResult copyFirstPrivateVars(mlir::Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::Value * > &moldVars, ArrayRef< llvm::Value * > llvmPrivateVars, SmallVectorImpl< omp::PrivateClauseOp > &privateDecls, bool insertBarrier, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
static LogicalResult convertAllocateDirOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, const OpenMPDialectLLVMIRTranslationInterface &ompIface)
static bool constructIsCancellable(Operation *op)
Returns true if the construct contains omp.cancel or omp.cancellation_point.
static llvm::omp::OpenMPOffloadMappingFlags convertClauseMapFlags(omp::ClauseMapFlags mlirFlags)
static void buildDependDataLocator(std::optional< ArrayAttr > dependKinds, OperandRange dependVars, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::OpenMPIRBuilder::DependData > &dds)
static void mapInitializationArgs(T loop, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, DenseMap< Value, llvm::Value * > &reductionVariableMap, unsigned i)
Map input arguments to reduction initialization region.
static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind)
Convert ProcBindKind from MLIR-generated enum to LLVM enum.
static void fillAffinityLocators(Operation::operand_range affinityVars, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::Value *affinityList)
static LogicalResult convertOmpTaskloopWrapperOp(omp::TaskloopWrapperOp loopWrapperOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
The correct entry point is convertOmpTaskloopContextOp. This gets called whilst lowering the body of ...
static void getOverlappedMembers(llvm::SmallVectorImpl< size_t > &overlapMapDataIdxs, omp::MapInfoOp parentOp)
static LogicalResult convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP single construct into LLVM IR using OpenMPIRBuilder.
static bool isDeclareTargetTo(Value value)
static uint64_t getArrayElementSizeInBits(LLVM::LLVMArrayType arrTy, DataLayout &dl)
static void collectReductionDecls(T op, SmallVectorImpl< omp::DeclareReductionOp > &reductions)
Populates reductions with reduction declarations used in the given op.
static LogicalResult handleError(llvm::Error error, Operation &op)
static LogicalResult convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind convertToDeviceClauseKind(mlir::omp::DeclareTargetDeviceType deviceClause)
static llvm::Error computeTaskloopBounds(omp::LoopNestOp loopOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::Value *&lbVal, llvm::Value *&ubVal, llvm::Value *&stepVal)
static LogicalResult checkImplementationStatus(Operation &op)
Check whether translation to LLVM IR for the given operation is currently supported.
static llvm::IRBuilderBase::InsertPoint createDeviceArgumentAccessor(omp::TargetOp targetOp, MapInfoData &mapData, llvm::Argument &arg, llvm::Value *input, llvm::Value *&retVal, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase::InsertPoint allocaIP, llvm::IRBuilderBase::InsertPoint codeGenIP, llvm::ArrayRef< llvm::IRBuilderBase::InsertPoint > deallocIPs)
static LogicalResult createReductionsAndCleanup(OP op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, ArrayRef< llvm::Value * > privateReductionVariables, ArrayRef< bool > isByRef, bool isNowait=false, bool isTeamsReduction=false)
static LogicalResult convertOmpCancellationPoint(omp::CancellationPointOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static bool opIsInSingleThread(mlir::Operation *op)
This can't always be determined statically, but when we can, it is good to avoid generating compiler-...
static uint64_t getReductionDataSize(OpTy &op)
static LogicalResult convertOmpAtomicRead(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Convert omp.atomic.read operation to LLVM IR.
static llvm::omp::Directive convertCancellationConstructType(omp::ClauseCancellationConstructType directive)
static void initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp, llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs, bool isTargetDevice, bool isGPU)
Populate default MinTeams, MaxTeams and MaxThreads to their default values as stated by the correspon...
static llvm::omp::RTLDependenceKindTy convertDependKind(mlir::omp::ClauseTaskDepend kind)
static void initTargetRuntimeAttrs(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, omp::TargetOp targetOp, Operation *capturedOp, llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs)
Gather LLVM runtime values for all clauses evaluated in the host that are passed to the kernel invoca...
static LogicalResult convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static Value getBaseValueForTypeLookup(Value value)
static bool isHostDeviceOp(Operation *op)
static LogicalResult convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute, llvm::OpenMPIRBuilder *ompBuilder, LLVM::ModuleTranslation &moduleTranslation)
static bool isDeclareTargetLink(Value value)
static LogicalResult convertFlagsAttr(Operation *op, mlir::omp::FlagsAttr attribute, LLVM::ModuleTranslation &moduleTranslation)
Lowers the FlagsAttr which is applied to the module on the device pass when offloading,...
static bool checkIfPointerMap(omp::MapInfoOp mapOp)
static LogicalResult applyTile(omp::TileOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Apply a #pragma omp tile / !$omp tile transformation using the OpenMPIRBuilder.
static LogicalResult convertAllocateFreeOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, const OpenMPDialectLLVMIRTranslationInterface &ompIface)
static void sortMapIndices(llvm::SmallVectorImpl< size_t > &indices, omp::MapInfoOp mapInfo)
static llvm::Function * getOmpTargetFree(llvm::IRBuilderBase &builder, llvm::Module *llvmModule)
static LogicalResult convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder.
static void collectMapDataFromMapOperands(MapInfoData &mapData, SmallVectorImpl< Value > &mapVars, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, llvm::IRBuilderBase &builder, ArrayRef< Value > useDevPtrOperands={}, ArrayRef< Value > useDevAddrOperands={}, ArrayRef< Value > hasDevAddrOperands={})
static void extractAtomicControlFlags(omp::AtomicUpdateOp atomicUpdateOp, bool &isIgnoreDenormalMode, bool &isFineGrainedMemory, bool &isRemoteMemory)
static Operation * genLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr, unsigned numCases, bool needsUniv, ArrayRef< TensorLevel > tidLvls)
Generates a for-loop or a while-loop, depending on whether it implements singleton iteration or co-it...
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
An instance of this location represents a tuple of file, line number, and column number.
Implementation class for module translation.
llvm::BasicBlock * lookupBlock(Block *block) const
Finds an LLVM IR basic block that corresponds to the given MLIR block.
WalkResult stackWalk(llvm::function_ref< WalkResult(T &)> callback)
Calls callback for every ModuleTranslation stack frame of type T starting from the top of the stack.
void stackPush(Args &&...args)
Creates a stack frame of type T on ModuleTranslation stack.
LogicalResult convertBlock(Block &bb, bool ignoreArguments, llvm::IRBuilderBase &builder)
Translates the contents of the given block to LLVM IR using this translator.
SmallVector< llvm::Value * > lookupValues(ValueRange values)
Looks up remapped a list of remapped values.
void mapFunction(StringRef name, llvm::Function *func)
Stores the mapping between a function name and its LLVM IR representation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
void invalidateOmpLoop(omp::NewCliOp mlir)
Mark an OpenMP loop as having been consumed.
SymbolTableCollection & symbolTable()
llvm::Type * convertType(Type type)
Converts the type from MLIR LLVM dialect to LLVM.
llvm::OpenMPIRBuilder * getOpenMPBuilder()
Returns the OpenMP IR builder associated with the LLVM IR module being constructed.
llvm::vfs::FileSystem & getFileSystem()
Returns the virtual filesystem to use for file operations.
void mapOmpLoop(omp::NewCliOp mlir, llvm::CanonicalLoopInfo *llvm)
Map an MLIR OpenMP dialect CanonicalLoopInfo to its lowered LLVM-IR OpenMPIRBuilder CanonicalLoopInfo...
llvm::GlobalValue * lookupGlobal(Operation *op)
Finds an LLVM IR global value that corresponds to the given MLIR operation defining a global value.
SaveStateStack< T, ModuleTranslation > SaveStack
RAII object calling stackPush/stackPop on construction/destruction.
LogicalResult convertOperation(Operation &op, llvm::IRBuilderBase &builder)
Converts the given MLIR operation into LLVM IR using this translator.
llvm::Function * lookupFunction(StringRef name) const
Finds an LLVM IR function by its name.
void mapBlock(Block *mlir, llvm::BasicBlock *llvm)
Stores the mapping between an MLIR block and LLVM IR basic block.
llvm::Module * getLLVMModule()
Returns the LLVM module in which the IR is being constructed.
void stackPop()
Pops the last element from the ModuleTranslation stack.
void forgetMapping(Region ®ion)
Removes the mapping for blocks contained in the region and values defined in these blocks.
void mapValue(Value mlir, llvm::Value *llvm)
Stores the mapping between an MLIR value and its LLVM IR counterpart.
llvm::CanonicalLoopInfo * lookupOMPLoop(omp::NewCliOp mlir) const
Find the LLVM-IR loop that represents an MLIR loop.
llvm::LLVMContext & getLLVMContext() const
Returns the LLVM context in which the IR is being constructed.
Utility class to translate MLIR LLVM dialect types to LLVM IR.
unsigned getPreferredAlignment(Type type, const llvm::DataLayout &layout)
Returns the preferred alignment for the type given the data layout.
T findInstanceOf()
Return an instance of the given location type if one is nested under the current location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void appendDialectRegistry(const DialectRegistry ®istry)
Append the contents of the given dialect registry to the registry associated with this context.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
This class implements the operand iterators for the Operation class.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Value getOperand(unsigned idx)
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
unsigned getNumOperands()
OperandRange operand_range
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
user_range getUsers()
Returns a range of all users.
result_range getResults()
MLIRContext * getContext()
Return the context this operation is associated with.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgListType getArguments()
unsigned getNumArguments()
Operation * getParentOp()
Return the parent operation this region is attached to.
BlockListType & getBlocks()
bool hasOneBlock()
Return true if this region has exactly one block.
Concrete CRTP base class for StateStack frames.
@ Private
The symbol is private and may only be referenced by SymbolRefAttrs local to the operations within the...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkResult interrupt()
The OpAsmOpInterface, see OpAsmInterface.td for more details.
void connectPHINodes(Region ®ion, const ModuleTranslation &state)
For all blocks in the region that were converted to LLVM IR using the given ModuleTranslation,...
llvm::Constant * createMappingInformation(Location loc, llvm::OpenMPIRBuilder &builder)
Create a constant string representing the mapping information extracted from the MLIR location inform...
bool opInSharedDeviceContext(Operation &op)
Check whether the given operation is located in a context where an allocation to be used by multiple ...
bool allocaUsesRequireSharedMem(Value alloc)
Check whether the value representing an allocation, assumed to have been defined in a shared device c...
auto getDims(VectorType vType)
Returns a range over the dims (size and scalability) of a VectorType.
Include the generated interface declarations.
SetVector< Block * > getBlocksSortedByDominance(Region ®ion)
Gets a list of blocks that is sorted according to dominance.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool isPure(Operation *op)
Returns true if the given operation is pure, i.e., is speculatable that does not touch memory.
void registerOpenMPDialectTranslation(DialectRegistry ®istry)
Register the OpenMP dialect and the translation from it to the LLVM IR in the given registry;.
llvm::SetVector< T, Vector, Set, N > SetVector
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
A util to collect info needed to convert delayed privatizers from MLIR to LLVM.
SmallVector< mlir::Value > mlirVars
SmallVector< omp::PrivateClauseOp > privatizers
MutableArrayRef< BlockArgument > blockArgs
SmallVector< llvm::Value * > llvmVars