24#include "llvm/ADT/ArrayRef.h"
25#include "llvm/ADT/SmallVector.h"
26#include "llvm/ADT/TypeSwitch.h"
27#include "llvm/Frontend/OpenMP/OMPConstants.h"
28#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
29#include "llvm/IR/Constants.h"
30#include "llvm/IR/DebugInfoMetadata.h"
31#include "llvm/IR/DerivedTypes.h"
32#include "llvm/IR/IRBuilder.h"
33#include "llvm/IR/MDBuilder.h"
34#include "llvm/IR/ReplaceConstant.h"
35#include "llvm/Support/FileSystem.h"
36#include "llvm/Support/VirtualFileSystem.h"
37#include "llvm/TargetParser/Triple.h"
38#include "llvm/Transforms/Utils/ModuleUtils.h"
49static llvm::omp::ScheduleKind
50convertToScheduleKind(std::optional<omp::ClauseScheduleKind> schedKind) {
51 if (!schedKind.has_value())
52 return llvm::omp::OMP_SCHEDULE_Default;
53 switch (schedKind.value()) {
54 case omp::ClauseScheduleKind::Static:
55 return llvm::omp::OMP_SCHEDULE_Static;
56 case omp::ClauseScheduleKind::Dynamic:
57 return llvm::omp::OMP_SCHEDULE_Dynamic;
58 case omp::ClauseScheduleKind::Guided:
59 return llvm::omp::OMP_SCHEDULE_Guided;
60 case omp::ClauseScheduleKind::Auto:
61 return llvm::omp::OMP_SCHEDULE_Auto;
62 case omp::ClauseScheduleKind::Runtime:
63 return llvm::omp::OMP_SCHEDULE_Runtime;
64 case omp::ClauseScheduleKind::Distribute:
65 return llvm::omp::OMP_SCHEDULE_Distribute;
67 llvm_unreachable(
"unhandled schedule clause argument");
72class OpenMPAllocaStackFrame
77 explicit OpenMPAllocaStackFrame(llvm::OpenMPIRBuilder::InsertPointTy allocaIP)
78 : allocaInsertPoint(allocaIP) {}
79 llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
85class OpenMPLoopInfoStackFrame
89 llvm::CanonicalLoopInfo *loopInfo =
nullptr;
108class PreviouslyReportedError
109 :
public llvm::ErrorInfo<PreviouslyReportedError> {
111 void log(raw_ostream &)
const override {
115 std::error_code convertToErrorCode()
const override {
117 "PreviouslyReportedError doesn't support ECError conversion");
124char PreviouslyReportedError::ID = 0;
135class LinearClauseProcessor {
138 SmallVector<llvm::Value *> linearPreconditionVars;
139 SmallVector<llvm::Value *> linearLoopBodyTemps;
140 SmallVector<llvm::Value *> linearOrigVal;
141 SmallVector<llvm::Value *> linearSteps;
142 SmallVector<llvm::Type *> linearVarTypes;
143 llvm::BasicBlock *linearFinalizationBB;
144 llvm::BasicBlock *linearExitBB;
145 llvm::BasicBlock *linearLastIterExitBB;
149 void registerType(LLVM::ModuleTranslation &moduleTranslation,
150 mlir::Attribute &ty) {
151 linearVarTypes.push_back(moduleTranslation.
convertType(
152 mlir::cast<mlir::TypeAttr>(ty).getValue()));
156 void createLinearVar(llvm::IRBuilderBase &builder,
157 LLVM::ModuleTranslation &moduleTranslation,
158 llvm::Value *linearVar,
int idx) {
159 linearPreconditionVars.push_back(
160 builder.CreateAlloca(linearVarTypes[idx],
nullptr,
".linear_var"));
161 llvm::Value *linearLoopBodyTemp =
162 builder.CreateAlloca(linearVarTypes[idx],
nullptr,
".linear_result");
163 linearOrigVal.push_back(linearVar);
164 linearLoopBodyTemps.push_back(linearLoopBodyTemp);
168 inline void initLinearStep(LLVM::ModuleTranslation &moduleTranslation,
169 mlir::Value &linearStep) {
170 linearSteps.push_back(moduleTranslation.
lookupValue(linearStep));
174 void initLinearVar(llvm::IRBuilderBase &builder,
175 LLVM::ModuleTranslation &moduleTranslation,
176 llvm::BasicBlock *loopPreHeader) {
177 builder.SetInsertPoint(loopPreHeader->getTerminator());
178 for (
size_t index = 0; index < linearOrigVal.size(); index++) {
179 llvm::LoadInst *linearVarLoad =
180 builder.CreateLoad(linearVarTypes[index], linearOrigVal[index]);
181 builder.CreateStore(linearVarLoad, linearPreconditionVars[index]);
186 void updateLinearVar(llvm::IRBuilderBase &builder, llvm::BasicBlock *loopBody,
187 llvm::Value *loopInductionVar) {
188 builder.SetInsertPoint(loopBody->getTerminator());
189 for (
size_t index = 0; index < linearPreconditionVars.size(); index++) {
190 llvm::Type *linearVarType = linearVarTypes[index];
191 llvm::Value *iv = loopInductionVar;
192 llvm::Value *step = linearSteps[index];
194 if (!iv->getType()->isIntegerTy())
195 llvm_unreachable(
"OpenMP loop induction variable must be an integer "
198 if (linearVarType->isIntegerTy()) {
200 iv = builder.CreateSExtOrTrunc(iv, linearVarType);
201 step = builder.CreateSExtOrTrunc(step, linearVarType);
203 llvm::LoadInst *linearVarStart =
204 builder.CreateLoad(linearVarType, linearPreconditionVars[index]);
205 llvm::Value *mulInst = builder.CreateMul(iv, step);
206 llvm::Value *addInst = builder.CreateAdd(linearVarStart, mulInst);
207 builder.CreateStore(addInst, linearLoopBodyTemps[index]);
208 }
else if (linearVarType->isFloatingPointTy()) {
210 step = builder.CreateSExtOrTrunc(step, iv->getType());
211 llvm::Value *mulInst = builder.CreateMul(iv, step);
213 llvm::LoadInst *linearVarStart =
214 builder.CreateLoad(linearVarType, linearPreconditionVars[index]);
215 llvm::Value *mulFp = builder.CreateSIToFP(mulInst, linearVarType);
216 llvm::Value *addInst = builder.CreateFAdd(linearVarStart, mulFp);
217 builder.CreateStore(addInst, linearLoopBodyTemps[index]);
220 "Linear variable must be of integer or floating-point type");
227 void splitLinearFiniBB(llvm::IRBuilderBase &builder,
228 llvm::BasicBlock *loopExit) {
229 linearFinalizationBB = loopExit->splitBasicBlock(
230 loopExit->getTerminator(),
"omp_loop.linear_finalization");
231 linearExitBB = linearFinalizationBB->splitBasicBlock(
232 linearFinalizationBB->getTerminator(),
"omp_loop.linear_exit");
233 linearLastIterExitBB = linearFinalizationBB->splitBasicBlock(
234 linearFinalizationBB->getTerminator(),
"omp_loop.linear_lastiter_exit");
238 llvm::OpenMPIRBuilder::InsertPointOrErrorTy
239 finalizeLinearVar(llvm::IRBuilderBase &builder,
240 LLVM::ModuleTranslation &moduleTranslation,
241 llvm::Value *lastIter) {
243 builder.SetInsertPoint(linearFinalizationBB->getTerminator());
244 llvm::Value *loopLastIterLoad = builder.CreateLoad(
245 llvm::Type::getInt32Ty(builder.getContext()), lastIter);
246 llvm::Value *isLast =
247 builder.CreateCmp(llvm::CmpInst::ICMP_NE, loopLastIterLoad,
248 llvm::ConstantInt::get(
249 llvm::Type::getInt32Ty(builder.getContext()), 0));
251 builder.SetInsertPoint(linearLastIterExitBB->getTerminator());
252 for (
size_t index = 0; index < linearOrigVal.size(); index++) {
253 llvm::LoadInst *linearVarTemp =
254 builder.CreateLoad(linearVarTypes[index], linearLoopBodyTemps[index]);
255 builder.CreateStore(linearVarTemp, linearOrigVal[index]);
261 builder.SetInsertPoint(linearFinalizationBB->getTerminator());
262 builder.CreateCondBr(isLast, linearLastIterExitBB, linearExitBB);
263 linearFinalizationBB->getTerminator()->eraseFromParent();
265 builder.SetInsertPoint(linearExitBB->getTerminator());
267 builder.saveIP(), llvm::omp::OMPD_barrier);
272 void emitStoresForLinearVar(llvm::IRBuilderBase &builder) {
273 for (
size_t index = 0; index < linearOrigVal.size(); index++) {
274 llvm::LoadInst *linearVarTemp =
275 builder.CreateLoad(linearVarTypes[index], linearLoopBodyTemps[index]);
276 builder.CreateStore(linearVarTemp, linearOrigVal[index]);
282 void rewriteInPlace(llvm::IRBuilderBase &builder, llvm::BasicBlock *startBB,
283 llvm::BasicBlock *endBB, llvm::StringRef prefix,
285 llvm::SmallVector<llvm::BasicBlock *, 32> worklist;
286 llvm::SmallPtrSet<llvm::BasicBlock *, 32> visited;
287 llvm::SmallPtrSet<llvm::BasicBlock *, 32> matchingBBs;
289 assert(startBB && endBB &&
"Invalid startBB/endBB");
293 worklist.push_back(startBB);
294 visited.insert(startBB);
296 while (!worklist.empty()) {
297 llvm::BasicBlock *bb = worklist.pop_back_val();
299 if (bb->hasName() && bb->getName().starts_with(prefix))
300 matchingBBs.insert(bb);
305 for (llvm::BasicBlock *succ : llvm::successors(bb)) {
306 if (visited.insert(succ).second)
307 worklist.push_back(succ);
312 for (
auto *user : linearOrigVal[varIndex]->users()) {
313 if (
auto *userInst = dyn_cast<llvm::Instruction>(user)) {
314 if (matchingBBs.contains(userInst->getParent()))
315 user->replaceUsesOfWith(linearOrigVal[varIndex],
316 linearLoopBodyTemps[varIndex]);
327 SymbolRefAttr symbolName) {
328 omp::PrivateClauseOp privatizer =
331 assert(privatizer &&
"privatizer not found in the symbol table");
342 auto todo = [&op](StringRef clauseName) {
343 return op.
emitError() <<
"not yet implemented: Unhandled clause "
344 << clauseName <<
" in " << op.
getName()
348 auto checkAllocate = [&todo](
auto op, LogicalResult &
result) {
349 if (!op.getAllocateVars().empty() || !op.getAllocatorVars().empty())
350 result = todo(
"allocate");
352 auto checkBare = [&todo](
auto op, LogicalResult &
result) {
354 result = todo(
"ompx_bare");
356 auto checkDepend = [&todo](
auto op, LogicalResult &
result) {
357 if (!op.getDependVars().empty() || op.getDependKinds())
360 auto checkDependIteratorModifier = [&todo](
auto op, LogicalResult &
result) {
361 if (!op.getDependIterated().empty() ||
362 (op.getDependIteratedKinds() && !op.getDependIteratedKinds()->empty()))
363 result = todo(
"depend with iterator modifier");
365 auto checkHint = [](
auto op, LogicalResult &) {
369 auto checkInReduction = [&todo](
auto op, LogicalResult &
result) {
370 if (!op.getInReductionVars().empty() || op.getInReductionByref() ||
371 op.getInReductionSyms())
372 result = todo(
"in_reduction");
374 auto checkNowait = [&todo](
auto op, LogicalResult &
result) {
378 auto checkOrder = [&todo](
auto op, LogicalResult &
result) {
379 if (op.getOrder() || op.getOrderMod())
382 auto checkPrivate = [&todo](
auto op, LogicalResult &
result) {
383 if (!op.getPrivateVars().empty() || op.getPrivateSyms())
384 result = todo(
"privatization");
386 auto checkReduction = [&todo](
auto op, LogicalResult &
result) {
387 if (isa<omp::TeamsOp>(op) || isa<omp::TaskloopOp>(op))
388 if (!op.getReductionVars().empty() || op.getReductionByref() ||
389 op.getReductionSyms())
390 result = todo(
"reduction");
391 if (op.getReductionMod() &&
392 op.getReductionMod().value() != omp::ReductionModifier::defaultmod)
393 result = todo(
"reduction with modifier");
395 auto checkTaskReduction = [&todo](
auto op, LogicalResult &
result) {
396 if (!op.getTaskReductionVars().empty() || op.getTaskReductionByref() ||
397 op.getTaskReductionSyms())
398 result = todo(
"task_reduction");
400 auto checkNumTeams = [&todo](
auto op, LogicalResult &
result) {
401 if (op.hasNumTeamsMultiDim())
402 result = todo(
"num_teams with multi-dimensional values");
404 auto checkNumThreads = [&todo](
auto op, LogicalResult &
result) {
405 if (op.hasNumThreadsMultiDim())
406 result = todo(
"num_threads with multi-dimensional values");
409 auto checkThreadLimit = [&todo](
auto op, LogicalResult &
result) {
410 if (op.hasThreadLimitMultiDim())
411 result = todo(
"thread_limit with multi-dimensional values");
416 .Case([&](omp::DistributeOp op) {
417 checkAllocate(op,
result);
420 .Case([&](omp::SectionsOp op) {
421 checkAllocate(op,
result);
423 checkReduction(op,
result);
425 .Case([&](omp::SingleOp op) {
426 checkAllocate(op,
result);
429 .Case([&](omp::TeamsOp op) {
430 checkAllocate(op,
result);
432 checkNumTeams(op,
result);
433 checkThreadLimit(op,
result);
435 .Case([&](omp::TaskOp op) {
436 checkAllocate(op,
result);
437 checkDependIteratorModifier(op,
result);
438 checkInReduction(op,
result);
440 .Case([&](omp::TaskgroupOp op) {
441 checkAllocate(op,
result);
442 checkTaskReduction(op,
result);
444 .Case([&](omp::TaskwaitOp op) {
448 .Case([&](omp::TaskloopOp op) {
449 checkAllocate(op,
result);
450 checkInReduction(op,
result);
451 checkReduction(op,
result);
453 .Case([&](omp::WsloopOp op) {
454 checkAllocate(op,
result);
456 checkReduction(op,
result);
458 .Case([&](omp::ParallelOp op) {
459 checkAllocate(op,
result);
460 checkReduction(op,
result);
461 checkNumThreads(op,
result);
463 .Case([&](omp::SimdOp op) { checkReduction(op,
result); })
464 .Case<omp::AtomicReadOp, omp::AtomicWriteOp, omp::AtomicUpdateOp,
465 omp::AtomicCaptureOp>([&](
auto op) { checkHint(op,
result); })
466 .Case<omp::TargetEnterDataOp, omp::TargetExitDataOp>(
467 [&](
auto op) { checkDepend(op,
result); })
468 .Case([&](omp::TargetUpdateOp op) { checkDepend(op,
result); })
469 .Case([&](omp::TargetOp op) {
470 checkAllocate(op,
result);
472 checkDependIteratorModifier(op,
result);
473 checkInReduction(op,
result);
474 checkThreadLimit(op,
result);
486 llvm::handleAllErrors(
488 [&](
const PreviouslyReportedError &) {
result = failure(); },
489 [&](
const llvm::ErrorInfoBase &err) {
506static llvm::OpenMPIRBuilder::InsertPointTy
512 llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
514 [&](OpenMPAllocaStackFrame &frame) {
515 allocaInsertPoint = frame.allocaInsertPoint;
523 allocaInsertPoint.getBlock()->getParent() ==
524 builder.GetInsertBlock()->getParent())
525 return allocaInsertPoint;
534 if (builder.GetInsertBlock() ==
535 &builder.GetInsertBlock()->getParent()->getEntryBlock()) {
536 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end() &&
537 "Assuming end of basic block");
538 llvm::BasicBlock *entryBB = llvm::BasicBlock::Create(
539 builder.getContext(),
"entry", builder.GetInsertBlock()->getParent(),
540 builder.GetInsertBlock()->getNextNode());
541 builder.CreateBr(entryBB);
542 builder.SetInsertPoint(entryBB);
545 llvm::BasicBlock &funcEntryBlock =
546 builder.GetInsertBlock()->getParent()->getEntryBlock();
547 return llvm::OpenMPIRBuilder::InsertPointTy(
548 &funcEntryBlock, funcEntryBlock.getFirstInsertionPt());
554static llvm::CanonicalLoopInfo *
556 llvm::CanonicalLoopInfo *loopInfo =
nullptr;
557 moduleTranslation.
stackWalk<OpenMPLoopInfoStackFrame>(
558 [&](OpenMPLoopInfoStackFrame &frame) {
559 loopInfo = frame.loopInfo;
571 Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder,
574 bool isLoopWrapper = isa<omp::LoopWrapperInterface>(region.
getParentOp());
576 llvm::BasicBlock *continuationBlock =
577 splitBB(builder,
true,
"omp.region.cont");
578 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
580 llvm::LLVMContext &llvmContext = builder.getContext();
581 for (
Block &bb : region) {
582 llvm::BasicBlock *llvmBB = llvm::BasicBlock::Create(
583 llvmContext, blockName, builder.GetInsertBlock()->getParent(),
584 builder.GetInsertBlock()->getNextNode());
585 moduleTranslation.
mapBlock(&bb, llvmBB);
588 llvm::Instruction *sourceTerminator = sourceBlock->getTerminator();
595 unsigned numYields = 0;
597 if (!isLoopWrapper) {
598 bool operandsProcessed =
false;
600 if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(bb.getTerminator())) {
601 if (!operandsProcessed) {
602 for (
unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
603 continuationBlockPHITypes.push_back(
604 moduleTranslation.
convertType(yield->getOperand(i).getType()));
606 operandsProcessed =
true;
608 assert(continuationBlockPHITypes.size() == yield->getNumOperands() &&
609 "mismatching number of values yielded from the region");
610 for (
unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
611 llvm::Type *operandType =
612 moduleTranslation.
convertType(yield->getOperand(i).getType());
614 assert(continuationBlockPHITypes[i] == operandType &&
615 "values of mismatching types yielded from the region");
625 if (!continuationBlockPHITypes.empty())
627 continuationBlockPHIs &&
628 "expected continuation block PHIs if converted regions yield values");
629 if (continuationBlockPHIs) {
630 llvm::IRBuilderBase::InsertPointGuard guard(builder);
631 continuationBlockPHIs->reserve(continuationBlockPHITypes.size());
632 builder.SetInsertPoint(continuationBlock, continuationBlock->begin());
633 for (llvm::Type *ty : continuationBlockPHITypes)
634 continuationBlockPHIs->push_back(builder.CreatePHI(ty, numYields));
640 for (
Block *bb : blocks) {
641 llvm::BasicBlock *llvmBB = moduleTranslation.
lookupBlock(bb);
644 if (bb->isEntryBlock()) {
645 assert(sourceTerminator->getNumSuccessors() == 1 &&
646 "provided entry block has multiple successors");
647 assert(sourceTerminator->getSuccessor(0) == continuationBlock &&
648 "ContinuationBlock is not the successor of the entry block");
649 sourceTerminator->setSuccessor(0, llvmBB);
652 llvm::IRBuilderBase::InsertPointGuard guard(builder);
654 moduleTranslation.
convertBlock(*bb, bb->isEntryBlock(), builder)))
655 return llvm::make_error<PreviouslyReportedError>();
660 builder.CreateBr(continuationBlock);
671 Operation *terminator = bb->getTerminator();
672 if (isa<omp::TerminatorOp, omp::YieldOp>(terminator)) {
673 builder.CreateBr(continuationBlock);
675 for (
unsigned i = 0, e = terminator->
getNumOperands(); i < e; ++i)
676 (*continuationBlockPHIs)[i]->addIncoming(
690 return continuationBlock;
696 case omp::ClauseProcBindKind::Close:
697 return llvm::omp::ProcBindKind::OMP_PROC_BIND_close;
698 case omp::ClauseProcBindKind::Master:
699 return llvm::omp::ProcBindKind::OMP_PROC_BIND_master;
700 case omp::ClauseProcBindKind::Primary:
701 return llvm::omp::ProcBindKind::OMP_PROC_BIND_primary;
702 case omp::ClauseProcBindKind::Spread:
703 return llvm::omp::ProcBindKind::OMP_PROC_BIND_spread;
705 llvm_unreachable(
"Unknown ClauseProcBindKind kind");
712 auto maskedOp = cast<omp::MaskedOp>(opInst);
713 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
718 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
720 auto ®ion = maskedOp.getRegion();
721 builder.restoreIP(codeGenIP);
729 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
731 llvm::Value *filterVal =
nullptr;
732 if (
auto filterVar = maskedOp.getFilteredThreadId()) {
733 filterVal = moduleTranslation.
lookupValue(filterVar);
735 llvm::LLVMContext &llvmContext = builder.getContext();
737 llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext), 0);
739 assert(filterVal !=
nullptr);
740 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
741 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
748 builder.restoreIP(*afterIP);
756 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
757 auto masterOp = cast<omp::MasterOp>(opInst);
762 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
764 auto ®ion = masterOp.getRegion();
765 builder.restoreIP(codeGenIP);
773 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
775 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
776 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
783 builder.restoreIP(*afterIP);
791 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
792 auto criticalOp = cast<omp::CriticalOp>(opInst);
797 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
799 auto ®ion = cast<omp::CriticalOp>(opInst).getRegion();
800 builder.restoreIP(codeGenIP);
808 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
810 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
811 llvm::LLVMContext &llvmContext = moduleTranslation.
getLLVMContext();
812 llvm::Constant *hint =
nullptr;
815 if (criticalOp.getNameAttr()) {
818 auto symbolRef = cast<SymbolRefAttr>(criticalOp.getNameAttr());
819 auto criticalDeclareOp =
823 llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext),
824 static_cast<int>(criticalDeclareOp.getHint()));
826 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
828 ompLoc, bodyGenCB, finiCB, criticalOp.getName().value_or(
""), hint);
833 builder.restoreIP(*afterIP);
840 template <
typename OP>
843 cast<
omp::BlockArgOpenMPOpInterface>(*op).getPrivateBlockArgs()) {
846 collectPrivatizationDecls<OP>(op);
861 void collectPrivatizationDecls(OP op) {
862 std::optional<ArrayAttr> attr = op.getPrivateSyms();
867 for (
auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
878 std::optional<ArrayAttr> attr = op.getReductionSyms();
882 reductions.reserve(reductions.size() + op.getNumReductionVars());
883 for (
auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
884 reductions.push_back(
896 Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder,
905 llvm::Instruction *potentialTerminator =
906 builder.GetInsertBlock()->empty() ?
nullptr
907 : &builder.GetInsertBlock()->back();
909 if (potentialTerminator && potentialTerminator->isTerminator())
910 potentialTerminator->removeFromParent();
911 moduleTranslation.
mapBlock(®ion.
front(), builder.GetInsertBlock());
914 region.
front(),
true, builder)))
918 if (continuationBlockArgs)
920 *continuationBlockArgs,
927 if (potentialTerminator && potentialTerminator->isTerminator()) {
928 llvm::BasicBlock *block = builder.GetInsertBlock();
929 if (block->empty()) {
935 potentialTerminator->insertInto(block, block->begin());
937 potentialTerminator->insertAfter(&block->back());
951 if (continuationBlockArgs)
952 llvm::append_range(*continuationBlockArgs, phis);
953 builder.SetInsertPoint(*continuationBlock,
954 (*continuationBlock)->getFirstInsertionPt());
961using OwningReductionGen =
962 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
963 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *,
965using OwningAtomicReductionGen =
966 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
967 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Type *, llvm::Value *,
969using OwningDataPtrPtrReductionGen =
970 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
971 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *&)>;
977static OwningReductionGen
983 OwningReductionGen gen =
984 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
985 llvm::Value *
lhs, llvm::Value *
rhs,
986 llvm::Value *&
result)
mutable
987 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
988 moduleTranslation.
mapValue(decl.getReductionLhsArg(),
lhs);
989 moduleTranslation.
mapValue(decl.getReductionRhsArg(),
rhs);
990 builder.restoreIP(insertPoint);
993 "omp.reduction.nonatomic.body", builder,
994 moduleTranslation, &phis)))
995 return llvm::createStringError(
996 "failed to inline `combiner` region of `omp.declare_reduction`");
997 result = llvm::getSingleElement(phis);
998 return builder.saveIP();
1007static OwningAtomicReductionGen
1009 llvm::IRBuilderBase &builder,
1011 if (decl.getAtomicReductionRegion().empty())
1012 return OwningAtomicReductionGen();
1018 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint, llvm::Type *,
1019 llvm::Value *
lhs, llvm::Value *
rhs)
mutable
1020 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1021 moduleTranslation.
mapValue(decl.getAtomicReductionLhsArg(),
lhs);
1022 moduleTranslation.
mapValue(decl.getAtomicReductionRhsArg(),
rhs);
1023 builder.restoreIP(insertPoint);
1026 "omp.reduction.atomic.body", builder,
1027 moduleTranslation, &phis)))
1028 return llvm::createStringError(
1029 "failed to inline `atomic` region of `omp.declare_reduction`");
1030 assert(phis.empty());
1031 return builder.saveIP();
1040static OwningDataPtrPtrReductionGen
1044 return OwningDataPtrPtrReductionGen();
1046 OwningDataPtrPtrReductionGen refDataPtrGen =
1047 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
1048 llvm::Value *byRefVal, llvm::Value *&
result)
mutable
1049 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1050 moduleTranslation.
mapValue(decl.getDataPtrPtrRegionArg(), byRefVal);
1051 builder.restoreIP(insertPoint);
1054 "omp.data_ptr_ptr.body", builder,
1055 moduleTranslation, &phis)))
1056 return llvm::createStringError(
1057 "failed to inline `data_ptr_ptr` region of `omp.declare_reduction`");
1058 result = llvm::getSingleElement(phis);
1059 return builder.saveIP();
1062 return refDataPtrGen;
1069 auto orderedOp = cast<omp::OrderedOp>(opInst);
1074 omp::ClauseDepend dependType = *orderedOp.getDoacrossDependType();
1075 bool isDependSource = dependType == omp::ClauseDepend::dependsource;
1076 unsigned numLoops = *orderedOp.getDoacrossNumLoops();
1078 moduleTranslation.
lookupValues(orderedOp.getDoacrossDependVars());
1080 size_t indexVecValues = 0;
1081 while (indexVecValues < vecValues.size()) {
1083 storeValues.reserve(numLoops);
1084 for (
unsigned i = 0; i < numLoops; i++) {
1085 storeValues.push_back(vecValues[indexVecValues]);
1088 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1090 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1091 builder.restoreIP(moduleTranslation.
getOpenMPBuilder()->createOrderedDepend(
1092 ompLoc, allocaIP, numLoops, storeValues,
".cnt.addr", isDependSource));
1102 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1103 auto orderedRegionOp = cast<omp::OrderedRegionOp>(opInst);
1108 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
1110 auto ®ion = cast<omp::OrderedRegionOp>(opInst).getRegion();
1111 builder.restoreIP(codeGenIP);
1119 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
1121 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1122 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1124 ompLoc, bodyGenCB, finiCB, !orderedRegionOp.getParLevelSimd());
1129 builder.restoreIP(*afterIP);
1135struct DeferredStore {
1136 DeferredStore(llvm::Value *value, llvm::Value *address)
1137 : value(value), address(address) {}
1140 llvm::Value *address;
1147template <
typename T>
1150 llvm::IRBuilderBase &builder,
1152 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1158 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1159 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1162 deferredStores.reserve(loop.getNumReductionVars());
1164 for (std::size_t i = 0; i < loop.getNumReductionVars(); ++i) {
1165 Region &allocRegion = reductionDecls[i].getAllocRegion();
1167 if (allocRegion.
empty())
1172 builder, moduleTranslation, &phis)))
1173 return loop.emitError(
1174 "failed to inline `alloc` region of `omp.declare_reduction`");
1176 assert(phis.size() == 1 &&
"expected one allocation to be yielded");
1177 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1181 llvm::Value *var = builder.CreateAlloca(
1182 moduleTranslation.
convertType(reductionDecls[i].getType()));
1184 llvm::Type *ptrTy = builder.getPtrTy();
1185 llvm::Value *castVar =
1186 builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
1187 llvm::Value *castPhi =
1188 builder.CreatePointerBitCastOrAddrSpaceCast(phis[0], ptrTy);
1190 deferredStores.emplace_back(castPhi, castVar);
1192 privateReductionVariables[i] = castVar;
1193 moduleTranslation.
mapValue(reductionArgs[i], castPhi);
1194 reductionVariableMap.try_emplace(loop.getReductionVars()[i], castPhi);
1196 assert(allocRegion.
empty() &&
1197 "allocaction is implicit for by-val reduction");
1198 llvm::Value *var = builder.CreateAlloca(
1199 moduleTranslation.
convertType(reductionDecls[i].getType()));
1201 llvm::Type *ptrTy = builder.getPtrTy();
1202 llvm::Value *castVar =
1203 builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
1205 moduleTranslation.
mapValue(reductionArgs[i], castVar);
1206 privateReductionVariables[i] = castVar;
1207 reductionVariableMap.try_emplace(loop.getReductionVars()[i], castVar);
1215template <
typename T>
1218 llvm::IRBuilderBase &builder,
1223 mlir::omp::DeclareReductionOp &reduction = reductionDecls[i];
1224 Region &initializerRegion = reduction.getInitializerRegion();
1227 mlir::Value mlirSource = loop.getReductionVars()[i];
1228 llvm::Value *llvmSource = moduleTranslation.
lookupValue(mlirSource);
1229 llvm::Value *origVal = llvmSource;
1231 if (!isa<LLVM::LLVMPointerType>(
1232 reduction.getInitializerMoldArg().getType()) &&
1233 isa<LLVM::LLVMPointerType>(mlirSource.
getType())) {
1236 reduction.getInitializerMoldArg().getType()),
1237 llvmSource,
"omp_orig");
1239 moduleTranslation.
mapValue(reduction.getInitializerMoldArg(), origVal);
1242 llvm::Value *allocation =
1243 reductionVariableMap.lookup(loop.getReductionVars()[i]);
1244 moduleTranslation.
mapValue(reduction.getInitializerAllocArg(), allocation);
1250 llvm::BasicBlock *block =
nullptr) {
1251 if (block ==
nullptr)
1252 block = builder.GetInsertBlock();
1254 if (!block->hasTerminator())
1255 builder.SetInsertPoint(block);
1257 builder.SetInsertPoint(block->getTerminator());
1265template <
typename OP>
1268 llvm::IRBuilderBase &builder,
1270 llvm::BasicBlock *latestAllocaBlock,
1276 if (op.getNumReductionVars() == 0)
1279 llvm::BasicBlock *initBlock = splitBB(builder,
true,
"omp.reduction.init");
1280 auto allocaIP = llvm::IRBuilderBase::InsertPoint(
1281 latestAllocaBlock, latestAllocaBlock->getTerminator()->getIterator());
1282 builder.restoreIP(allocaIP);
1285 for (
unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1287 if (!reductionDecls[i].getAllocRegion().empty())
1293 byRefVars[i] = builder.CreateAlloca(
1294 moduleTranslation.
convertType(reductionDecls[i].getType()));
1302 for (
auto [data, addr] : deferredStores)
1303 builder.CreateStore(data, addr);
1308 for (
unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1313 reductionVariableMap, i);
1321 "omp.reduction.neutral", builder,
1322 moduleTranslation, &phis)))
1325 assert(phis.size() == 1 &&
"expected one value to be yielded from the "
1326 "reduction neutral element declaration region");
1331 if (!reductionDecls[i].getAllocRegion().empty())
1340 builder.CreateStore(phis[0], byRefVars[i]);
1342 privateReductionVariables[i] = byRefVars[i];
1343 moduleTranslation.
mapValue(reductionArgs[i], phis[0]);
1344 reductionVariableMap.try_emplace(op.getReductionVars()[i], phis[0]);
1347 builder.CreateStore(phis[0], privateReductionVariables[i]);
1354 moduleTranslation.
forgetMapping(reductionDecls[i].getInitializerRegion());
1361template <
typename T>
1362static void collectReductionInfo(
1363 T loop, llvm::IRBuilderBase &builder,
1372 unsigned numReductions = loop.getNumReductionVars();
1374 for (
unsigned i = 0; i < numReductions; ++i) {
1377 owningAtomicReductionGens.push_back(
1380 reductionDecls[i], builder, moduleTranslation, isByRef[i]));
1384 reductionInfos.reserve(numReductions);
1385 for (
unsigned i = 0; i < numReductions; ++i) {
1386 llvm::OpenMPIRBuilder::ReductionGenAtomicCBTy
atomicGen =
nullptr;
1387 if (owningAtomicReductionGens[i])
1388 atomicGen = owningAtomicReductionGens[i];
1389 llvm::Value *variable =
1390 moduleTranslation.
lookupValue(loop.getReductionVars()[i]);
1393 if (
auto alloca = mlir::dyn_cast<LLVM::AllocaOp>(op)) {
1394 allocatedType = alloca.getElemType();
1401 reductionInfos.push_back(
1403 privateReductionVariables[i],
1404 llvm::OpenMPIRBuilder::EvalKind::Scalar,
1408 allocatedType ? moduleTranslation.
convertType(allocatedType) :
nullptr,
1409 reductionDecls[i].getByrefElementType()
1411 *reductionDecls[i].getByrefElementType())
1421 llvm::IRBuilderBase &builder, StringRef regionName,
1422 bool shouldLoadCleanupRegionArg =
true) {
1423 for (
auto [i, cleanupRegion] : llvm::enumerate(cleanupRegions)) {
1424 if (cleanupRegion->empty())
1430 llvm::Instruction *potentialTerminator =
1431 builder.GetInsertBlock()->empty() ?
nullptr
1432 : &builder.GetInsertBlock()->back();
1433 if (potentialTerminator && potentialTerminator->isTerminator())
1434 builder.SetInsertPoint(potentialTerminator);
1435 llvm::Value *privateVarValue =
1436 shouldLoadCleanupRegionArg
1437 ? builder.CreateLoad(
1439 privateVariables[i])
1440 : privateVariables[i];
1445 moduleTranslation)))
1458 OP op, llvm::IRBuilderBase &builder,
1460 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1463 bool isNowait =
false,
bool isTeamsReduction =
false) {
1465 if (op.getNumReductionVars() == 0)
1477 collectReductionInfo(op, builder, moduleTranslation, reductionDecls,
1479 owningReductionGenRefDataPtrGens,
1480 privateReductionVariables, reductionInfos, isByRef);
1485 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
1486 builder.SetInsertPoint(tempTerminator);
1487 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
1488 ompBuilder->createReductions(builder.saveIP(), allocaIP, reductionInfos,
1489 isByRef, isNowait, isTeamsReduction);
1494 if (!contInsertPoint->getBlock())
1495 return op->emitOpError() <<
"failed to convert reductions";
1497 llvm::OpenMPIRBuilder::InsertPointTy afterIP = *contInsertPoint;
1498 if (!isTeamsReduction) {
1499 llvm::OpenMPIRBuilder::InsertPointOrErrorTy barrierIP =
1500 ompBuilder->createBarrier(*contInsertPoint, llvm::omp::OMPD_for);
1504 afterIP = *barrierIP;
1507 tempTerminator->eraseFromParent();
1508 builder.restoreIP(afterIP);
1512 llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
1513 [](omp::DeclareReductionOp reductionDecl) {
1514 return &reductionDecl.getCleanupRegion();
1517 moduleTranslation, builder,
1518 "omp.reduction.cleanup");
1529template <
typename OP>
1533 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1538 if (op.getNumReductionVars() == 0)
1544 allocaIP, reductionDecls,
1545 privateReductionVariables, reductionVariableMap,
1546 deferredStores, isByRef)))
1549 return initReductionVars(op, reductionArgs, builder, moduleTranslation,
1550 allocaIP.getBlock(), reductionDecls,
1551 privateReductionVariables, reductionVariableMap,
1552 isByRef, deferredStores);
1566 if (mappedPrivateVars ==
nullptr || !mappedPrivateVars->contains(privateVar))
1569 Value blockArg = (*mappedPrivateVars)[privateVar];
1572 assert(isa<LLVM::LLVMPointerType>(blockArgType) &&
1573 "A block argument corresponding to a mapped var should have "
1576 if (privVarType == blockArgType)
1583 if (!isa<LLVM::LLVMPointerType>(privVarType))
1584 return builder.CreateLoad(moduleTranslation.
convertType(privVarType),
1597 omp::PrivateClauseOp &privDecl, llvm::Value *nonPrivateVar,
1599 llvm::BasicBlock *privInitBlock,
1601 Region &initRegion = privDecl.getInitRegion();
1602 if (initRegion.
empty())
1603 return llvmPrivateVar;
1605 assert(nonPrivateVar);
1606 moduleTranslation.
mapValue(privDecl.getInitMoldArg(), nonPrivateVar);
1607 moduleTranslation.
mapValue(privDecl.getInitPrivateArg(), llvmPrivateVar);
1612 moduleTranslation, &phis)))
1613 return llvm::createStringError(
1614 "failed to inline `init` region of `omp.private`");
1616 assert(phis.size() == 1 &&
"expected one allocation to be yielded");
1633 llvm::Value *llvmPrivateVar, llvm::BasicBlock *privInitBlock,
1636 builder, moduleTranslation, privDecl,
1639 blockArg, llvmPrivateVar, privInitBlock, mappedPrivateVars);
1648 return llvm::Error::success();
1650 llvm::BasicBlock *privInitBlock = splitBB(builder,
true,
"omp.private.init");
1653 for (
auto [idx, zip] : llvm::enumerate(llvm::zip_equal(
1656 auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVar] = zip;
1658 builder, moduleTranslation, privDecl, mlirPrivVar, blockArg,
1659 llvmPrivateVar, privInitBlock, mappedPrivateVars);
1662 return privVarOrErr.takeError();
1664 llvmPrivateVar = privVarOrErr.get();
1665 moduleTranslation.
mapValue(blockArg, llvmPrivateVar);
1670 return llvm::Error::success();
1680 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1683 llvm::Instruction *allocaTerminator = allocaIP.getBlock()->getTerminator();
1684 splitBB(llvm::OpenMPIRBuilder::InsertPointTy(allocaIP.getBlock(),
1685 allocaTerminator->getIterator()),
1686 true, allocaTerminator->getStableDebugLoc(),
1687 "omp.region.after_alloca");
1689 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1691 allocaTerminator = allocaIP.getBlock()->getTerminator();
1692 builder.SetInsertPoint(allocaTerminator);
1694 assert(allocaTerminator->getNumSuccessors() == 1 &&
1695 "This is an unconditional branch created by splitBB");
1697 llvm::DataLayout dataLayout = builder.GetInsertBlock()->getDataLayout();
1698 llvm::BasicBlock *afterAllocas = allocaTerminator->getSuccessor(0);
1700 unsigned int allocaAS =
1701 moduleTranslation.
getLLVMModule()->getDataLayout().getAllocaAddrSpace();
1704 .getProgramAddressSpace();
1706 for (
auto [privDecl, mlirPrivVar, blockArg] :
1709 llvm::Type *llvmAllocType =
1710 moduleTranslation.
convertType(privDecl.getType());
1711 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1712 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
1713 llvmAllocType,
nullptr,
"omp.private.alloc");
1714 if (allocaAS != defaultAS)
1715 llvmPrivateVar = builder.CreateAddrSpaceCast(llvmPrivateVar,
1716 builder.getPtrTy(defaultAS));
1718 privateVarsInfo.
llvmVars.push_back(llvmPrivateVar);
1721 return afterAllocas;
1729 if (mlir::isa<omp::SingleOp, omp::CriticalOp>(parent))
1738 if (mlir::isa<omp::ParallelOp>(parent))
1752 bool needsFirstprivate =
1753 llvm::any_of(privateDecls, [](omp::PrivateClauseOp &privOp) {
1754 return privOp.getDataSharingType() ==
1755 omp::DataSharingClauseType::FirstPrivate;
1758 if (!needsFirstprivate)
1761 llvm::BasicBlock *copyBlock =
1762 splitBB(builder,
true,
"omp.private.copy");
1765 for (
auto [decl, moldVar, llvmVar] :
1766 llvm::zip_equal(privateDecls, moldVars, llvmPrivateVars)) {
1767 if (decl.getDataSharingType() != omp::DataSharingClauseType::FirstPrivate)
1771 Region ©Region = decl.getCopyRegion();
1773 moduleTranslation.
mapValue(decl.getCopyMoldArg(), moldVar);
1776 moduleTranslation.
mapValue(decl.getCopyPrivateArg(), llvmVar);
1780 moduleTranslation)))
1781 return decl.emitError(
"failed to inline `copy` region of `omp.private`");
1795 llvm::OpenMPIRBuilder::InsertPointOrErrorTy res =
1796 ompBuilder->createBarrier(builder.saveIP(), llvm::omp::OMPD_barrier);
1812 llvm::transform(mlirPrivateVars, moldVars.begin(), [&](
mlir::Value mlirVar) {
1814 llvm::Value *moldVar = findAssociatedValue(
1815 mlirVar, builder, moduleTranslation, mappedPrivateVars);
1820 llvmPrivateVars, privateDecls, insertBarrier,
1831 llvm::transform(privateDecls, std::back_inserter(privateCleanupRegions),
1832 [](omp::PrivateClauseOp privatizer) {
1833 return &privatizer.getDeallocRegion();
1837 privateCleanupRegions, llvmPrivateVars, moduleTranslation, builder,
1838 "omp.private.dealloc",
false)))
1839 return mlir::emitError(loc,
"failed to inline `dealloc` region of an "
1840 "`omp.private` op in");
1852 if (mlir::isa<omp::CancelOp, omp::CancellationPointOp>(child))
1862 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1863 using StorableBodyGenCallbackTy =
1864 llvm::OpenMPIRBuilder::StorableBodyGenCallbackTy;
1866 auto sectionsOp = cast<omp::SectionsOp>(opInst);
1872 assert(isByRef.size() == sectionsOp.getNumReductionVars());
1876 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1880 sectionsOp.getNumReductionVars());
1884 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
1887 sectionsOp, reductionArgs, builder, moduleTranslation, allocaIP,
1888 reductionDecls, privateReductionVariables, reductionVariableMap,
1895 auto sectionOp = dyn_cast<omp::SectionOp>(op);
1899 Region ®ion = sectionOp.getRegion();
1900 auto sectionCB = [§ionsOp, ®ion, &builder, &moduleTranslation](
1901 InsertPointTy allocaIP, InsertPointTy codeGenIP) {
1902 builder.restoreIP(codeGenIP);
1909 sectionsOp.getRegion().getNumArguments());
1910 for (
auto [sectionsArg, sectionArg] : llvm::zip_equal(
1911 sectionsOp.getRegion().getArguments(), region.
getArguments())) {
1912 llvm::Value *llvmVal = moduleTranslation.
lookupValue(sectionsArg);
1914 moduleTranslation.
mapValue(sectionArg, llvmVal);
1921 sectionCBs.push_back(sectionCB);
1927 if (sectionCBs.empty())
1930 assert(isa<omp::SectionOp>(*sectionsOp.getRegion().op_begin()));
1935 auto privCB = [&](InsertPointTy, InsertPointTy codeGenIP, llvm::Value &,
1936 llvm::Value &vPtr, llvm::Value *&replacementValue)
1937 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1938 replacementValue = &vPtr;
1944 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
1948 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1949 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1951 ompLoc, allocaIP, sectionCBs, privCB, finiCB, isCancellable,
1952 sectionsOp.getNowait());
1957 builder.restoreIP(*afterIP);
1961 sectionsOp, builder, moduleTranslation, allocaIP, reductionDecls,
1962 privateReductionVariables, isByRef, sectionsOp.getNowait());
1969 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1970 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1975 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
1976 builder.restoreIP(codegenIP);
1978 builder, moduleTranslation)
1981 auto finiCB = [&](InsertPointTy codeGenIP) {
return llvm::Error::success(); };
1985 std::optional<ArrayAttr> cpFuncs = singleOp.getCopyprivateSyms();
1988 for (
size_t i = 0, e = cpVars.size(); i < e; ++i) {
1989 llvmCPVars.push_back(moduleTranslation.
lookupValue(cpVars[i]));
1991 singleOp, cast<SymbolRefAttr>((*cpFuncs)[i]));
1992 llvmCPFuncs.push_back(
1996 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1998 ompLoc, bodyCB, finiCB, singleOp.getNowait(), llvmCPVars,
2004 builder.restoreIP(*afterIP);
2010 llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(teamsOp.getOperation());
2015 for (
auto ra : iface.getReductionBlockArgs())
2016 for (
auto &use : ra.getUses()) {
2017 auto *useOp = use.getOwner();
2019 if (mlir::isa<LLVM::DbgDeclareOp, LLVM::DbgValueOp>(useOp)) {
2020 debugUses.push_back(useOp);
2024 auto currentDistOp = useOp->getParentOfType<omp::DistributeOp>();
2029 Operation *currentOp = currentDistOp.getOperation();
2030 if (distOp && (distOp != currentOp))
2039 for (
auto *use : debugUses)
2048 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2053 unsigned numReductionVars = op.getNumReductionVars();
2057 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2063 if (doTeamsReduction) {
2064 isByRef =
getIsByRef(op.getReductionByref());
2066 assert(isByRef.size() == op.getNumReductionVars());
2069 llvm::cast<omp::BlockArgOpenMPOpInterface>(*op).getReductionBlockArgs();
2074 op, reductionArgs, builder, moduleTranslation, allocaIP,
2075 reductionDecls, privateReductionVariables, reductionVariableMap,
2080 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
2082 moduleTranslation, allocaIP);
2083 builder.restoreIP(codegenIP);
2089 llvm::Value *numTeamsLower =
nullptr;
2090 if (
Value numTeamsLowerVar = op.getNumTeamsLower())
2091 numTeamsLower = moduleTranslation.
lookupValue(numTeamsLowerVar);
2093 llvm::Value *numTeamsUpper =
nullptr;
2094 if (!op.getNumTeamsUpperVars().empty())
2095 numTeamsUpper = moduleTranslation.
lookupValue(op.getNumTeams(0));
2097 llvm::Value *threadLimit =
nullptr;
2098 if (!op.getThreadLimitVars().empty())
2099 threadLimit = moduleTranslation.
lookupValue(op.getThreadLimit(0));
2101 llvm::Value *ifExpr =
nullptr;
2102 if (
Value ifVar = op.getIfExpr())
2105 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2106 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2108 ompLoc, bodyCB, numTeamsLower, numTeamsUpper, threadLimit, ifExpr);
2113 builder.restoreIP(*afterIP);
2114 if (doTeamsReduction) {
2117 op, builder, moduleTranslation, allocaIP, reductionDecls,
2118 privateReductionVariables, isByRef,
2128 if (dependVars.empty())
2130 for (
auto dep : llvm::zip(dependVars, dependKinds->getValue())) {
2131 llvm::omp::RTLDependenceKindTy type;
2133 cast<mlir::omp::ClauseTaskDependAttr>(std::get<1>(dep)).getValue()) {
2134 case mlir::omp::ClauseTaskDepend::taskdependin:
2135 type = llvm::omp::RTLDependenceKindTy::DepIn;
2140 case mlir::omp::ClauseTaskDepend::taskdependout:
2141 case mlir::omp::ClauseTaskDepend::taskdependinout:
2142 type = llvm::omp::RTLDependenceKindTy::DepInOut;
2144 case mlir::omp::ClauseTaskDepend::taskdependmutexinoutset:
2145 type = llvm::omp::RTLDependenceKindTy::DepMutexInOutSet;
2147 case mlir::omp::ClauseTaskDepend::taskdependinoutset:
2148 type = llvm::omp::RTLDependenceKindTy::DepInOutSet;
2151 llvm::Value *depVal = moduleTranslation.
lookupValue(std::get<0>(dep));
2152 llvm::OpenMPIRBuilder::DependData dd(type, depVal->getType(), depVal);
2153 dds.emplace_back(dd);
2165 llvm::IRBuilderBase &llvmBuilder, llvm::OpenMPIRBuilder &ompBuilder,
2167 auto finiCB = [&](llvm::OpenMPIRBuilder::InsertPointTy ip) -> llvm::Error {
2168 llvm::IRBuilderBase::InsertPointGuard guard(llvmBuilder);
2172 llvmBuilder.restoreIP(ip);
2178 cancelTerminators.push_back(llvmBuilder.CreateBr(ip.getBlock()));
2179 return llvm::Error::success();
2184 ompBuilder.pushFinalizationCB(
2194 llvm::OpenMPIRBuilder &ompBuilder,
2195 const llvm::OpenMPIRBuilder::InsertPointTy &afterIP) {
2196 ompBuilder.popFinalizationCB();
2197 llvm::BasicBlock *constructFini = afterIP.getBlock()->getSinglePredecessor();
2198 for (llvm::UncondBrInst *cancelBranch : cancelTerminators)
2199 cancelBranch->setSuccessor(constructFini);
2205class TaskContextStructManager {
2207 TaskContextStructManager(llvm::IRBuilderBase &builder,
2208 LLVM::ModuleTranslation &moduleTranslation,
2209 MutableArrayRef<omp::PrivateClauseOp> privateDecls)
2210 : builder{builder}, moduleTranslation{moduleTranslation},
2211 privateDecls{privateDecls} {}
2217 void generateTaskContextStruct();
2223 void createGEPsToPrivateVars();
2229 SmallVector<llvm::Value *>
2230 createGEPsToPrivateVars(llvm::Value *altStructPtr)
const;
2233 void freeStructPtr();
2235 MutableArrayRef<llvm::Value *> getLLVMPrivateVarGEPs() {
2236 return llvmPrivateVarGEPs;
2239 llvm::Value *getStructPtr() {
return structPtr; }
2242 llvm::IRBuilderBase &builder;
2243 LLVM::ModuleTranslation &moduleTranslation;
2244 MutableArrayRef<omp::PrivateClauseOp> privateDecls;
2247 SmallVector<llvm::Type *> privateVarTypes;
2251 SmallVector<llvm::Value *> llvmPrivateVarGEPs;
2254 llvm::Value *structPtr =
nullptr;
2256 llvm::Type *structTy =
nullptr;
2267 llvm::SmallVector<llvm::Value *> lowerBounds;
2268 llvm::SmallVector<llvm::Value *> upperBounds;
2269 llvm::SmallVector<llvm::Value *> steps;
2270 llvm::SmallVector<llvm::Value *> trips;
2272 llvm::Value *totalTrips;
2274 llvm::Value *lookUpAsI64(mlir::Value val,
const LLVM::ModuleTranslation &mt,
2275 llvm::IRBuilderBase &builder) {
2279 if (v->getType()->isIntegerTy(64))
2281 if (v->getType()->isIntegerTy())
2282 return builder.CreateSExtOrTrunc(v, builder.getInt64Ty());
2287 IteratorInfo(mlir::omp::IteratorOp itersOp,
2288 mlir::LLVM::ModuleTranslation &moduleTranslation,
2289 llvm::IRBuilderBase &builder) {
2290 dims = itersOp.getLoopLowerBounds().size();
2291 lowerBounds.resize(dims);
2292 upperBounds.resize(dims);
2296 for (
unsigned d = 0; d < dims; ++d) {
2297 llvm::Value *lb = lookUpAsI64(itersOp.getLoopLowerBounds()[d],
2298 moduleTranslation, builder);
2299 llvm::Value *ub = lookUpAsI64(itersOp.getLoopUpperBounds()[d],
2300 moduleTranslation, builder);
2302 lookUpAsI64(itersOp.getLoopSteps()[d], moduleTranslation, builder);
2303 assert(lb && ub && st &&
2304 "Expect lowerBounds, upperBounds, and steps in IteratorOp");
2305 assert((!llvm::isa<llvm::ConstantInt>(st) ||
2306 !llvm::cast<llvm::ConstantInt>(st)->isZero()) &&
2307 "Expect non-zero step in IteratorOp");
2309 lowerBounds[d] = lb;
2310 upperBounds[d] = ub;
2314 llvm::Value *diff = builder.CreateSub(ub, lb);
2315 llvm::Value *
div = builder.CreateSDiv(diff, st);
2316 trips[d] = builder.CreateAdd(
2317 div, llvm::ConstantInt::get(builder.getInt64Ty(), 1));
2320 totalTrips = llvm::ConstantInt::get(builder.getInt64Ty(), 1);
2321 for (
unsigned d = 0; d < dims; ++d)
2322 totalTrips = builder.CreateMul(totalTrips, trips[d]);
2325 unsigned getDims()
const {
return dims; }
2326 llvm::ArrayRef<llvm::Value *> getLowerBounds()
const {
return lowerBounds; }
2327 llvm::ArrayRef<llvm::Value *> getUpperBounds()
const {
return upperBounds; }
2328 llvm::ArrayRef<llvm::Value *> getSteps()
const {
return steps; }
2329 llvm::ArrayRef<llvm::Value *> getTrips()
const {
return trips; }
2330 llvm::Value *getTotalTrips()
const {
return totalTrips; }
2335void TaskContextStructManager::generateTaskContextStruct() {
2336 if (privateDecls.empty())
2338 privateVarTypes.reserve(privateDecls.size());
2340 for (omp::PrivateClauseOp &privOp : privateDecls) {
2343 if (!privOp.readsFromMold())
2345 Type mlirType = privOp.getType();
2346 privateVarTypes.push_back(moduleTranslation.
convertType(mlirType));
2349 if (privateVarTypes.empty())
2352 structTy = llvm::StructType::get(moduleTranslation.
getLLVMContext(),
2355 llvm::DataLayout dataLayout =
2356 builder.GetInsertBlock()->getModule()->getDataLayout();
2357 llvm::Type *intPtrTy = builder.getIntPtrTy(dataLayout);
2358 llvm::Constant *allocSize = llvm::ConstantExpr::getSizeOf(structTy);
2361 structPtr = builder.CreateMalloc(intPtrTy, structTy, allocSize,
2363 "omp.task.context_ptr");
2366SmallVector<llvm::Value *> TaskContextStructManager::createGEPsToPrivateVars(
2367 llvm::Value *altStructPtr)
const {
2368 SmallVector<llvm::Value *> ret;
2371 ret.reserve(privateDecls.size());
2372 llvm::Value *zero = builder.getInt32(0);
2374 for (
auto privDecl : privateDecls) {
2375 if (!privDecl.readsFromMold()) {
2377 ret.push_back(
nullptr);
2380 llvm::Value *iVal = builder.getInt32(i);
2381 llvm::Value *gep = builder.CreateGEP(structTy, altStructPtr, {zero, iVal});
2388void TaskContextStructManager::createGEPsToPrivateVars() {
2390 assert(privateVarTypes.empty());
2394 llvmPrivateVarGEPs = createGEPsToPrivateVars(structPtr);
2397void TaskContextStructManager::freeStructPtr() {
2401 llvm::IRBuilderBase::InsertPointGuard guard{builder};
2403 builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
2404 builder.CreateFree(structPtr);
2408 llvm::OpenMPIRBuilder &ompBuilder,
2409 llvm::Value *affinityList, llvm::Value *
index,
2410 llvm::Value *addr, llvm::Value *len) {
2411 llvm::StructType *kmpTaskAffinityInfoTy =
2412 ompBuilder.getKmpTaskAffinityInfoTy();
2413 llvm::Value *entry = builder.CreateInBoundsGEP(
2414 kmpTaskAffinityInfoTy, affinityList,
index,
"omp.affinity.entry");
2416 addr = builder.CreatePtrToInt(addr, kmpTaskAffinityInfoTy->getElementType(0));
2417 len = builder.CreateIntCast(len, kmpTaskAffinityInfoTy->getElementType(1),
2419 llvm::Value *flags = builder.getInt32(0);
2421 builder.CreateStore(addr,
2422 builder.CreateStructGEP(kmpTaskAffinityInfoTy, entry, 0));
2423 builder.CreateStore(len,
2424 builder.CreateStructGEP(kmpTaskAffinityInfoTy, entry, 1));
2425 builder.CreateStore(flags,
2426 builder.CreateStructGEP(kmpTaskAffinityInfoTy, entry, 2));
2430 llvm::IRBuilderBase &builder,
2432 llvm::Value *affinityList) {
2433 for (
auto [i, affinityVar] : llvm::enumerate(affinityVars)) {
2434 auto entryOp = affinityVar.getDefiningOp<mlir::omp::AffinityEntryOp>();
2435 assert(entryOp &&
"affinity item must be omp.affinity_entry");
2437 llvm::Value *addr = moduleTranslation.
lookupValue(entryOp.getAddr());
2438 llvm::Value *len = moduleTranslation.
lookupValue(entryOp.getLen());
2439 assert(addr && len &&
"expect affinity addr and len to be non-null");
2441 affinityList, builder.getInt64(i), addr, len);
2445static mlir::LogicalResult
2448 llvm::IRBuilderBase &builder,
2450 llvm::Value *tmp = linearIV;
2451 for (
int d = (
int)iterInfo.getDims() - 1; d >= 0; --d) {
2452 llvm::Value *trip = iterInfo.getTrips()[d];
2454 llvm::Value *idx = builder.CreateURem(tmp, trip);
2456 tmp = builder.CreateUDiv(tmp, trip);
2459 llvm::Value *physIV = builder.CreateAdd(
2460 iterInfo.getLowerBounds()[d],
2461 builder.CreateMul(idx, iterInfo.getSteps()[d]),
"omp.it.phys_iv");
2467 moduleTranslation.
mapBlock(&iteratorRegionBlock, builder.GetInsertBlock());
2468 if (mlir::failed(moduleTranslation.
convertBlock(iteratorRegionBlock,
2471 return mlir::failure();
2473 return mlir::success();
2479static mlir::LogicalResult
2482 IteratorInfo &iterInfo, llvm::StringRef loopName,
2487 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
2489 auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy bodyIP,
2490 llvm::Value *linearIV) -> llvm::Error {
2491 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2492 builder.restoreIP(bodyIP);
2495 builder, moduleTranslation))) {
2496 return llvm::make_error<llvm::StringError>(
2497 "failed to convert iterator region", llvm::inconvertibleErrorCode());
2501 mlir::dyn_cast<mlir::omp::YieldOp>(iteratorRegionBlock.
getTerminator());
2502 assert(yield && yield.getResults().size() == 1 &&
2503 "expect omp.yield in iterator region to have one result");
2505 genStoreEntry(linearIV, yield);
2511 return llvm::Error::success();
2514 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2516 loc, iterInfo.getTotalTrips(), bodyGen, loopName);
2520 builder.restoreIP(*afterIP);
2522 return mlir::success();
2525static mlir::LogicalResult
2528 llvm::OpenMPIRBuilder::AffinityData &ad) {
2530 if (taskOp.getAffinityVars().empty() && taskOp.getIterated().empty()) {
2533 return mlir::success();
2537 llvm::StructType *kmpTaskAffinityInfoTy =
2540 auto allocateAffinityList = [&](llvm::Value *count) -> llvm::Value * {
2541 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2542 if (llvm::isa<llvm::Constant>(count) || llvm::isa<llvm::Argument>(count))
2544 return builder.CreateAlloca(kmpTaskAffinityInfoTy, count,
2545 "omp.affinity_list");
2548 auto createAffinity =
2549 [&](llvm::Value *count,
2550 llvm::Value *info) -> llvm::OpenMPIRBuilder::AffinityData {
2551 llvm::OpenMPIRBuilder::AffinityData ad{};
2552 ad.Count = builder.CreateTrunc(count, builder.getInt32Ty());
2554 builder.CreatePointerBitCastOrAddrSpaceCast(info, builder.getPtrTy(0));
2558 if (!taskOp.getAffinityVars().empty()) {
2559 llvm::Value *count = llvm::ConstantInt::get(
2560 builder.getInt64Ty(), taskOp.getAffinityVars().size());
2561 llvm::Value *list = allocateAffinityList(count);
2564 ads.emplace_back(createAffinity(count, list));
2567 if (!taskOp.getIterated().empty()) {
2568 for (
auto [i, iter] : llvm::enumerate(taskOp.getIterated())) {
2569 auto itersOp = iter.getDefiningOp<omp::IteratorOp>();
2570 assert(itersOp &&
"iterated value must be defined by omp.iterator");
2571 IteratorInfo iterInfo(itersOp, moduleTranslation, builder);
2572 llvm::Value *affList = allocateAffinityList(iterInfo.getTotalTrips());
2574 itersOp, builder, moduleTranslation, iterInfo,
"iterator",
2575 [&](llvm::Value *linearIV, mlir::omp::YieldOp yield) {
2576 auto entryOp = yield.getResults()[0]
2577 .getDefiningOp<mlir::omp::AffinityEntryOp>();
2578 assert(entryOp &&
"expect yield produce an affinity entry");
2585 affList, linearIV, addr, len);
2587 return llvm::failure();
2588 ads.emplace_back(createAffinity(iterInfo.getTotalTrips(), affList));
2592 llvm::Value *totalAffinityCount = builder.getInt32(0);
2593 for (
const auto &affinity : ads)
2594 totalAffinityCount = builder.CreateAdd(
2596 builder.CreateIntCast(affinity.Count, builder.getInt32Ty(),
2599 llvm::Value *affinityInfo = ads.front().Info;
2600 if (ads.size() > 1) {
2601 llvm::StructType *kmpTaskAffinityInfoTy =
2603 llvm::Value *affinityInfoElemSize = builder.getInt64(
2604 moduleTranslation.
getLLVMModule()->getDataLayout().getTypeAllocSize(
2605 kmpTaskAffinityInfoTy));
2607 llvm::Value *packedAffinityInfo = allocateAffinityList(totalAffinityCount);
2608 llvm::Value *packedAffinityInfoOffset = builder.getInt32(0);
2609 for (
const auto &affinity : ads) {
2610 llvm::Value *affinityCount = builder.CreateIntCast(
2611 affinity.Count, builder.getInt32Ty(),
false);
2612 llvm::Value *affinityCountInt64 = builder.CreateIntCast(
2613 affinityCount, builder.getInt64Ty(),
false);
2614 llvm::Value *affinityInfoSize =
2615 builder.CreateMul(affinityCountInt64, affinityInfoElemSize);
2617 llvm::Value *packedAffinityInfoIndex = builder.CreateIntCast(
2618 packedAffinityInfoOffset, kmpTaskAffinityInfoTy->getElementType(0),
2620 packedAffinityInfoIndex = builder.CreateInBoundsGEP(
2621 kmpTaskAffinityInfoTy, packedAffinityInfo, packedAffinityInfoIndex);
2623 builder.CreateMemCpy(
2624 packedAffinityInfoIndex, llvm::Align(1),
2625 builder.CreatePointerBitCastOrAddrSpaceCast(
2626 affinity.Info, builder.getPtrTy(packedAffinityInfoIndex->getType()
2627 ->getPointerAddressSpace())),
2628 llvm::Align(1), affinityInfoSize);
2630 packedAffinityInfoOffset =
2631 builder.CreateAdd(packedAffinityInfoOffset, affinityCount);
2634 affinityInfo = packedAffinityInfo;
2637 ad.Count = totalAffinityCount;
2638 ad.Info = affinityInfo;
2640 return mlir::success();
2647 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2652 TaskContextStructManager taskStructMgr{builder, moduleTranslation,
2664 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2669 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end());
2670 llvm::BasicBlock *taskStartBlock = llvm::BasicBlock::Create(
2671 builder.getContext(),
"omp.task.start",
2672 builder.GetInsertBlock()->getParent());
2673 llvm::Instruction *branchToTaskStartBlock = builder.CreateBr(taskStartBlock);
2674 builder.SetInsertPoint(branchToTaskStartBlock);
2677 llvm::BasicBlock *copyBlock =
2678 splitBB(builder,
true,
"omp.private.copy");
2679 llvm::BasicBlock *initBlock =
2680 splitBB(builder,
true,
"omp.private.init");
2696 moduleTranslation, allocaIP);
2699 builder.SetInsertPoint(initBlock->getTerminator());
2702 taskStructMgr.generateTaskContextStruct();
2709 taskStructMgr.createGEPsToPrivateVars();
2711 for (
auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVarAlloc] :
2714 taskStructMgr.getLLVMPrivateVarGEPs())) {
2716 if (!privDecl.readsFromMold())
2718 assert(llvmPrivateVarAlloc &&
2719 "reads from mold so shouldn't have been skipped");
2722 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2723 blockArg, llvmPrivateVarAlloc, initBlock);
2724 if (!privateVarOrErr)
2725 return handleError(privateVarOrErr, *taskOp.getOperation());
2734 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
2735 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2736 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
2738 llvmPrivateVarAlloc = builder.CreateLoad(privateVarOrErr.get()->getType(),
2739 llvmPrivateVarAlloc);
2741 assert(llvmPrivateVarAlloc->getType() ==
2742 moduleTranslation.
convertType(blockArg.getType()));
2752 taskOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
2753 taskStructMgr.getLLVMPrivateVarGEPs(), privateVarsInfo.
privatizers,
2754 taskOp.getPrivateNeedsBarrier())))
2755 return llvm::failure();
2757 llvm::OpenMPIRBuilder::AffinityData ad;
2759 return llvm::failure();
2762 builder.SetInsertPoint(taskStartBlock);
2764 auto bodyCB = [&](InsertPointTy allocaIP,
2765 InsertPointTy codegenIP) -> llvm::Error {
2769 moduleTranslation, allocaIP);
2772 builder.restoreIP(codegenIP);
2774 llvm::BasicBlock *privInitBlock =
nullptr;
2776 for (
auto [i, zip] : llvm::enumerate(llvm::zip_equal(
2779 auto [blockArg, privDecl, mlirPrivVar] = zip;
2781 if (privDecl.readsFromMold())
2784 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2785 llvm::Type *llvmAllocType =
2786 moduleTranslation.
convertType(privDecl.getType());
2787 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
2788 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
2789 llvmAllocType,
nullptr,
"omp.private.alloc");
2792 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2793 blockArg, llvmPrivateVar, privInitBlock);
2794 if (!privateVarOrError)
2795 return privateVarOrError.takeError();
2796 moduleTranslation.
mapValue(blockArg, privateVarOrError.get());
2797 privateVarsInfo.
llvmVars[i] = privateVarOrError.get();
2800 taskStructMgr.createGEPsToPrivateVars();
2801 for (
auto [i, llvmPrivVar] :
2802 llvm::enumerate(taskStructMgr.getLLVMPrivateVarGEPs())) {
2804 assert(privateVarsInfo.
llvmVars[i] &&
2805 "This is added in the loop above");
2808 privateVarsInfo.
llvmVars[i] = llvmPrivVar;
2813 for (
auto [blockArg, llvmPrivateVar, privateDecl] :
2817 if (!privateDecl.readsFromMold())
2820 if (!mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2821 llvmPrivateVar = builder.CreateLoad(
2822 moduleTranslation.
convertType(blockArg.getType()), llvmPrivateVar);
2824 assert(llvmPrivateVar->getType() ==
2825 moduleTranslation.
convertType(blockArg.getType()));
2826 moduleTranslation.
mapValue(blockArg, llvmPrivateVar);
2830 taskOp.getRegion(),
"omp.task.region", builder, moduleTranslation);
2831 if (failed(
handleError(continuationBlockOrError, *taskOp)))
2832 return llvm::make_error<PreviouslyReportedError>();
2834 builder.SetInsertPoint(continuationBlockOrError.get()->getTerminator());
2839 return llvm::make_error<PreviouslyReportedError>();
2842 taskStructMgr.freeStructPtr();
2844 return llvm::Error::success();
2853 llvm::omp::Directive::OMPD_taskgroup);
2857 moduleTranslation, dds);
2859 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2860 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2862 ompLoc, allocaIP, bodyCB, !taskOp.getUntied(),
2864 moduleTranslation.
lookupValue(taskOp.getIfExpr()), dds, ad,
2865 taskOp.getMergeable(),
2866 moduleTranslation.
lookupValue(taskOp.getEventHandle()),
2867 moduleTranslation.
lookupValue(taskOp.getPriority()));
2875 builder.restoreIP(*afterIP);
2883 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2884 auto taskloopOp = cast<omp::TaskloopOp>(opInst);
2892 TaskContextStructManager taskStructMgr{builder, moduleTranslation,
2895 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2898 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end());
2899 llvm::BasicBlock *taskloopStartBlock = llvm::BasicBlock::Create(
2900 builder.getContext(),
"omp.taskloop.start",
2901 builder.GetInsertBlock()->getParent());
2902 llvm::Instruction *branchToTaskloopStartBlock =
2903 builder.CreateBr(taskloopStartBlock);
2904 builder.SetInsertPoint(branchToTaskloopStartBlock);
2906 llvm::BasicBlock *copyBlock =
2907 splitBB(builder,
true,
"omp.private.copy");
2908 llvm::BasicBlock *initBlock =
2909 splitBB(builder,
true,
"omp.private.init");
2912 moduleTranslation, allocaIP);
2915 builder.SetInsertPoint(initBlock->getTerminator());
2918 taskStructMgr.generateTaskContextStruct();
2919 taskStructMgr.createGEPsToPrivateVars();
2921 llvmFirstPrivateVars.resize(privateVarsInfo.
blockArgs.size());
2923 for (
auto [i, zip] : llvm::enumerate(llvm::zip_equal(
2925 privateVarsInfo.
blockArgs, taskStructMgr.getLLVMPrivateVarGEPs()))) {
2926 auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVarAlloc] = zip;
2928 if (!privDecl.readsFromMold())
2930 assert(llvmPrivateVarAlloc &&
2931 "reads from mold so shouldn't have been skipped");
2934 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2935 blockArg, llvmPrivateVarAlloc, initBlock);
2936 if (!privateVarOrErr)
2937 return handleError(privateVarOrErr, *taskloopOp.getOperation());
2939 llvmFirstPrivateVars[i] = privateVarOrErr.get();
2941 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2942 builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
2944 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
2945 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2946 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
2948 llvmPrivateVarAlloc = builder.CreateLoad(privateVarOrErr.get()->getType(),
2949 llvmPrivateVarAlloc);
2951 assert(llvmPrivateVarAlloc->getType() ==
2952 moduleTranslation.
convertType(blockArg.getType()));
2958 taskloopOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
2959 taskStructMgr.getLLVMPrivateVarGEPs(), privateVarsInfo.
privatizers,
2960 taskloopOp.getPrivateNeedsBarrier())))
2961 return llvm::failure();
2964 builder.SetInsertPoint(taskloopStartBlock);
2966 auto bodyCB = [&](InsertPointTy allocaIP,
2967 InsertPointTy codegenIP) -> llvm::Error {
2971 moduleTranslation, allocaIP);
2974 builder.restoreIP(codegenIP);
2976 llvm::BasicBlock *privInitBlock =
nullptr;
2978 for (
auto [i, zip] : llvm::enumerate(llvm::zip_equal(
2981 auto [blockArg, privDecl, mlirPrivVar] = zip;
2983 if (privDecl.readsFromMold())
2986 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2987 llvm::Type *llvmAllocType =
2988 moduleTranslation.
convertType(privDecl.getType());
2989 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
2990 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
2991 llvmAllocType,
nullptr,
"omp.private.alloc");
2994 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2995 blockArg, llvmPrivateVar, privInitBlock);
2996 if (!privateVarOrError)
2997 return privateVarOrError.takeError();
2998 moduleTranslation.
mapValue(blockArg, privateVarOrError.get());
2999 privateVarsInfo.
llvmVars[i] = privateVarOrError.get();
3002 taskStructMgr.createGEPsToPrivateVars();
3003 for (
auto [i, llvmPrivVar] :
3004 llvm::enumerate(taskStructMgr.getLLVMPrivateVarGEPs())) {
3006 assert(privateVarsInfo.
llvmVars[i] &&
3007 "This is added in the loop above");
3010 privateVarsInfo.
llvmVars[i] = llvmPrivVar;
3015 for (
auto [blockArg, llvmPrivateVar, privateDecl] :
3019 if (!privateDecl.readsFromMold())
3022 if (!mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
3023 llvmPrivateVar = builder.CreateLoad(
3024 moduleTranslation.
convertType(blockArg.getType()), llvmPrivateVar);
3026 assert(llvmPrivateVar->getType() ==
3027 moduleTranslation.
convertType(blockArg.getType()));
3028 moduleTranslation.
mapValue(blockArg, llvmPrivateVar);
3031 auto continuationBlockOrError =
3033 builder, moduleTranslation);
3035 if (failed(
handleError(continuationBlockOrError, opInst)))
3036 return llvm::make_error<PreviouslyReportedError>();
3038 builder.SetInsertPoint(continuationBlockOrError.get()->getTerminator());
3046 taskloopOp.getLoc(), privateVarsInfo.
llvmVars,
3048 return llvm::make_error<PreviouslyReportedError>();
3051 taskStructMgr.freeStructPtr();
3053 return llvm::Error::success();
3059 auto taskDupCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP,
3060 llvm::Value *destPtr, llvm::Value *srcPtr)
3062 llvm::IRBuilderBase::InsertPointGuard guard(builder);
3063 builder.restoreIP(codegenIP);
3066 builder.getPtrTy(srcPtr->getType()->getPointerAddressSpace());
3068 builder.CreateLoad(ptrTy, srcPtr,
"omp.taskloop.context.src");
3070 TaskContextStructManager &srcStructMgr = taskStructMgr;
3071 TaskContextStructManager destStructMgr(builder, moduleTranslation,
3073 destStructMgr.generateTaskContextStruct();
3074 llvm::Value *dest = destStructMgr.getStructPtr();
3075 dest->setName(
"omp.taskloop.context.dest");
3076 builder.CreateStore(dest, destPtr);
3079 srcStructMgr.createGEPsToPrivateVars(src);
3081 destStructMgr.createGEPsToPrivateVars(dest);
3084 for (
auto [privDecl, mold, blockArg, llvmPrivateVarAlloc] :
3085 llvm::zip_equal(privateVarsInfo.
privatizers, srcGEPs,
3088 if (!privDecl.readsFromMold())
3090 assert(llvmPrivateVarAlloc &&
3091 "reads from mold so shouldn't have been skipped");
3094 initPrivateVar(builder, moduleTranslation, privDecl, mold, blockArg,
3095 llvmPrivateVarAlloc, builder.GetInsertBlock());
3096 if (!privateVarOrErr)
3097 return privateVarOrErr.takeError();
3106 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
3107 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
3108 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
3110 llvmPrivateVarAlloc = builder.CreateLoad(
3111 privateVarOrErr.get()->getType(), llvmPrivateVarAlloc);
3113 assert(llvmPrivateVarAlloc->getType() ==
3114 moduleTranslation.
convertType(blockArg.getType()));
3122 &opInst, builder, moduleTranslation, srcGEPs, destGEPs,
3123 privateVarsInfo.
privatizers, taskloopOp.getPrivateNeedsBarrier())))
3124 return llvm::make_error<PreviouslyReportedError>();
3126 return builder.saveIP();
3129 auto loopOp = cast<omp::LoopNestOp>(taskloopOp.getWrappedLoop());
3139 llvm::Type *boundType =
3140 moduleTranslation.
lookupValue(lowerBounds[0])->getType();
3141 llvm::Value *lbVal =
nullptr;
3142 llvm::Value *ubVal = builder.getIntN(boundType->getIntegerBitWidth(), 1);
3143 llvm::Value *stepVal =
nullptr;
3144 if (loopOp.getCollapseNumLoops() > 1) {
3162 for (uint64_t i = 0; i < loopOp.getCollapseNumLoops(); i++) {
3163 llvm::Value *loopLb = moduleTranslation.
lookupValue(lowerBounds[i]);
3164 llvm::Value *loopUb = moduleTranslation.
lookupValue(upperBounds[i]);
3165 llvm::Value *loopStep = moduleTranslation.
lookupValue(steps[i]);
3171 llvm::Value *loopLbMinusOne = builder.CreateSub(
3172 loopLb, builder.getIntN(boundType->getIntegerBitWidth(), 1));
3173 llvm::Value *loopUbMinusOne = builder.CreateSub(
3174 loopUb, builder.getIntN(boundType->getIntegerBitWidth(), 1));
3175 llvm::Value *boundsCmp = builder.CreateICmpSLT(loopLb, loopUb);
3176 llvm::Value *ubMinusLb = builder.CreateSub(loopUb, loopLbMinusOne);
3177 llvm::Value *lbMinusUb = builder.CreateSub(loopLb, loopUbMinusOne);
3178 llvm::Value *loopTripCount =
3179 builder.CreateSelect(boundsCmp, ubMinusLb, lbMinusUb);
3180 loopTripCount = builder.CreateBinaryIntrinsic(
3181 llvm::Intrinsic::abs, loopTripCount, builder.getFalse());
3185 llvm::Value *loopTripCountDivStep =
3186 builder.CreateSDiv(loopTripCount, loopStep);
3187 loopTripCountDivStep = builder.CreateBinaryIntrinsic(
3188 llvm::Intrinsic::abs, loopTripCountDivStep, builder.getFalse());
3189 llvm::Value *loopTripCountRem =
3190 builder.CreateSRem(loopTripCount, loopStep);
3191 loopTripCountRem = builder.CreateBinaryIntrinsic(
3192 llvm::Intrinsic::abs, loopTripCountRem, builder.getFalse());
3193 llvm::Value *needsRoundUp = builder.CreateICmpNE(
3195 builder.getIntN(loopTripCountRem->getType()->getIntegerBitWidth(),
3198 builder.CreateAdd(loopTripCountDivStep,
3199 builder.CreateZExtOrTrunc(
3200 needsRoundUp, loopTripCountDivStep->getType()));
3201 ubVal = builder.CreateMul(ubVal, loopTripCount);
3203 lbVal = builder.getIntN(boundType->getIntegerBitWidth(), 1);
3204 stepVal = builder.getIntN(boundType->getIntegerBitWidth(), 1);
3206 lbVal = moduleTranslation.
lookupValue(lowerBounds[0]);
3207 ubVal = moduleTranslation.
lookupValue(upperBounds[0]);
3208 stepVal = moduleTranslation.
lookupValue(steps[0]);
3210 assert(lbVal !=
nullptr &&
"Expected value for lbVal");
3211 assert(ubVal !=
nullptr &&
"Expected value for ubVal");
3212 assert(stepVal !=
nullptr &&
"Expected value for stepVal");
3214 llvm::Value *ifCond =
nullptr;
3215 llvm::Value *grainsize =
nullptr;
3217 mlir::Value grainsizeVal = taskloopOp.getGrainsize();
3218 mlir::Value numTasksVal = taskloopOp.getNumTasks();
3219 if (
Value ifVar = taskloopOp.getIfExpr())
3222 grainsize = moduleTranslation.
lookupValue(grainsizeVal);
3224 }
else if (numTasksVal) {
3225 grainsize = moduleTranslation.
lookupValue(numTasksVal);
3229 llvm::OpenMPIRBuilder::TaskDupCallbackTy taskDupOrNull =
nullptr;
3230 if (taskStructMgr.getStructPtr())
3231 taskDupOrNull = taskDupCB;
3241 llvm::omp::Directive::OMPD_taskgroup);
3243 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3244 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3246 ompLoc, allocaIP, bodyCB, loopInfo, lbVal, ubVal, stepVal,
3247 taskloopOp.getUntied(), ifCond, grainsize, taskloopOp.getNogroup(),
3248 sched, moduleTranslation.
lookupValue(taskloopOp.getFinal()),
3249 taskloopOp.getMergeable(),
3250 moduleTranslation.
lookupValue(taskloopOp.getPriority()),
3251 loopOp.getCollapseNumLoops(), taskDupOrNull,
3252 taskStructMgr.getStructPtr());
3259 builder.restoreIP(*afterIP);
3267 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
3271 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
3272 builder.restoreIP(codegenIP);
3274 builder, moduleTranslation)
3279 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3280 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3287 builder.restoreIP(*afterIP);
3306 auto wsloopOp = cast<omp::WsloopOp>(opInst);
3310 auto loopOp = cast<omp::LoopNestOp>(wsloopOp.getWrappedLoop());
3312 assert(isByRef.size() == wsloopOp.getNumReductionVars());
3316 wsloopOp.getScheduleKind().value_or(omp::ClauseScheduleKind::Static);
3319 llvm::Value *step = moduleTranslation.
lookupValue(loopOp.getLoopSteps()[0]);
3320 llvm::Type *ivType = step->getType();
3321 llvm::Value *chunk =
nullptr;
3322 if (wsloopOp.getScheduleChunk()) {
3323 llvm::Value *chunkVar =
3324 moduleTranslation.
lookupValue(wsloopOp.getScheduleChunk());
3325 chunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
3328 omp::DistributeOp distributeOp =
nullptr;
3329 llvm::Value *distScheduleChunk =
nullptr;
3330 bool hasDistSchedule =
false;
3331 if (llvm::isa_and_present<omp::DistributeOp>(opInst.
getParentOp())) {
3332 distributeOp = cast<omp::DistributeOp>(opInst.
getParentOp());
3333 hasDistSchedule = distributeOp.getDistScheduleStatic();
3334 if (distributeOp.getDistScheduleChunkSize()) {
3335 llvm::Value *chunkVar = moduleTranslation.
lookupValue(
3336 distributeOp.getDistScheduleChunkSize());
3337 distScheduleChunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
3345 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3349 wsloopOp.getNumReductionVars());
3352 builder, moduleTranslation, privateVarsInfo, allocaIP);
3359 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
3364 moduleTranslation, allocaIP, reductionDecls,
3365 privateReductionVariables, reductionVariableMap,
3366 deferredStores, isByRef)))
3375 wsloopOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
3377 wsloopOp.getPrivateNeedsBarrier())))
3380 assert(afterAllocas.get()->getSinglePredecessor());
3381 if (failed(initReductionVars(wsloopOp, reductionArgs, builder,
3383 afterAllocas.get()->getSinglePredecessor(),
3384 reductionDecls, privateReductionVariables,
3385 reductionVariableMap, isByRef, deferredStores)))
3389 bool isOrdered = wsloopOp.getOrdered().has_value();
3390 std::optional<omp::ScheduleModifier> scheduleMod = wsloopOp.getScheduleMod();
3391 bool isSimd = wsloopOp.getScheduleSimd();
3392 bool loopNeedsBarrier = !wsloopOp.getNowait();
3397 llvm::omp::WorksharingLoopType workshareLoopType =
3398 llvm::isa_and_present<omp::DistributeOp>(opInst.
getParentOp())
3399 ? llvm::omp::WorksharingLoopType::DistributeForStaticLoop
3400 : llvm::omp::WorksharingLoopType::ForStaticLoop;
3404 llvm::omp::Directive::OMPD_for);
3406 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3409 LinearClauseProcessor linearClauseProcessor;
3411 if (!wsloopOp.getLinearVars().empty()) {
3412 auto linearVarTypes = wsloopOp.getLinearVarTypes().value();
3414 linearClauseProcessor.registerType(moduleTranslation, linearVarType);
3416 for (
auto [idx, linearVar] : llvm::enumerate(wsloopOp.getLinearVars()))
3417 linearClauseProcessor.createLinearVar(
3418 builder, moduleTranslation, moduleTranslation.
lookupValue(linearVar),
3420 for (
mlir::Value linearStep : wsloopOp.getLinearStepVars())
3421 linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
3424 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
3426 wsloopOp.getRegion(),
"omp.wsloop.region", builder, moduleTranslation);
3434 if (!wsloopOp.getLinearVars().empty()) {
3435 linearClauseProcessor.initLinearVar(builder, moduleTranslation,
3436 loopInfo->getPreheader());
3437 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
3439 builder.saveIP(), llvm::omp::OMPD_barrier);
3442 builder.restoreIP(*afterBarrierIP);
3443 linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
3444 loopInfo->getIndVar());
3445 linearClauseProcessor.splitLinearFiniBB(builder, loopInfo->getExit());
3448 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
3451 bool noLoopMode =
false;
3452 omp::TargetOp targetOp = wsloopOp->getParentOfType<mlir::omp::TargetOp>();
3454 Operation *targetCapturedOp = targetOp.getInnermostCapturedOmpOp();
3458 if (loopOp == targetCapturedOp) {
3459 omp::TargetRegionFlags kernelFlags =
3460 targetOp.getKernelExecFlags(targetCapturedOp);
3461 if (omp::bitEnumContainsAll(kernelFlags,
3462 omp::TargetRegionFlags::spmd |
3463 omp::TargetRegionFlags::no_loop) &&
3464 !omp::bitEnumContainsAny(kernelFlags,
3465 omp::TargetRegionFlags::generic))
3470 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
3471 ompBuilder->applyWorkshareLoop(
3472 ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
3473 convertToScheduleKind(schedule), chunk, isSimd,
3474 scheduleMod == omp::ScheduleModifier::monotonic,
3475 scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
3476 workshareLoopType, noLoopMode, hasDistSchedule, distScheduleChunk);
3482 if (!wsloopOp.getLinearVars().empty()) {
3483 llvm::OpenMPIRBuilder::InsertPointTy oldIP = builder.saveIP();
3484 assert(loopInfo->getLastIter() &&
3485 "`lastiter` in CanonicalLoopInfo is nullptr");
3486 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
3487 linearClauseProcessor.finalizeLinearVar(builder, moduleTranslation,
3488 loopInfo->getLastIter());
3491 for (
size_t index = 0;
index < wsloopOp.getLinearVars().size();
index++)
3492 linearClauseProcessor.rewriteInPlace(
3493 builder, sourceBlock->getSingleSuccessor(), *regionBlock,
3494 "omp.loop_nest.region",
index);
3496 builder.restoreIP(oldIP);
3504 wsloopOp, builder, moduleTranslation, allocaIP, reductionDecls,
3505 privateReductionVariables, isByRef, wsloopOp.getNowait(),
3518 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
3520 assert(isByRef.size() == opInst.getNumReductionVars());
3533 opInst.getNumReductionVars());
3536 auto bodyGenCB = [&](InsertPointTy allocaIP,
3537 InsertPointTy codeGenIP) -> llvm::Error {
3539 builder, moduleTranslation, privateVarsInfo, allocaIP);
3541 return llvm::make_error<PreviouslyReportedError>();
3547 cast<omp::BlockArgOpenMPOpInterface>(*opInst).getReductionBlockArgs();
3550 InsertPointTy(allocaIP.getBlock(),
3551 allocaIP.getBlock()->getTerminator()->getIterator());
3554 opInst, reductionArgs, builder, moduleTranslation, allocaIP,
3555 reductionDecls, privateReductionVariables, reductionVariableMap,
3556 deferredStores, isByRef)))
3557 return llvm::make_error<PreviouslyReportedError>();
3559 assert(afterAllocas.get()->getSinglePredecessor());
3560 builder.restoreIP(codeGenIP);
3566 return llvm::make_error<PreviouslyReportedError>();
3569 opInst, builder, moduleTranslation, privateVarsInfo.
mlirVars,
3571 opInst.getPrivateNeedsBarrier())))
3572 return llvm::make_error<PreviouslyReportedError>();
3575 initReductionVars(opInst, reductionArgs, builder, moduleTranslation,
3576 afterAllocas.get()->getSinglePredecessor(),
3577 reductionDecls, privateReductionVariables,
3578 reductionVariableMap, isByRef, deferredStores)))
3579 return llvm::make_error<PreviouslyReportedError>();
3584 moduleTranslation, allocaIP);
3588 opInst.getRegion(),
"omp.par.region", builder, moduleTranslation);
3590 return regionBlock.takeError();
3593 if (opInst.getNumReductionVars() > 0) {
3598 owningReductionGenRefDataPtrGens;
3600 collectReductionInfo(opInst, builder, moduleTranslation, reductionDecls,
3602 owningReductionGenRefDataPtrGens,
3603 privateReductionVariables, reductionInfos, isByRef);
3606 builder.SetInsertPoint((*regionBlock)->getTerminator());
3609 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
3610 builder.SetInsertPoint(tempTerminator);
3612 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
3613 ompBuilder->createReductions(
3614 builder.saveIP(), allocaIP, reductionInfos, isByRef,
3616 if (!contInsertPoint)
3617 return contInsertPoint.takeError();
3619 if (!contInsertPoint->getBlock())
3620 return llvm::make_error<PreviouslyReportedError>();
3622 tempTerminator->eraseFromParent();
3623 builder.restoreIP(*contInsertPoint);
3626 return llvm::Error::success();
3629 auto privCB = [](InsertPointTy allocaIP, InsertPointTy codeGenIP,
3630 llvm::Value &, llvm::Value &val, llvm::Value *&replVal) {
3639 auto finiCB = [&](InsertPointTy codeGenIP) -> llvm::Error {
3640 InsertPointTy oldIP = builder.saveIP();
3641 builder.restoreIP(codeGenIP);
3646 llvm::transform(reductionDecls, std::back_inserter(reductionCleanupRegions),
3647 [](omp::DeclareReductionOp reductionDecl) {
3648 return &reductionDecl.getCleanupRegion();
3651 reductionCleanupRegions, privateReductionVariables,
3652 moduleTranslation, builder,
"omp.reduction.cleanup")))
3653 return llvm::createStringError(
3654 "failed to inline `cleanup` region of `omp.declare_reduction`");
3659 return llvm::make_error<PreviouslyReportedError>();
3663 if (isCancellable) {
3664 auto IPOrErr = ompBuilder->createBarrier(
3665 llvm::OpenMPIRBuilder::LocationDescription(builder),
3666 llvm::omp::Directive::OMPD_unknown,
3670 return IPOrErr.takeError();
3673 builder.restoreIP(oldIP);
3674 return llvm::Error::success();
3677 llvm::Value *ifCond =
nullptr;
3678 if (
auto ifVar = opInst.getIfExpr())
3680 llvm::Value *numThreads =
nullptr;
3681 if (!opInst.getNumThreadsVars().empty())
3682 numThreads = moduleTranslation.
lookupValue(opInst.getNumThreads(0));
3683 auto pbKind = llvm::omp::OMP_PROC_BIND_default;
3684 if (
auto bind = opInst.getProcBindKind())
3687 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3689 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3691 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3692 ompBuilder->createParallel(ompLoc, allocaIP, bodyGenCB, privCB, finiCB,
3693 ifCond, numThreads, pbKind, isCancellable);
3698 builder.restoreIP(*afterIP);
3703static llvm::omp::OrderKind
3706 return llvm::omp::OrderKind::OMP_ORDER_unknown;
3708 case omp::ClauseOrderKind::Concurrent:
3709 return llvm::omp::OrderKind::OMP_ORDER_concurrent;
3711 llvm_unreachable(
"Unknown ClauseOrderKind kind");
3719 auto simdOp = cast<omp::SimdOp>(opInst);
3727 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
3730 simdOp.getNumReductionVars());
3735 assert(isByRef.size() == simdOp.getNumReductionVars());
3737 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3741 builder, moduleTranslation, privateVarsInfo, allocaIP);
3746 LinearClauseProcessor linearClauseProcessor;
3748 if (!simdOp.getLinearVars().empty()) {
3749 auto linearVarTypes = simdOp.getLinearVarTypes().value();
3751 linearClauseProcessor.registerType(moduleTranslation, linearVarType);
3752 for (
auto [idx, linearVar] : llvm::enumerate(simdOp.getLinearVars())) {
3753 bool isImplicit =
false;
3754 for (
auto [mlirPrivVar, llvmPrivateVar] : llvm::zip_equal(
3758 if (linearVar == mlirPrivVar) {
3760 linearClauseProcessor.createLinearVar(builder, moduleTranslation,
3761 llvmPrivateVar, idx);
3767 linearClauseProcessor.createLinearVar(
3768 builder, moduleTranslation,
3771 for (
mlir::Value linearStep : simdOp.getLinearStepVars())
3772 linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
3776 moduleTranslation, allocaIP, reductionDecls,
3777 privateReductionVariables, reductionVariableMap,
3778 deferredStores, isByRef)))
3789 assert(afterAllocas.get()->getSinglePredecessor());
3790 if (failed(initReductionVars(simdOp, reductionArgs, builder,
3792 afterAllocas.get()->getSinglePredecessor(),
3793 reductionDecls, privateReductionVariables,
3794 reductionVariableMap, isByRef, deferredStores)))
3797 llvm::ConstantInt *simdlen =
nullptr;
3798 if (std::optional<uint64_t> simdlenVar = simdOp.getSimdlen())
3799 simdlen = builder.getInt64(simdlenVar.value());
3801 llvm::ConstantInt *safelen =
nullptr;
3802 if (std::optional<uint64_t> safelenVar = simdOp.getSafelen())
3803 safelen = builder.getInt64(safelenVar.value());
3805 llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars;
3808 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
3809 std::optional<ArrayAttr> alignmentValues = simdOp.getAlignments();
3811 for (
size_t i = 0; i < operands.size(); ++i) {
3812 llvm::Value *alignment =
nullptr;
3813 llvm::Value *llvmVal = moduleTranslation.
lookupValue(operands[i]);
3814 llvm::Type *ty = llvmVal->getType();
3816 auto intAttr = cast<IntegerAttr>((*alignmentValues)[i]);
3817 alignment = builder.getInt64(intAttr.getInt());
3818 assert(ty->isPointerTy() &&
"Invalid type for aligned variable");
3819 assert(alignment &&
"Invalid alignment value");
3823 if (!intAttr.getValue().isPowerOf2())
3826 auto curInsert = builder.saveIP();
3827 builder.SetInsertPoint(sourceBlock);
3828 llvmVal = builder.CreateLoad(ty, llvmVal);
3829 builder.restoreIP(curInsert);
3830 alignedVars[llvmVal] = alignment;
3834 simdOp.getRegion(),
"omp.simd.region", builder, moduleTranslation);
3841 if (simdOp.getLinearVars().size()) {
3842 linearClauseProcessor.initLinearVar(builder, moduleTranslation,
3843 loopInfo->getPreheader());
3845 linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
3846 loopInfo->getIndVar());
3848 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
3850 ompBuilder->applySimd(loopInfo, alignedVars,
3852 ? moduleTranslation.
lookupValue(simdOp.getIfExpr())
3854 order, simdlen, safelen);
3856 linearClauseProcessor.emitStoresForLinearVar(builder);
3859 bool hasOrderedRegions =
false;
3860 simdOp.getRegion().walk([&](omp::OrderedRegionOp orderedOp) {
3861 hasOrderedRegions =
true;
3865 for (
size_t index = 0;
index < simdOp.getLinearVars().size();
index++) {
3866 llvm::BasicBlock *startBB = sourceBlock->getSingleSuccessor();
3867 llvm::BasicBlock *endBB = *regionBlock;
3868 linearClauseProcessor.rewriteInPlace(builder, startBB, endBB,
3869 "omp.loop_nest.region",
index);
3871 if (hasOrderedRegions) {
3873 linearClauseProcessor.rewriteInPlace(builder, startBB, endBB,
3874 "omp.ordered.region",
index);
3876 linearClauseProcessor.rewriteInPlace(builder, startBB, endBB,
3877 "omp_region.finalize",
index);
3885 for (
auto [i, tuple] : llvm::enumerate(
3886 llvm::zip(reductionDecls, isByRef, simdOp.getReductionVars(),
3887 privateReductionVariables))) {
3888 auto [decl, byRef, reductionVar, privateReductionVar] = tuple;
3890 OwningReductionGen gen =
makeReductionGen(decl, builder, moduleTranslation);
3891 llvm::Value *originalVariable = moduleTranslation.
lookupValue(reductionVar);
3892 llvm::Type *reductionType = moduleTranslation.
convertType(decl.getType());
3896 llvm::Value *redValue = originalVariable;
3899 builder.CreateLoad(reductionType, redValue,
"red.value." + Twine(i));
3900 llvm::Value *privateRedValue = builder.CreateLoad(
3901 reductionType, privateReductionVar,
"red.private.value." + Twine(i));
3902 llvm::Value *reduced;
3904 auto res = gen(builder.saveIP(), redValue, privateRedValue, reduced);
3907 builder.restoreIP(res.get());
3911 builder.CreateStore(reduced, originalVariable);
3916 llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
3917 [](omp::DeclareReductionOp reductionDecl) {
3918 return &reductionDecl.getCleanupRegion();
3921 moduleTranslation, builder,
3922 "omp.reduction.cleanup")))
3935 auto loopOp = cast<omp::LoopNestOp>(opInst);
3941 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3946 auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip,
3947 llvm::Value *iv) -> llvm::Error {
3950 loopOp.getRegion().front().getArgument(loopInfos.size()), iv);
3955 bodyInsertPoints.push_back(ip);
3957 if (loopInfos.size() != loopOp.getNumLoops() - 1)
3958 return llvm::Error::success();
3961 builder.restoreIP(ip);
3963 loopOp.getRegion(),
"omp.loop_nest.region", builder, moduleTranslation);
3965 return regionBlock.takeError();
3967 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
3968 return llvm::Error::success();
3976 for (
unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) {
3977 llvm::Value *lowerBound =
3978 moduleTranslation.
lookupValue(loopOp.getLoopLowerBounds()[i]);
3979 llvm::Value *upperBound =
3980 moduleTranslation.
lookupValue(loopOp.getLoopUpperBounds()[i]);
3981 llvm::Value *step = moduleTranslation.
lookupValue(loopOp.getLoopSteps()[i]);
3986 llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc;
3987 llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP;
3989 loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back(),
3991 computeIP = loopInfos.front()->getPreheaderIP();
3995 ompBuilder->createCanonicalLoop(
3996 loc, bodyGen, lowerBound, upperBound, step,
3997 true, loopOp.getLoopInclusive(), computeIP);
4002 loopInfos.push_back(*loopResult);
4005 llvm::OpenMPIRBuilder::InsertPointTy afterIP =
4006 loopInfos.front()->getAfterIP();
4009 if (
const auto &tiles = loopOp.getTileSizes()) {
4010 llvm::Type *ivType = loopInfos.front()->getIndVarType();
4013 for (
auto tile : tiles.value()) {
4014 llvm::Value *tileVal = llvm::ConstantInt::get(ivType,
tile);
4015 tileSizes.push_back(tileVal);
4018 std::vector<llvm::CanonicalLoopInfo *> newLoops =
4019 ompBuilder->tileLoops(ompLoc.DL, loopInfos, tileSizes);
4023 llvm::BasicBlock *afterBB = newLoops.front()->getAfter();
4024 llvm::BasicBlock *afterAfterBB = afterBB->getSingleSuccessor();
4025 afterIP = {afterAfterBB, afterAfterBB->begin()};
4029 for (
const auto &newLoop : newLoops)
4030 loopInfos.push_back(newLoop);
4034 const auto &numCollapse = loopOp.getCollapseNumLoops();
4036 loopInfos.begin(), loopInfos.begin() + (numCollapse));
4038 auto newTopLoopInfo =
4039 ompBuilder->collapseLoops(ompLoc.DL, collapseLoopInfos, {});
4041 assert(newTopLoopInfo &&
"New top loop information is missing");
4042 moduleTranslation.
stackWalk<OpenMPLoopInfoStackFrame>(
4043 [&](OpenMPLoopInfoStackFrame &frame) {
4044 frame.loopInfo = newTopLoopInfo;
4052 builder.restoreIP(afterIP);
4062 llvm::OpenMPIRBuilder::LocationDescription loopLoc(builder);
4063 Value loopIV = op.getInductionVar();
4064 Value loopTC = op.getTripCount();
4066 llvm::Value *llvmTC = moduleTranslation.
lookupValue(loopTC);
4069 ompBuilder->createCanonicalLoop(
4071 [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *llvmIV) {
4074 moduleTranslation.
mapValue(loopIV, llvmIV);
4076 builder.restoreIP(ip);
4081 return bodyGenStatus.takeError();
4083 llvmTC,
"omp.loop");
4085 return op.emitError(llvm::toString(llvmOrError.takeError()));
4087 llvm::CanonicalLoopInfo *llvmCLI = *llvmOrError;
4088 llvm::IRBuilderBase::InsertPoint afterIP = llvmCLI->getAfterIP();
4089 builder.restoreIP(afterIP);
4092 if (
Value cli = op.getCli())
4105 Value applyee = op.getApplyee();
4106 assert(applyee &&
"Loop to apply unrolling on required");
4108 llvm::CanonicalLoopInfo *consBuilderCLI =
4110 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
4111 ompBuilder->unrollLoopHeuristic(loc.DL, consBuilderCLI);
4119static LogicalResult
applyTile(omp::TileOp op, llvm::IRBuilderBase &builder,
4122 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
4127 for (
Value size : op.getSizes()) {
4128 llvm::Value *translatedSize = moduleTranslation.
lookupValue(size);
4129 assert(translatedSize &&
4130 "sizes clause arguments must already be translated");
4131 translatedSizes.push_back(translatedSize);
4134 for (
Value applyee : op.getApplyees()) {
4135 llvm::CanonicalLoopInfo *consBuilderCLI =
4137 assert(applyee &&
"Canonical loop must already been translated");
4138 translatedLoops.push_back(consBuilderCLI);
4141 auto generatedLoops =
4142 ompBuilder->tileLoops(loc.DL, translatedLoops, translatedSizes);
4143 if (!op.getGeneratees().empty()) {
4144 for (
auto [mlirLoop,
genLoop] :
4145 zip_equal(op.getGeneratees(), generatedLoops))
4150 for (
Value applyee : op.getApplyees())
4158static LogicalResult
applyFuse(omp::FuseOp op, llvm::IRBuilderBase &builder,
4161 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
4165 for (
size_t i = 0; i < op.getApplyees().size(); i++) {
4166 Value applyee = op.getApplyees()[i];
4167 llvm::CanonicalLoopInfo *consBuilderCLI =
4169 assert(applyee &&
"Canonical loop must already been translated");
4170 if (op.getFirst().has_value() && i < op.getFirst().value() - 1)
4171 beforeFuse.push_back(consBuilderCLI);
4172 else if (op.getCount().has_value() &&
4173 i >= op.getFirst().value() + op.getCount().value() - 1)
4174 afterFuse.push_back(consBuilderCLI);
4176 toFuse.push_back(consBuilderCLI);
4179 (op.getGeneratees().empty() ||
4180 beforeFuse.size() + afterFuse.size() + 1 == op.getGeneratees().size()) &&
4181 "Wrong number of generatees");
4184 auto generatedLoop = ompBuilder->fuseLoops(loc.DL, toFuse);
4185 if (!op.getGeneratees().empty()) {
4187 for (; i < beforeFuse.size(); i++)
4188 moduleTranslation.
mapOmpLoop(op.getGeneratees()[i], beforeFuse[i]);
4189 moduleTranslation.
mapOmpLoop(op.getGeneratees()[i++], generatedLoop);
4190 for (; i < afterFuse.size(); i++)
4191 moduleTranslation.
mapOmpLoop(op.getGeneratees()[i], afterFuse[i]);
4195 for (
Value applyee : op.getApplyees())
4202static llvm::AtomicOrdering
4205 return llvm::AtomicOrdering::Monotonic;
4208 case omp::ClauseMemoryOrderKind::Seq_cst:
4209 return llvm::AtomicOrdering::SequentiallyConsistent;
4210 case omp::ClauseMemoryOrderKind::Acq_rel:
4211 return llvm::AtomicOrdering::AcquireRelease;
4212 case omp::ClauseMemoryOrderKind::Acquire:
4213 return llvm::AtomicOrdering::Acquire;
4214 case omp::ClauseMemoryOrderKind::Release:
4215 return llvm::AtomicOrdering::Release;
4216 case omp::ClauseMemoryOrderKind::Relaxed:
4217 return llvm::AtomicOrdering::Monotonic;
4219 llvm_unreachable(
"Unknown ClauseMemoryOrderKind kind");
4226 auto readOp = cast<omp::AtomicReadOp>(opInst);
4231 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4234 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4237 llvm::Value *x = moduleTranslation.
lookupValue(readOp.getX());
4238 llvm::Value *v = moduleTranslation.
lookupValue(readOp.getV());
4240 llvm::Type *elementType =
4241 moduleTranslation.
convertType(readOp.getElementType());
4243 llvm::OpenMPIRBuilder::AtomicOpValue V = {v, elementType,
false,
false};
4244 llvm::OpenMPIRBuilder::AtomicOpValue X = {x, elementType,
false,
false};
4245 builder.restoreIP(ompBuilder->createAtomicRead(ompLoc, X, V, AO, allocaIP));
4253 auto writeOp = cast<omp::AtomicWriteOp>(opInst);
4258 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4261 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4263 llvm::Value *expr = moduleTranslation.
lookupValue(writeOp.getExpr());
4264 llvm::Value *dest = moduleTranslation.
lookupValue(writeOp.getX());
4265 llvm::Type *ty = moduleTranslation.
convertType(writeOp.getExpr().getType());
4266 llvm::OpenMPIRBuilder::AtomicOpValue x = {dest, ty,
false,
4269 ompBuilder->createAtomicWrite(ompLoc, x, expr, ao, allocaIP));
4277 .Case([&](LLVM::AddOp) {
return llvm::AtomicRMWInst::BinOp::Add; })
4278 .Case([&](LLVM::SubOp) {
return llvm::AtomicRMWInst::BinOp::Sub; })
4279 .Case([&](LLVM::AndOp) {
return llvm::AtomicRMWInst::BinOp::And; })
4280 .Case([&](LLVM::OrOp) {
return llvm::AtomicRMWInst::BinOp::Or; })
4281 .Case([&](LLVM::XOrOp) {
return llvm::AtomicRMWInst::BinOp::Xor; })
4282 .Case([&](LLVM::UMaxOp) {
return llvm::AtomicRMWInst::BinOp::UMax; })
4283 .Case([&](LLVM::UMinOp) {
return llvm::AtomicRMWInst::BinOp::UMin; })
4284 .Case([&](LLVM::FAddOp) {
return llvm::AtomicRMWInst::BinOp::FAdd; })
4285 .Case([&](LLVM::FSubOp) {
return llvm::AtomicRMWInst::BinOp::FSub; })
4286 .Default(llvm::AtomicRMWInst::BinOp::BAD_BINOP);
4290 bool &isIgnoreDenormalMode,
4291 bool &isFineGrainedMemory,
4292 bool &isRemoteMemory) {
4293 isIgnoreDenormalMode =
false;
4294 isFineGrainedMemory =
false;
4295 isRemoteMemory =
false;
4296 if (atomicUpdateOp &&
4297 atomicUpdateOp->hasAttr(atomicUpdateOp.getAtomicControlAttrName())) {
4298 mlir::omp::AtomicControlAttr atomicControlAttr =
4299 atomicUpdateOp.getAtomicControlAttr();
4300 isIgnoreDenormalMode = atomicControlAttr.getIgnoreDenormalMode();
4301 isFineGrainedMemory = atomicControlAttr.getFineGrainedMemory();
4302 isRemoteMemory = atomicControlAttr.getRemoteMemory();
4309 llvm::IRBuilderBase &builder,
4316 auto &innerOpList = opInst.getRegion().front().getOperations();
4317 bool isXBinopExpr{
false};
4318 llvm::AtomicRMWInst::BinOp binop;
4320 llvm::Value *llvmExpr =
nullptr;
4321 llvm::Value *llvmX =
nullptr;
4322 llvm::Type *llvmXElementType =
nullptr;
4323 if (innerOpList.size() == 2) {
4329 opInst.getRegion().getArgument(0))) {
4330 return opInst.emitError(
"no atomic update operation with region argument"
4331 " as operand found inside atomic.update region");
4334 isXBinopExpr = innerOp.
getOperand(0) == opInst.getRegion().getArgument(0);
4336 llvmExpr = moduleTranslation.
lookupValue(mlirExpr);
4340 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
4342 llvmX = moduleTranslation.
lookupValue(opInst.getX());
4344 opInst.getRegion().getArgument(0).getType());
4345 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
4349 llvm::AtomicOrdering atomicOrdering =
4354 [&opInst, &moduleTranslation](
4355 llvm::Value *atomicx,
4358 moduleTranslation.
mapValue(*opInst.getRegion().args_begin(), atomicx);
4359 moduleTranslation.
mapBlock(&bb, builder.GetInsertBlock());
4360 if (failed(moduleTranslation.
convertBlock(bb,
true, builder)))
4361 return llvm::make_error<PreviouslyReportedError>();
4363 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.
getTerminator());
4364 assert(yieldop && yieldop.getResults().size() == 1 &&
4365 "terminator must be omp.yield op and it must have exactly one "
4367 return moduleTranslation.
lookupValue(yieldop.getResults()[0]);
4370 bool isIgnoreDenormalMode;
4371 bool isFineGrainedMemory;
4372 bool isRemoteMemory;
4377 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4378 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4379 ompBuilder->createAtomicUpdate(ompLoc, allocaIP, llvmAtomicX, llvmExpr,
4380 atomicOrdering, binop, updateFn,
4381 isXBinopExpr, isIgnoreDenormalMode,
4382 isFineGrainedMemory, isRemoteMemory);
4387 builder.restoreIP(*afterIP);
4393 llvm::IRBuilderBase &builder,
4400 bool isXBinopExpr =
false, isPostfixUpdate =
false;
4401 llvm::AtomicRMWInst::BinOp binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
4403 omp::AtomicUpdateOp atomicUpdateOp = atomicCaptureOp.getAtomicUpdateOp();
4404 omp::AtomicWriteOp atomicWriteOp = atomicCaptureOp.getAtomicWriteOp();
4406 assert((atomicUpdateOp || atomicWriteOp) &&
4407 "internal op must be an atomic.update or atomic.write op");
4409 if (atomicWriteOp) {
4410 isPostfixUpdate =
true;
4411 mlirExpr = atomicWriteOp.getExpr();
4413 isPostfixUpdate = atomicCaptureOp.getSecondOp() ==
4414 atomicCaptureOp.getAtomicUpdateOp().getOperation();
4415 auto &innerOpList = atomicUpdateOp.getRegion().front().getOperations();
4418 if (innerOpList.size() == 2) {
4421 atomicUpdateOp.getRegion().getArgument(0))) {
4422 return atomicUpdateOp.emitError(
4423 "no atomic update operation with region argument"
4424 " as operand found inside atomic.update region");
4428 innerOp.
getOperand(0) == atomicUpdateOp.getRegion().getArgument(0);
4431 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
4435 llvm::Value *llvmExpr = moduleTranslation.
lookupValue(mlirExpr);
4436 llvm::Value *llvmX =
4437 moduleTranslation.
lookupValue(atomicCaptureOp.getAtomicReadOp().getX());
4438 llvm::Value *llvmV =
4439 moduleTranslation.
lookupValue(atomicCaptureOp.getAtomicReadOp().getV());
4440 llvm::Type *llvmXElementType = moduleTranslation.
convertType(
4441 atomicCaptureOp.getAtomicReadOp().getElementType());
4442 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
4445 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicV = {llvmV, llvmXElementType,
4449 llvm::AtomicOrdering atomicOrdering =
4453 [&](llvm::Value *atomicx,
4456 return moduleTranslation.
lookupValue(atomicWriteOp.getExpr());
4457 Block &bb = *atomicUpdateOp.getRegion().
begin();
4458 moduleTranslation.
mapValue(*atomicUpdateOp.getRegion().args_begin(),
4460 moduleTranslation.
mapBlock(&bb, builder.GetInsertBlock());
4461 if (failed(moduleTranslation.
convertBlock(bb,
true, builder)))
4462 return llvm::make_error<PreviouslyReportedError>();
4464 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.
getTerminator());
4465 assert(yieldop && yieldop.getResults().size() == 1 &&
4466 "terminator must be omp.yield op and it must have exactly one "
4468 return moduleTranslation.
lookupValue(yieldop.getResults()[0]);
4471 bool isIgnoreDenormalMode;
4472 bool isFineGrainedMemory;
4473 bool isRemoteMemory;
4475 isFineGrainedMemory, isRemoteMemory);
4478 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4479 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4480 ompBuilder->createAtomicCapture(
4481 ompLoc, allocaIP, llvmAtomicX, llvmAtomicV, llvmExpr, atomicOrdering,
4482 binop, updateFn, atomicUpdateOp, isPostfixUpdate, isXBinopExpr,
4483 isIgnoreDenormalMode, isFineGrainedMemory, isRemoteMemory);
4485 if (failed(
handleError(afterIP, *atomicCaptureOp)))
4488 builder.restoreIP(*afterIP);
4493 omp::ClauseCancellationConstructType directive) {
4494 switch (directive) {
4495 case omp::ClauseCancellationConstructType::Loop:
4496 return llvm::omp::Directive::OMPD_for;
4497 case omp::ClauseCancellationConstructType::Parallel:
4498 return llvm::omp::Directive::OMPD_parallel;
4499 case omp::ClauseCancellationConstructType::Sections:
4500 return llvm::omp::Directive::OMPD_sections;
4501 case omp::ClauseCancellationConstructType::Taskgroup:
4502 return llvm::omp::Directive::OMPD_taskgroup;
4504 llvm_unreachable(
"Unhandled cancellation construct type");
4513 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4516 llvm::Value *ifCond =
nullptr;
4517 if (
Value ifVar = op.getIfExpr())
4520 llvm::omp::Directive cancelledDirective =
4523 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4524 ompBuilder->createCancel(ompLoc, ifCond, cancelledDirective);
4526 if (failed(
handleError(afterIP, *op.getOperation())))
4529 builder.restoreIP(afterIP.get());
4536 llvm::IRBuilderBase &builder,
4541 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4544 llvm::omp::Directive cancelledDirective =
4547 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4548 ompBuilder->createCancellationPoint(ompLoc, cancelledDirective);
4550 if (failed(
handleError(afterIP, *op.getOperation())))
4553 builder.restoreIP(afterIP.get());
4563 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4565 auto threadprivateOp = cast<omp::ThreadprivateOp>(opInst);
4570 Value symAddr = threadprivateOp.getSymAddr();
4573 if (
auto asCast = dyn_cast<LLVM::AddrSpaceCastOp>(symOp))
4576 if (!isa<LLVM::AddressOfOp>(symOp))
4577 return opInst.
emitError(
"Addressing symbol not found");
4578 LLVM::AddressOfOp addressOfOp = dyn_cast<LLVM::AddressOfOp>(symOp);
4580 LLVM::GlobalOp global =
4581 addressOfOp.getGlobal(moduleTranslation.
symbolTable());
4582 llvm::GlobalValue *globalValue = moduleTranslation.
lookupGlobal(global);
4583 llvm::Type *type = globalValue->getValueType();
4584 llvm::TypeSize typeSize =
4585 builder.GetInsertBlock()->getModule()->getDataLayout().getTypeStoreSize(
4587 llvm::ConstantInt *size = builder.getInt64(typeSize.getFixedValue());
4588 llvm::Value *callInst = ompBuilder->createCachedThreadPrivate(
4589 ompLoc, globalValue, size, global.getSymName() +
".cache");
4595static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind
4597 switch (deviceClause) {
4598 case mlir::omp::DeclareTargetDeviceType::host:
4599 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseHost;
4601 case mlir::omp::DeclareTargetDeviceType::nohost:
4602 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseNoHost;
4604 case mlir::omp::DeclareTargetDeviceType::any:
4605 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseAny;
4608 llvm_unreachable(
"unhandled device clause");
4611static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind
4613 mlir::omp::DeclareTargetCaptureClause captureClause) {
4614 switch (captureClause) {
4615 case mlir::omp::DeclareTargetCaptureClause::to:
4616 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo;
4617 case mlir::omp::DeclareTargetCaptureClause::link:
4618 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink;
4619 case mlir::omp::DeclareTargetCaptureClause::enter:
4620 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter;
4621 case mlir::omp::DeclareTargetCaptureClause::none:
4622 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryNone;
4624 llvm_unreachable(
"unhandled capture clause");
4629 if (
auto addrCast = dyn_cast_if_present<LLVM::AddrSpaceCastOp>(op))
4631 if (
auto addressOfOp = dyn_cast_if_present<LLVM::AddressOfOp>(op)) {
4632 auto modOp = addressOfOp->getParentOfType<mlir::ModuleOp>();
4633 return modOp.lookupSymbol(addressOfOp.getGlobalName());
4638static llvm::SmallString<64>
4640 llvm::OpenMPIRBuilder &ompBuilder) {
4642 llvm::raw_svector_ostream os(suffix);
4645 auto fileInfoCallBack = [&loc]() {
4646 return std::pair<std::string, uint64_t>(
4647 llvm::StringRef(loc.getFilename()), loc.getLine());
4650 auto vfs = llvm::vfs::getRealFileSystem();
4653 ompBuilder.getTargetEntryUniqueInfo(fileInfoCallBack, *vfs).FileID);
4655 os <<
"_decl_tgt_ref_ptr";
4661 if (
auto declareTargetGlobal =
4662 dyn_cast_if_present<omp::DeclareTargetInterface>(
4664 if (declareTargetGlobal.getDeclareTargetCaptureClause() ==
4665 omp::DeclareTargetCaptureClause::link)
4671 if (
auto declareTargetGlobal =
4672 dyn_cast_if_present<omp::DeclareTargetInterface>(
4674 if (declareTargetGlobal.getDeclareTargetCaptureClause() ==
4675 omp::DeclareTargetCaptureClause::to ||
4676 declareTargetGlobal.getDeclareTargetCaptureClause() ==
4677 omp::DeclareTargetCaptureClause::enter)
4691 if (
auto declareTargetGlobal =
4692 dyn_cast<omp::DeclareTargetInterface>(gOp.getOperation())) {
4695 if ((declareTargetGlobal.getDeclareTargetCaptureClause() ==
4696 omp::DeclareTargetCaptureClause::link) ||
4697 (declareTargetGlobal.getDeclareTargetCaptureClause() ==
4698 omp::DeclareTargetCaptureClause::to &&
4699 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
4703 if (gOp.getSymName().contains(suffix))
4708 (gOp.getSymName().str() + suffix.str()).str());
4717struct MapInfosTy : llvm::OpenMPIRBuilder::MapInfosTy {
4718 SmallVector<Operation *, 4> Mappers;
4721 void append(MapInfosTy &curInfo) {
4722 Mappers.append(curInfo.Mappers.begin(), curInfo.Mappers.end());
4723 llvm::OpenMPIRBuilder::MapInfosTy::append(curInfo);
4732struct MapInfoData : MapInfosTy {
4733 llvm::SmallVector<bool, 4> IsDeclareTarget;
4734 llvm::SmallVector<bool, 4> IsAMember;
4736 llvm::SmallVector<bool, 4> IsAMapping;
4737 llvm::SmallVector<mlir::Operation *, 4> MapClause;
4738 llvm::SmallVector<llvm::Value *, 4> OriginalValue;
4741 llvm::SmallVector<llvm::Type *, 4> BaseType;
4744 void append(MapInfoData &CurInfo) {
4745 IsDeclareTarget.append(CurInfo.IsDeclareTarget.begin(),
4746 CurInfo.IsDeclareTarget.end());
4747 MapClause.append(CurInfo.MapClause.begin(), CurInfo.MapClause.end());
4748 OriginalValue.append(CurInfo.OriginalValue.begin(),
4749 CurInfo.OriginalValue.end());
4750 BaseType.append(CurInfo.BaseType.begin(), CurInfo.BaseType.end());
4751 MapInfosTy::append(CurInfo);
4755enum class TargetDirectiveEnumTy : uint32_t {
4759 TargetEnterData = 3,
4764static TargetDirectiveEnumTy getTargetDirectiveEnumTyFromOp(Operation *op) {
4765 return llvm::TypeSwitch<Operation *, TargetDirectiveEnumTy>(op)
4766 .Case([](omp::TargetDataOp) {
return TargetDirectiveEnumTy::TargetData; })
4767 .Case([](omp::TargetEnterDataOp) {
4768 return TargetDirectiveEnumTy::TargetEnterData;
4770 .Case([&](omp::TargetExitDataOp) {
4771 return TargetDirectiveEnumTy::TargetExitData;
4773 .Case([&](omp::TargetUpdateOp) {
4774 return TargetDirectiveEnumTy::TargetUpdate;
4776 .Case([&](omp::TargetOp) {
return TargetDirectiveEnumTy::Target; })
4777 .Default([&](Operation *op) {
return TargetDirectiveEnumTy::None; });
4784 if (
auto nestedArrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(
4785 arrTy.getElementType()))
4802 llvm::Value *basePointer,
4803 llvm::Type *baseType,
4804 llvm::IRBuilderBase &builder,
4806 if (
auto memberClause =
4807 mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(clauseOp)) {
4812 if (!memberClause.getBounds().empty()) {
4813 llvm::Value *elementCount = builder.getInt64(1);
4814 for (
auto bounds : memberClause.getBounds()) {
4815 if (
auto boundOp = mlir::dyn_cast_if_present<mlir::omp::MapBoundsOp>(
4816 bounds.getDefiningOp())) {
4821 elementCount = builder.CreateMul(
4825 moduleTranslation.
lookupValue(boundOp.getUpperBound()),
4826 moduleTranslation.
lookupValue(boundOp.getLowerBound())),
4827 builder.getInt64(1)));
4834 if (
auto arrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(type))
4842 return builder.CreateMul(elementCount,
4843 builder.getInt64(underlyingTypeSzInBits / 8));
4854static llvm::omp::OpenMPOffloadMappingFlags
4856 auto mapTypeToBool = [&mlirFlags](omp::ClauseMapFlags flag) {
4857 return (mlirFlags & flag) == flag;
4859 const bool hasExplicitMap =
4860 (mlirFlags &
~omp::ClauseMapFlags::is_device_ptr) !=
4861 omp::ClauseMapFlags::none;
4863 llvm::omp::OpenMPOffloadMappingFlags mapType =
4864 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
4867 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
4870 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
4873 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
4876 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
4879 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
4882 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRIVATE;
4885 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
4888 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
4891 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
4894 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT;
4897 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD;
4900 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH;
4903 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
4904 if (!hasExplicitMap)
4905 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
4915 ArrayRef<Value> useDevAddrOperands = {},
4916 ArrayRef<Value> hasDevAddrOperands = {}) {
4917 auto checkIsAMember = [](
const auto &mapVars,
auto mapOp) {
4925 for (Value mapValue : mapVars) {
4926 auto map = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
4927 for (
auto member : map.getMembers())
4928 if (member == mapOp)
4935 for (Value mapValue : mapVars) {
4936 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
4938 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
4939 mapData.OriginalValue.push_back(moduleTranslation.
lookupValue(offloadPtr));
4940 mapData.Pointers.push_back(mapData.OriginalValue.back());
4942 if (llvm::Value *refPtr =
4944 mapData.IsDeclareTarget.push_back(
true);
4945 mapData.BasePointers.push_back(refPtr);
4947 mapData.IsDeclareTarget.push_back(
true);
4948 mapData.BasePointers.push_back(mapData.OriginalValue.back());
4950 mapData.IsDeclareTarget.push_back(
false);
4951 mapData.BasePointers.push_back(mapData.OriginalValue.back());
4954 mapData.BaseType.push_back(
4955 moduleTranslation.
convertType(mapOp.getVarType()));
4956 mapData.Sizes.push_back(
4957 getSizeInBytes(dl, mapOp.getVarType(), mapOp, mapData.Pointers.back(),
4958 mapData.BaseType.back(), builder, moduleTranslation));
4959 mapData.MapClause.push_back(mapOp.getOperation());
4961 mapData.Names.push_back(LLVM::createMappingInformation(
4963 mapData.DevicePointers.push_back(llvm::OpenMPIRBuilder::DeviceInfoTy::None);
4964 if (mapOp.getMapperId())
4965 mapData.Mappers.push_back(
4967 mapOp, mapOp.getMapperIdAttr()));
4969 mapData.Mappers.push_back(
nullptr);
4970 mapData.IsAMapping.push_back(
true);
4971 mapData.IsAMember.push_back(checkIsAMember(mapVars, mapOp));
4974 auto findMapInfo = [&mapData](llvm::Value *val,
4975 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
4978 for (llvm::Value *basePtr : mapData.OriginalValue) {
4979 if (basePtr == val && mapData.IsAMapping[index]) {
4981 mapData.Types[index] |=
4982 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
4983 mapData.DevicePointers[index] = devInfoTy;
4991 auto addDevInfos = [&](
const llvm::ArrayRef<Value> &useDevOperands,
4992 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
4993 for (Value mapValue : useDevOperands) {
4994 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
4996 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
4997 llvm::Value *origValue = moduleTranslation.
lookupValue(offloadPtr);
5000 if (!findMapInfo(origValue, devInfoTy)) {
5001 mapData.OriginalValue.push_back(origValue);
5002 mapData.Pointers.push_back(mapData.OriginalValue.back());
5003 mapData.IsDeclareTarget.push_back(
false);
5004 mapData.BasePointers.push_back(mapData.OriginalValue.back());
5005 mapData.BaseType.push_back(
5006 moduleTranslation.
convertType(mapOp.getVarType()));
5007 mapData.Sizes.push_back(builder.getInt64(0));
5008 mapData.MapClause.push_back(mapOp.getOperation());
5009 mapData.Types.push_back(
5010 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM);
5011 mapData.Names.push_back(LLVM::createMappingInformation(
5013 mapData.DevicePointers.push_back(devInfoTy);
5014 mapData.Mappers.push_back(
nullptr);
5015 mapData.IsAMapping.push_back(
false);
5016 mapData.IsAMember.push_back(checkIsAMember(useDevOperands, mapOp));
5021 addDevInfos(useDevAddrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
5022 addDevInfos(useDevPtrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer);
5024 for (Value mapValue : hasDevAddrOperands) {
5025 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
5027 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
5028 llvm::Value *origValue = moduleTranslation.
lookupValue(offloadPtr);
5030 auto mapTypeAlways = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
5032 (mapOp.getMapType() & omp::ClauseMapFlags::is_device_ptr) !=
5033 omp::ClauseMapFlags::none;
5035 mapData.OriginalValue.push_back(origValue);
5036 mapData.BasePointers.push_back(origValue);
5037 mapData.Pointers.push_back(origValue);
5038 mapData.IsDeclareTarget.push_back(
false);
5039 mapData.BaseType.push_back(
5040 moduleTranslation.
convertType(mapOp.getVarType()));
5041 mapData.Sizes.push_back(
5042 builder.getInt64(dl.
getTypeSize(mapOp.getVarType())));
5043 mapData.MapClause.push_back(mapOp.getOperation());
5044 if (llvm::to_underlying(mapType & mapTypeAlways)) {
5048 mapData.Types.push_back(mapType);
5052 if (mapOp.getMapperId()) {
5053 mapData.Mappers.push_back(
5055 mapOp, mapOp.getMapperIdAttr()));
5057 mapData.Mappers.push_back(
nullptr);
5062 mapData.Types.push_back(
5063 isDevicePtr ? mapType
5064 : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL);
5065 mapData.Mappers.push_back(
nullptr);
5067 mapData.Names.push_back(LLVM::createMappingInformation(
5069 mapData.DevicePointers.push_back(
5070 isDevicePtr ? llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer
5071 : llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
5072 mapData.IsAMapping.push_back(
false);
5073 mapData.IsAMember.push_back(checkIsAMember(hasDevAddrOperands, mapOp));
5078 auto *res = llvm::find(mapData.MapClause, memberOp);
5079 assert(res != mapData.MapClause.end() &&
5080 "MapInfoOp for member not found in MapData, cannot return index");
5081 return std::distance(mapData.MapClause.begin(), res);
5085 omp::MapInfoOp mapInfo) {
5086 ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
5096 auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]);
5097 auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]);
5099 for (auto it : llvm::zip(memberIndicesA, memberIndicesB)) {
5100 int64_t aIndex = mlir::cast<IntegerAttr>(std::get<0>(it)).getInt();
5101 int64_t bIndex = mlir::cast<IntegerAttr>(std::get<1>(it)).getInt();
5103 if (aIndex == bIndex)
5106 if (aIndex < bIndex)
5109 if (aIndex > bIndex)
5116 bool memberAParent = memberIndicesA.size() < memberIndicesB.size();
5118 occludedChildren.push_back(
b);
5120 occludedChildren.push_back(a);
5121 return memberAParent;
5127 for (
auto v : occludedChildren)
5134 ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
5136 if (indexAttr.size() == 1)
5137 return cast<omp::MapInfoOp>(mapInfo.getMembers()[0].getDefiningOp());
5141 return llvm::cast<omp::MapInfoOp>(
5166static std::vector<llvm::Value *>
5168 llvm::IRBuilderBase &builder,
bool isArrayTy,
5170 std::vector<llvm::Value *> idx;
5181 idx.push_back(builder.getInt64(0));
5182 for (
int i = bounds.size() - 1; i >= 0; --i) {
5183 if (
auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
5184 bounds[i].getDefiningOp())) {
5185 idx.push_back(moduleTranslation.
lookupValue(boundOp.getLowerBound()));
5203 std::vector<llvm::Value *> dimensionIndexSizeOffset;
5204 for (
int i = bounds.size() - 1; i >= 0; --i) {
5205 if (
auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
5206 bounds[i].getDefiningOp())) {
5207 if (i == ((
int)bounds.size() - 1))
5209 moduleTranslation.
lookupValue(boundOp.getLowerBound()));
5211 idx.back() = builder.CreateAdd(
5212 builder.CreateMul(idx.back(), moduleTranslation.
lookupValue(
5213 boundOp.getExtent())),
5214 moduleTranslation.
lookupValue(boundOp.getLowerBound()));
5223 llvm::transform(values, std::back_inserter(ints), [](
Attribute value) {
5224 return cast<IntegerAttr>(value).getInt();
5232 omp::MapInfoOp parentOp) {
5234 if (parentOp.getMembers().empty())
5238 if (parentOp.getMembers().size() == 1) {
5239 overlapMapDataIdxs.push_back(0);
5245 ArrayAttr indexAttr = parentOp.getMembersIndexAttr();
5246 for (
auto [memIndex, indicesAttr] : llvm::enumerate(indexAttr))
5247 memberByIndex.push_back(
5248 std::make_pair(memIndex, cast<ArrayAttr>(indicesAttr)));
5253 llvm::sort(memberByIndex.begin(), memberByIndex.end(),
5254 [&](
auto a,
auto b) { return a.second.size() < b.second.size(); });
5260 for (
auto v : memberByIndex) {
5264 *std::find_if(memberByIndex.begin(), memberByIndex.end(), [&](
auto x) {
5267 llvm::SmallVector<int64_t> xArr(x.second.size());
5268 getAsIntegers(x.second, xArr);
5269 return std::equal(vArr.begin(), vArr.end(), xArr.begin()) &&
5270 xArr.size() >= vArr.size();
5276 for (
auto v : memberByIndex)
5277 if (find(skipList.begin(), skipList.end(), v) == skipList.end())
5278 overlapMapDataIdxs.push_back(v.first);
5290 if (mapOp.getVarPtrPtr())
5319 llvm::OpenMPIRBuilder &ompBuilder,
DataLayout &dl, MapInfosTy &combinedInfo,
5320 MapInfoData &mapData, uint64_t mapDataIndex,
5321 TargetDirectiveEnumTy targetDirective) {
5322 assert(!ompBuilder.Config.isTargetDevice() &&
5323 "function only supported for host device codegen");
5326 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
5328 auto *parentMapper = mapData.Mappers[mapDataIndex];
5334 llvm::omp::OpenMPOffloadMappingFlags baseFlag =
5335 (targetDirective == TargetDirectiveEnumTy::Target &&
5336 !mapData.IsDeclareTarget[mapDataIndex])
5337 ? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM
5338 : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
5341 using mapFlags = llvm::omp::OpenMPOffloadMappingFlags;
5345 mapFlags parentFlags = mapData.Types[mapDataIndex];
5346 mapFlags preserve = mapFlags::OMP_MAP_TO | mapFlags::OMP_MAP_FROM |
5347 mapFlags::OMP_MAP_ALWAYS | mapFlags::OMP_MAP_CLOSE |
5348 mapFlags::OMP_MAP_PRESENT | mapFlags::OMP_MAP_OMPX_HOLD;
5349 baseFlag |= (parentFlags & preserve);
5352 combinedInfo.Types.emplace_back(baseFlag);
5353 combinedInfo.DevicePointers.emplace_back(
5354 mapData.DevicePointers[mapDataIndex]);
5358 combinedInfo.Mappers.emplace_back(
5359 parentMapper && !parentClause.getPartialMap() ? parentMapper :
nullptr);
5361 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
5362 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
5371 llvm::Value *lowAddr, *highAddr;
5372 if (!parentClause.getPartialMap()) {
5373 lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
5374 builder.getPtrTy());
5375 highAddr = builder.CreatePointerCast(
5376 builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
5377 mapData.Pointers[mapDataIndex], 1),
5378 builder.getPtrTy());
5379 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
5381 auto mapOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
5384 lowAddr = builder.CreatePointerCast(mapData.Pointers[firstMemberIdx],
5385 builder.getPtrTy());
5388 highAddr = builder.CreatePointerCast(
5389 builder.CreateGEP(mapData.BaseType[lastMemberIdx],
5390 mapData.Pointers[lastMemberIdx], builder.getInt64(1)),
5391 builder.getPtrTy());
5392 combinedInfo.Pointers.emplace_back(mapData.Pointers[firstMemberIdx]);
5395 llvm::Value *size = builder.CreateIntCast(
5396 builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
5397 builder.getInt64Ty(),
5399 combinedInfo.Sizes.push_back(size);
5401 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag =
5402 ompBuilder.getMemberOfFlag(combinedInfo.BasePointers.size() - 1);
5410 if (!parentClause.getPartialMap()) {
5415 llvm::omp::OpenMPOffloadMappingFlags mapFlag = mapData.Types[mapDataIndex];
5416 bool hasMapClose = (llvm::omp::OpenMPOffloadMappingFlags(mapFlag) &
5417 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE) ==
5418 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
5419 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
5421 if (targetDirective == TargetDirectiveEnumTy::TargetUpdate || hasMapClose) {
5422 combinedInfo.Types.emplace_back(mapFlag);
5423 combinedInfo.DevicePointers.emplace_back(
5424 mapData.DevicePointers[mapDataIndex]);
5426 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
5427 combinedInfo.BasePointers.emplace_back(
5428 mapData.BasePointers[mapDataIndex]);
5429 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
5430 combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIndex]);
5431 combinedInfo.Mappers.emplace_back(
nullptr);
5442 lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
5443 builder.getPtrTy());
5444 highAddr = builder.CreatePointerCast(
5445 builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
5446 mapData.Pointers[mapDataIndex], 1),
5447 builder.getPtrTy());
5454 for (
auto v : overlapIdxs) {
5457 cast<omp::MapInfoOp>(parentClause.getMembers()[v].getDefiningOp()));
5458 combinedInfo.Types.emplace_back(mapFlag);
5459 combinedInfo.DevicePointers.emplace_back(
5460 mapData.DevicePointers[mapDataOverlapIdx]);
5462 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
5463 combinedInfo.BasePointers.emplace_back(
5464 mapData.BasePointers[mapDataIndex]);
5465 combinedInfo.Mappers.emplace_back(
nullptr);
5466 combinedInfo.Pointers.emplace_back(lowAddr);
5467 combinedInfo.Sizes.emplace_back(builder.CreateIntCast(
5468 builder.CreatePtrDiff(builder.getInt8Ty(),
5469 mapData.OriginalValue[mapDataOverlapIdx],
5471 builder.getInt64Ty(),
true));
5472 lowAddr = builder.CreateConstGEP1_32(
5474 mapData.MapClause[mapDataOverlapIdx]))
5475 ? builder.getPtrTy()
5476 : mapData.BaseType[mapDataOverlapIdx],
5477 mapData.BasePointers[mapDataOverlapIdx], 1);
5480 combinedInfo.Types.emplace_back(mapFlag);
5481 combinedInfo.DevicePointers.emplace_back(
5482 mapData.DevicePointers[mapDataIndex]);
5484 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
5485 combinedInfo.BasePointers.emplace_back(
5486 mapData.BasePointers[mapDataIndex]);
5487 combinedInfo.Mappers.emplace_back(
nullptr);
5488 combinedInfo.Pointers.emplace_back(lowAddr);
5489 combinedInfo.Sizes.emplace_back(builder.CreateIntCast(
5490 builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
5491 builder.getInt64Ty(),
true));
5494 return memberOfFlag;
5500 llvm::OpenMPIRBuilder &ompBuilder,
DataLayout &dl, MapInfosTy &combinedInfo,
5501 MapInfoData &mapData, uint64_t mapDataIndex,
5502 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag,
5503 TargetDirectiveEnumTy targetDirective) {
5504 assert(!ompBuilder.Config.isTargetDevice() &&
5505 "function only supported for host device codegen");
5508 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
5510 for (
auto mappedMembers : parentClause.getMembers()) {
5512 llvm::cast<omp::MapInfoOp>(mappedMembers.getDefiningOp());
5515 assert(memberDataIdx >= 0 &&
"could not find mapped member of structure");
5526 mapFlag &=
~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
5527 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
5528 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
5529 combinedInfo.Types.emplace_back(mapFlag);
5530 combinedInfo.DevicePointers.emplace_back(
5531 llvm::OpenMPIRBuilder::DeviceInfoTy::None);
5532 combinedInfo.Mappers.emplace_back(
nullptr);
5533 combinedInfo.Names.emplace_back(
5535 combinedInfo.BasePointers.emplace_back(
5536 mapData.BasePointers[mapDataIndex]);
5537 combinedInfo.Pointers.emplace_back(mapData.BasePointers[memberDataIdx]);
5538 combinedInfo.Sizes.emplace_back(builder.getInt64(
5539 moduleTranslation.
getLLVMModule()->getDataLayout().getPointerSize()));
5545 mapFlag &=
~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
5546 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
5547 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
5549 ? parentClause.getVarPtr()
5550 : parentClause.getVarPtrPtr());
5553 (targetDirective != TargetDirectiveEnumTy::TargetUpdate &&
5554 targetDirective != TargetDirectiveEnumTy::TargetData))) {
5555 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
5558 combinedInfo.Types.emplace_back(mapFlag);
5559 combinedInfo.DevicePointers.emplace_back(
5560 mapData.DevicePointers[memberDataIdx]);
5561 combinedInfo.Mappers.emplace_back(mapData.Mappers[memberDataIdx]);
5562 combinedInfo.Names.emplace_back(
5564 uint64_t basePointerIndex =
5566 combinedInfo.BasePointers.emplace_back(
5567 mapData.BasePointers[basePointerIndex]);
5568 combinedInfo.Pointers.emplace_back(mapData.Pointers[memberDataIdx]);
5570 llvm::Value *size = mapData.Sizes[memberDataIdx];
5572 size = builder.CreateSelect(
5573 builder.CreateIsNull(mapData.Pointers[memberDataIdx]),
5574 builder.getInt64(0), size);
5577 combinedInfo.Sizes.emplace_back(size);
5582 MapInfosTy &combinedInfo,
5583 TargetDirectiveEnumTy targetDirective,
5584 int mapDataParentIdx = -1) {
5588 auto mapFlag = mapData.Types[mapDataIdx];
5589 auto mapInfoOp = llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIdx]);
5593 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
5595 if (targetDirective == TargetDirectiveEnumTy::Target &&
5596 !mapData.IsDeclareTarget[mapDataIdx])
5597 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
5599 if (mapInfoOp.getMapCaptureType() == omp::VariableCaptureKind::ByCopy &&
5601 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
5606 if (mapDataParentIdx >= 0)
5607 combinedInfo.BasePointers.emplace_back(
5608 mapData.BasePointers[mapDataParentIdx]);
5610 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIdx]);
5612 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIdx]);
5613 combinedInfo.DevicePointers.emplace_back(mapData.DevicePointers[mapDataIdx]);
5614 combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIdx]);
5615 combinedInfo.Names.emplace_back(mapData.Names[mapDataIdx]);
5616 combinedInfo.Types.emplace_back(mapFlag);
5617 combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIdx]);
5621 llvm::IRBuilderBase &builder,
5622 llvm::OpenMPIRBuilder &ompBuilder,
5624 MapInfoData &mapData, uint64_t mapDataIndex,
5625 TargetDirectiveEnumTy targetDirective) {
5626 assert(!ompBuilder.Config.isTargetDevice() &&
5627 "function only supported for host device codegen");
5630 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
5635 if (parentClause.getMembers().size() == 1 && parentClause.getPartialMap()) {
5636 auto memberClause = llvm::cast<omp::MapInfoOp>(
5637 parentClause.getMembers()[0].getDefiningOp());
5654 llvm::omp::OpenMPOffloadMappingFlags memberOfParentFlag =
5656 combinedInfo, mapData, mapDataIndex,
5659 combinedInfo, mapData, mapDataIndex,
5660 memberOfParentFlag, targetDirective);
5670 llvm::IRBuilderBase &builder) {
5672 "function only supported for host device codegen");
5673 for (
size_t i = 0; i < mapData.MapClause.size(); ++i) {
5675 if (!mapData.IsDeclareTarget[i]) {
5676 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
5677 omp::VariableCaptureKind captureKind = mapOp.getMapCaptureType();
5687 switch (captureKind) {
5688 case omp::VariableCaptureKind::ByRef: {
5689 llvm::Value *newV = mapData.Pointers[i];
5691 moduleTranslation, builder, mapData.BaseType[i]->isArrayTy(),
5694 newV = builder.CreateLoad(builder.getPtrTy(), newV);
5696 if (!offsetIdx.empty())
5697 newV = builder.CreateInBoundsGEP(mapData.BaseType[i], newV, offsetIdx,
5699 mapData.Pointers[i] = newV;
5701 case omp::VariableCaptureKind::ByCopy: {
5702 llvm::Type *type = mapData.BaseType[i];
5704 if (mapData.Pointers[i]->getType()->isPointerTy())
5705 newV = builder.CreateLoad(type, mapData.Pointers[i]);
5707 newV = mapData.Pointers[i];
5710 auto curInsert = builder.saveIP();
5711 llvm::DebugLoc DbgLoc = builder.getCurrentDebugLocation();
5713 auto *memTempAlloc =
5714 builder.CreateAlloca(builder.getPtrTy(),
nullptr,
".casted");
5715 builder.SetCurrentDebugLocation(DbgLoc);
5716 builder.restoreIP(curInsert);
5718 builder.CreateStore(newV, memTempAlloc);
5719 newV = builder.CreateLoad(builder.getPtrTy(), memTempAlloc);
5722 mapData.Pointers[i] = newV;
5723 mapData.BasePointers[i] = newV;
5725 case omp::VariableCaptureKind::This:
5726 case omp::VariableCaptureKind::VLAType:
5727 mapData.MapClause[i]->emitOpError(
"Unhandled capture kind");
5738 MapInfoData &mapData,
5739 TargetDirectiveEnumTy targetDirective) {
5741 "function only supported for host device codegen");
5762 for (
size_t i = 0; i < mapData.MapClause.size(); ++i) {
5765 if (mapData.IsAMember[i])
5768 auto mapInfoOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[i]);
5769 if (!mapInfoOp.getMembers().empty()) {
5771 combinedInfo, mapData, i, targetDirective);
5779static llvm::Expected<llvm::Function *>
5781 LLVM::ModuleTranslation &moduleTranslation,
5782 llvm::StringRef mapperFuncName,
5783 TargetDirectiveEnumTy targetDirective);
5785static llvm::Expected<llvm::Function *>
5788 TargetDirectiveEnumTy targetDirective) {
5790 "function only supported for host device codegen");
5791 auto declMapperOp = cast<omp::DeclareMapperOp>(op);
5792 std::string mapperFuncName =
5794 {
"omp_mapper", declMapperOp.getSymName()});
5796 if (
auto *lookupFunc = moduleTranslation.
lookupFunction(mapperFuncName))
5804 if (llvm::Function *existingFunc =
5805 moduleTranslation.
getLLVMModule()->getFunction(mapperFuncName)) {
5806 moduleTranslation.
mapFunction(mapperFuncName, existingFunc);
5807 return existingFunc;
5811 mapperFuncName, targetDirective);
5814static llvm::Expected<llvm::Function *>
5817 llvm::StringRef mapperFuncName,
5818 TargetDirectiveEnumTy targetDirective) {
5820 "function only supported for host device codegen");
5821 auto declMapperOp = cast<omp::DeclareMapperOp>(op);
5822 auto declMapperInfoOp = declMapperOp.getDeclareMapperInfo();
5825 llvm::Type *varType = moduleTranslation.
convertType(declMapperOp.getType());
5828 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
5831 MapInfosTy combinedInfo;
5833 [&](InsertPointTy codeGenIP, llvm::Value *ptrPHI,
5834 llvm::Value *unused2) -> llvm::OpenMPIRBuilder::MapInfosOrErrorTy {
5835 builder.restoreIP(codeGenIP);
5836 moduleTranslation.
mapValue(declMapperOp.getSymVal(), ptrPHI);
5837 moduleTranslation.
mapBlock(&declMapperOp.getRegion().front(),
5838 builder.GetInsertBlock());
5839 if (failed(moduleTranslation.
convertBlock(declMapperOp.getRegion().front(),
5842 return llvm::make_error<PreviouslyReportedError>();
5843 MapInfoData mapData;
5846 genMapInfos(builder, moduleTranslation, dl, combinedInfo, mapData,
5852 return combinedInfo;
5856 if (!combinedInfo.Mappers[i])
5859 moduleTranslation, targetDirective);
5863 genMapInfoCB, varType, mapperFuncName, customMapperCB);
5865 return newFn.takeError();
5866 if ([[maybe_unused]] llvm::Function *mappedFunc =
5868 assert(mappedFunc == *newFn &&
5869 "mapper function mapping disagrees with emitted function");
5871 moduleTranslation.
mapFunction(mapperFuncName, *newFn);
5879 llvm::Value *ifCond =
nullptr;
5880 llvm::Value *deviceID = builder.getInt64(llvm::omp::OMP_DEVICEID_UNDEF);
5884 llvm::omp::RuntimeFunction RTLFn;
5886 TargetDirectiveEnumTy targetDirective = getTargetDirectiveEnumTyFromOp(op);
5889 llvm::OpenMPIRBuilder::TargetDataInfo info(
5892 assert(!ompBuilder->Config.isTargetDevice() &&
5893 "target data/enter/exit/update are host ops");
5894 bool isOffloadEntry = !ompBuilder->Config.TargetTriples.empty();
5896 auto getDeviceID = [&](
mlir::Value dev) -> llvm::Value * {
5897 llvm::Value *v = moduleTranslation.
lookupValue(dev);
5898 return builder.CreateIntCast(v, builder.getInt64Ty(),
true);
5903 .Case([&](omp::TargetDataOp dataOp) {
5907 if (
auto ifVar = dataOp.getIfExpr())
5911 deviceID = getDeviceID(devId);
5913 mapVars = dataOp.getMapVars();
5914 useDevicePtrVars = dataOp.getUseDevicePtrVars();
5915 useDeviceAddrVars = dataOp.getUseDeviceAddrVars();
5918 .Case([&](omp::TargetEnterDataOp enterDataOp) -> LogicalResult {
5922 if (
auto ifVar = enterDataOp.getIfExpr())
5926 deviceID = getDeviceID(devId);
5929 enterDataOp.getNowait()
5930 ? llvm::omp::OMPRTL___tgt_target_data_begin_nowait_mapper
5931 : llvm::omp::OMPRTL___tgt_target_data_begin_mapper;
5932 mapVars = enterDataOp.getMapVars();
5933 info.HasNoWait = enterDataOp.getNowait();
5936 .Case([&](omp::TargetExitDataOp exitDataOp) -> LogicalResult {
5940 if (
auto ifVar = exitDataOp.getIfExpr())
5944 deviceID = getDeviceID(devId);
5946 RTLFn = exitDataOp.getNowait()
5947 ? llvm::omp::OMPRTL___tgt_target_data_end_nowait_mapper
5948 : llvm::omp::OMPRTL___tgt_target_data_end_mapper;
5949 mapVars = exitDataOp.getMapVars();
5950 info.HasNoWait = exitDataOp.getNowait();
5953 .Case([&](omp::TargetUpdateOp updateDataOp) -> LogicalResult {
5957 if (
auto ifVar = updateDataOp.getIfExpr())
5961 deviceID = getDeviceID(devId);
5964 updateDataOp.getNowait()
5965 ? llvm::omp::OMPRTL___tgt_target_data_update_nowait_mapper
5966 : llvm::omp::OMPRTL___tgt_target_data_update_mapper;
5967 mapVars = updateDataOp.getMapVars();
5968 info.HasNoWait = updateDataOp.getNowait();
5971 .DefaultUnreachable(
"unexpected operation");
5976 if (!isOffloadEntry)
5977 ifCond = builder.getFalse();
5979 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
5980 MapInfoData mapData;
5982 builder, useDevicePtrVars, useDeviceAddrVars);
5985 MapInfosTy combinedInfo;
5986 auto genMapInfoCB = [&](InsertPointTy codeGenIP) -> MapInfosTy & {
5987 builder.restoreIP(codeGenIP);
5988 genMapInfos(builder, moduleTranslation, DL, combinedInfo, mapData,
5990 return combinedInfo;
5996 [&moduleTranslation](
5997 llvm::OpenMPIRBuilder::DeviceInfoTy type,
6001 for (
auto [arg, useDevVar] :
6002 llvm::zip_equal(blockArgs, useDeviceVars)) {
6004 auto getMapBasePtr = [](omp::MapInfoOp mapInfoOp) {
6005 return mapInfoOp.getVarPtrPtr() ? mapInfoOp.getVarPtrPtr()
6006 : mapInfoOp.getVarPtr();
6009 auto useDevMap = cast<omp::MapInfoOp>(useDevVar.getDefiningOp());
6010 for (
auto [mapClause, devicePointer, basePointer] : llvm::zip_equal(
6011 mapInfoData.MapClause, mapInfoData.DevicePointers,
6012 mapInfoData.BasePointers)) {
6013 auto mapOp = cast<omp::MapInfoOp>(mapClause);
6014 if (getMapBasePtr(mapOp) != getMapBasePtr(useDevMap) ||
6015 devicePointer != type)
6018 if (llvm::Value *devPtrInfoMap =
6019 mapper ? mapper(basePointer) : basePointer) {
6020 moduleTranslation.
mapValue(arg, devPtrInfoMap);
6027 using BodyGenTy = llvm::OpenMPIRBuilder::BodyGenTy;
6028 auto bodyGenCB = [&](InsertPointTy codeGenIP, BodyGenTy bodyGenType)
6029 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
6032 builder.restoreIP(codeGenIP);
6033 assert(isa<omp::TargetDataOp>(op) &&
6034 "BodyGen requested for non TargetDataOp");
6035 auto blockArgIface = cast<omp::BlockArgOpenMPOpInterface>(op);
6036 Region ®ion = cast<omp::TargetDataOp>(op).getRegion();
6037 switch (bodyGenType) {
6038 case BodyGenTy::Priv:
6040 if (!info.DevicePtrInfoMap.empty()) {
6041 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
6042 blockArgIface.getUseDeviceAddrBlockArgs(),
6043 useDeviceAddrVars, mapData,
6044 [&](llvm::Value *basePointer) -> llvm::Value * {
6045 if (!info.DevicePtrInfoMap[basePointer].second)
6047 return builder.CreateLoad(
6049 info.DevicePtrInfoMap[basePointer].second);
6051 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
6052 blockArgIface.getUseDevicePtrBlockArgs(), useDevicePtrVars,
6053 mapData, [&](llvm::Value *basePointer) {
6054 return info.DevicePtrInfoMap[basePointer].second;
6058 moduleTranslation)))
6059 return llvm::make_error<PreviouslyReportedError>();
6062 case BodyGenTy::DupNoPriv:
6063 if (info.DevicePtrInfoMap.empty()) {
6066 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
6067 blockArgIface.getUseDeviceAddrBlockArgs(),
6068 useDeviceAddrVars, mapData);
6069 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
6070 blockArgIface.getUseDevicePtrBlockArgs(), useDevicePtrVars,
6074 case BodyGenTy::NoPriv:
6076 if (info.DevicePtrInfoMap.empty()) {
6078 moduleTranslation)))
6079 return llvm::make_error<PreviouslyReportedError>();
6083 return builder.saveIP();
6086 auto customMapperCB =
6088 if (!combinedInfo.Mappers[i])
6090 info.HasMapper =
true;
6092 moduleTranslation, targetDirective);
6095 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
6096 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
6098 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = [&]() {
6099 if (isa<omp::TargetDataOp>(op))
6100 return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(),
6101 deviceID, ifCond, info, genMapInfoCB,
6105 return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(),
6106 deviceID, ifCond, info, genMapInfoCB,
6107 customMapperCB, &RTLFn);
6113 builder.restoreIP(*afterIP);
6121 auto distributeOp = cast<omp::DistributeOp>(opInst);
6128 bool doDistributeReduction =
6132 unsigned numReductionVars = teamsOp ? teamsOp.getNumReductionVars() : 0;
6137 if (doDistributeReduction) {
6138 isByRef =
getIsByRef(teamsOp.getReductionByref());
6139 assert(isByRef.size() == teamsOp.getNumReductionVars());
6142 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
6146 llvm::cast<omp::BlockArgOpenMPOpInterface>(*teamsOp)
6147 .getReductionBlockArgs();
6150 teamsOp, reductionArgs, builder, moduleTranslation, allocaIP,
6151 reductionDecls, privateReductionVariables, reductionVariableMap,
6156 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
6157 auto bodyGenCB = [&](InsertPointTy allocaIP,
6158 InsertPointTy codeGenIP) -> llvm::Error {
6162 moduleTranslation, allocaIP);
6165 builder.restoreIP(codeGenIP);
6171 return llvm::make_error<PreviouslyReportedError>();
6176 return llvm::make_error<PreviouslyReportedError>();
6179 distributeOp, builder, moduleTranslation, privVarsInfo.
mlirVars,
6181 distributeOp.getPrivateNeedsBarrier())))
6182 return llvm::make_error<PreviouslyReportedError>();
6185 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
6188 builder, moduleTranslation);
6190 return regionBlock.takeError();
6191 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
6196 if (!isa_and_present<omp::WsloopOp>(distributeOp.getNestedWrapper())) {
6199 bool hasDistSchedule = distributeOp.getDistScheduleStatic();
6200 auto schedule = hasDistSchedule ? omp::ClauseScheduleKind::Distribute
6201 : omp::ClauseScheduleKind::Static;
6203 bool isOrdered = hasDistSchedule;
6204 std::optional<omp::ScheduleModifier> scheduleMod;
6205 bool isSimd =
false;
6206 llvm::omp::WorksharingLoopType workshareLoopType =
6207 llvm::omp::WorksharingLoopType::DistributeStaticLoop;
6208 bool loopNeedsBarrier =
false;
6209 llvm::Value *chunk = moduleTranslation.
lookupValue(
6210 distributeOp.getDistScheduleChunkSize());
6211 llvm::CanonicalLoopInfo *loopInfo =
6213 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
6214 ompBuilder->applyWorkshareLoop(
6215 ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
6216 convertToScheduleKind(schedule), chunk, isSimd,
6217 scheduleMod == omp::ScheduleModifier::monotonic,
6218 scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
6219 workshareLoopType,
false, hasDistSchedule, chunk);
6222 return wsloopIP.takeError();
6225 distributeOp.getLoc(), privVarsInfo.
llvmVars,
6227 return llvm::make_error<PreviouslyReportedError>();
6229 return llvm::Error::success();
6232 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
6234 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
6235 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
6236 ompBuilder->createDistribute(ompLoc, allocaIP, bodyGenCB);
6241 builder.restoreIP(*afterIP);
6243 if (doDistributeReduction) {
6246 teamsOp, builder, moduleTranslation, allocaIP, reductionDecls,
6247 privateReductionVariables, isByRef,
6259 if (!cast<mlir::ModuleOp>(op))
6264 ompBuilder->M.addModuleFlag(llvm::Module::Max,
"openmp-device",
6265 attribute.getOpenmpDeviceVersion());
6267 if (attribute.getNoGpuLib())
6270 ompBuilder->createGlobalFlag(
6271 attribute.getDebugKind() ,
6272 "__omp_rtl_debug_kind");
6273 ompBuilder->createGlobalFlag(
6275 .getAssumeTeamsOversubscription()
6277 "__omp_rtl_assume_teams_oversubscription");
6278 ompBuilder->createGlobalFlag(
6280 .getAssumeThreadsOversubscription()
6282 "__omp_rtl_assume_threads_oversubscription");
6283 ompBuilder->createGlobalFlag(
6284 attribute.getAssumeNoThreadState() ,
6285 "__omp_rtl_assume_no_thread_state");
6286 ompBuilder->createGlobalFlag(
6288 .getAssumeNoNestedParallelism()
6290 "__omp_rtl_assume_no_nested_parallelism");
6295 omp::TargetOp targetOp,
6296 llvm::StringRef parentName =
"") {
6297 auto fileLoc = targetOp.getLoc()->findInstanceOf<
FileLineColLoc>();
6299 assert(fileLoc &&
"No file found from location");
6300 StringRef fileName = fileLoc.getFilename().getValue();
6302 llvm::sys::fs::UniqueID id;
6303 uint64_t line = fileLoc.getLine();
6304 if (
auto ec = llvm::sys::fs::getUniqueID(fileName,
id)) {
6305 size_t fileHash = llvm::hash_value(fileName.str());
6306 size_t deviceId = 0xdeadf17e;
6308 llvm::TargetRegionEntryInfo(parentName, deviceId, fileHash, line);
6310 targetInfo = llvm::TargetRegionEntryInfo(parentName,
id.
getDevice(),
6311 id.getFile(), line);
6318 llvm::IRBuilderBase &builder, llvm::Function *
func) {
6320 "function only supported for target device codegen");
6321 llvm::IRBuilderBase::InsertPointGuard guard(builder);
6322 for (
size_t i = 0; i < mapData.MapClause.size(); ++i) {
6335 if (mapData.IsDeclareTarget[i]) {
6342 if (
auto *constant = dyn_cast<llvm::Constant>(mapData.OriginalValue[i]))
6343 convertUsersOfConstantsToInstructions(constant,
func,
false);
6350 for (llvm::User *user : mapData.OriginalValue[i]->users())
6351 userVec.push_back(user);
6353 for (llvm::User *user : userVec) {
6354 if (
auto *insn = dyn_cast<llvm::Instruction>(user)) {
6355 if (insn->getFunction() ==
func) {
6356 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
6357 llvm::Value *substitute = mapData.BasePointers[i];
6359 : mapOp.getVarPtr())) {
6360 builder.SetCurrentDebugLocation(insn->getDebugLoc());
6361 substitute = builder.CreateLoad(
6362 mapData.BasePointers[i]->getType(), mapData.BasePointers[i]);
6363 cast<llvm::LoadInst>(substitute)->moveBefore(insn->getIterator());
6365 user->replaceUsesOfWith(mapData.OriginalValue[i], substitute);
6412static llvm::IRBuilderBase::InsertPoint
6414 llvm::Value *input, llvm::Value *&retVal,
6415 llvm::IRBuilderBase &builder,
6416 llvm::OpenMPIRBuilder &ompBuilder,
6418 llvm::IRBuilderBase::InsertPoint allocaIP,
6419 llvm::IRBuilderBase::InsertPoint codeGenIP) {
6420 assert(ompBuilder.Config.isTargetDevice() &&
6421 "function only supported for target device codegen");
6422 builder.restoreIP(allocaIP);
6424 omp::VariableCaptureKind capture = omp::VariableCaptureKind::ByRef;
6426 ompBuilder.M.getContext());
6427 unsigned alignmentValue = 0;
6429 for (
size_t i = 0; i < mapData.MapClause.size(); ++i)
6430 if (mapData.OriginalValue[i] == input) {
6431 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
6432 capture = mapOp.getMapCaptureType();
6435 mapOp.getVarType(), ompBuilder.M.getDataLayout());
6439 unsigned int allocaAS = ompBuilder.M.getDataLayout().getAllocaAddrSpace();
6440 unsigned int defaultAS =
6441 ompBuilder.M.getDataLayout().getProgramAddressSpace();
6444 llvm::Value *v = builder.CreateAlloca(arg.getType(), allocaAS);
6446 if (allocaAS != defaultAS && arg.getType()->isPointerTy())
6447 v = builder.CreateAddrSpaceCast(v, builder.getPtrTy(defaultAS));
6449 builder.CreateStore(&arg, v);
6451 builder.restoreIP(codeGenIP);
6454 case omp::VariableCaptureKind::ByCopy: {
6458 case omp::VariableCaptureKind::ByRef: {
6459 llvm::LoadInst *loadInst = builder.CreateAlignedLoad(
6461 ompBuilder.M.getDataLayout().getPrefTypeAlign(v->getType()));
6476 if (v->getType()->isPointerTy() && alignmentValue) {
6477 llvm::MDBuilder MDB(builder.getContext());
6478 loadInst->setMetadata(
6479 llvm::LLVMContext::MD_align,
6480 llvm::MDNode::get(builder.getContext(),
6481 MDB.createConstant(llvm::ConstantInt::get(
6482 llvm::Type::getInt64Ty(builder.getContext()),
6489 case omp::VariableCaptureKind::This:
6490 case omp::VariableCaptureKind::VLAType:
6493 assert(
false &&
"Currently unsupported capture kind");
6497 return builder.saveIP();
6514 auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
6515 for (
auto item : llvm::zip_equal(targetOp.getHostEvalVars(),
6516 blockArgIface.getHostEvalBlockArgs())) {
6517 Value hostEvalVar = std::get<0>(item), blockArg = std::get<1>(item);
6521 .Case([&](omp::TeamsOp teamsOp) {
6522 if (teamsOp.getNumTeamsLower() == blockArg)
6523 numTeamsLower = hostEvalVar;
6524 else if (llvm::is_contained(teamsOp.getNumTeamsUpperVars(),
6526 numTeamsUpper = hostEvalVar;
6527 else if (!teamsOp.getThreadLimitVars().empty() &&
6528 teamsOp.getThreadLimit(0) == blockArg)
6529 threadLimit = hostEvalVar;
6531 llvm_unreachable(
"unsupported host_eval use");
6533 .Case([&](omp::ParallelOp parallelOp) {
6534 if (!parallelOp.getNumThreadsVars().empty() &&
6535 parallelOp.getNumThreads(0) == blockArg)
6536 numThreads = hostEvalVar;
6538 llvm_unreachable(
"unsupported host_eval use");
6540 .Case([&](omp::LoopNestOp loopOp) {
6541 auto processBounds =
6545 for (
auto [i, lb] : llvm::enumerate(opBounds)) {
6546 if (lb == blockArg) {
6549 (*outBounds)[i] = hostEvalVar;
6555 processBounds(loopOp.getLoopLowerBounds(), lowerBounds);
6556 found = processBounds(loopOp.getLoopUpperBounds(), upperBounds) ||
6558 found = processBounds(loopOp.getLoopSteps(), steps) || found;
6560 assert(found &&
"unsupported host_eval use");
6562 .DefaultUnreachable(
"unsupported host_eval use");
6574template <
typename OpTy>
6579 if (OpTy casted = dyn_cast<OpTy>(op))
6582 if (immediateParent)
6583 return dyn_cast_if_present<OpTy>(op->
getParentOp());
6592 return std::nullopt;
6595 if (
auto constAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
6596 return constAttr.getInt();
6598 return std::nullopt;
6603 uint64_t sizeInBytes = sizeInBits / 8;
6607template <
typename OpTy>
6609 if (op.getNumReductionVars() > 0) {
6614 members.reserve(reductions.size());
6615 for (omp::DeclareReductionOp &red : reductions)
6616 members.push_back(red.getType());
6618 auto structType = mlir::LLVM::LLVMStructType::getLiteral(
6634 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs,
6635 bool isTargetDevice,
bool isGPU) {
6638 Value numThreads, numTeamsLower, numTeamsUpper, threadLimit;
6639 if (!isTargetDevice) {
6647 numTeamsLower = teamsOp.getNumTeamsLower();
6649 if (!teamsOp.getNumTeamsUpperVars().empty())
6650 numTeamsUpper = teamsOp.getNumTeams(0);
6651 if (!teamsOp.getThreadLimitVars().empty())
6652 threadLimit = teamsOp.getThreadLimit(0);
6656 if (!parallelOp.getNumThreadsVars().empty())
6657 numThreads = parallelOp.getNumThreads(0);
6663 int32_t minTeamsVal = 1, maxTeamsVal = -1;
6667 if (numTeamsUpper) {
6669 minTeamsVal = maxTeamsVal = *val;
6671 minTeamsVal = maxTeamsVal = 0;
6677 minTeamsVal = maxTeamsVal = 1;
6679 minTeamsVal = maxTeamsVal = -1;
6684 auto setMaxValueFromClause = [](
Value clauseValue, int32_t &
result) {
6698 int32_t targetThreadLimitVal = -1, teamsThreadLimitVal = -1;
6699 if (!targetOp.getThreadLimitVars().empty())
6700 setMaxValueFromClause(targetOp.getThreadLimit(0), targetThreadLimitVal);
6701 setMaxValueFromClause(threadLimit, teamsThreadLimitVal);
6704 int32_t maxThreadsVal = -1;
6706 setMaxValueFromClause(numThreads, maxThreadsVal);
6714 int32_t combinedMaxThreadsVal = targetThreadLimitVal;
6715 if (combinedMaxThreadsVal < 0 ||
6716 (teamsThreadLimitVal >= 0 && teamsThreadLimitVal < combinedMaxThreadsVal))
6717 combinedMaxThreadsVal = teamsThreadLimitVal;
6719 if (combinedMaxThreadsVal < 0 ||
6720 (maxThreadsVal >= 0 && maxThreadsVal < combinedMaxThreadsVal))
6721 combinedMaxThreadsVal = maxThreadsVal;
6723 int32_t reductionDataSize = 0;
6724 if (isGPU && capturedOp) {
6730 omp::TargetRegionFlags kernelFlags = targetOp.getKernelExecFlags(capturedOp);
6732 omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic |
6733 omp::TargetRegionFlags::spmd) &&
6734 "invalid kernel flags");
6736 omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic)
6737 ? omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::spmd)
6738 ? llvm::omp::OMP_TGT_EXEC_MODE_GENERIC_SPMD
6739 : llvm::omp::OMP_TGT_EXEC_MODE_GENERIC
6740 : llvm::omp::OMP_TGT_EXEC_MODE_SPMD;
6741 if (omp::bitEnumContainsAll(kernelFlags,
6742 omp::TargetRegionFlags::spmd |
6743 omp::TargetRegionFlags::no_loop) &&
6744 !omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic))
6745 attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_SPMD_NO_LOOP;
6747 attrs.MinTeams = minTeamsVal;
6748 attrs.MaxTeams.front() = maxTeamsVal;
6749 attrs.MinThreads = 1;
6750 attrs.MaxThreads.front() = combinedMaxThreadsVal;
6751 attrs.ReductionDataSize = reductionDataSize;
6754 if (attrs.ReductionDataSize != 0)
6755 attrs.ReductionBufferLength = 1024;
6767 omp::TargetOp targetOp,
Operation *capturedOp,
6768 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs) {
6770 unsigned numLoops = loopOp ? loopOp.getNumLoops() : 0;
6772 Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
6776 teamsThreadLimit, &lowerBounds, &upperBounds, &steps);
6779 if (!targetOp.getThreadLimitVars().empty()) {
6780 Value targetThreadLimit = targetOp.getThreadLimit(0);
6781 attrs.TargetThreadLimit.front() =
6789 attrs.MinTeams = builder.CreateSExtOrTrunc(
6790 moduleTranslation.
lookupValue(numTeamsLower), builder.getInt32Ty());
6793 attrs.MaxTeams.front() = builder.CreateSExtOrTrunc(
6794 moduleTranslation.
lookupValue(numTeamsUpper), builder.getInt32Ty());
6796 if (teamsThreadLimit)
6797 attrs.TeamsThreadLimit.front() = builder.CreateSExtOrTrunc(
6798 moduleTranslation.
lookupValue(teamsThreadLimit), builder.getInt32Ty());
6801 attrs.MaxThreads = moduleTranslation.
lookupValue(numThreads);
6803 if (omp::bitEnumContainsAny(targetOp.getKernelExecFlags(capturedOp),
6804 omp::TargetRegionFlags::trip_count)) {
6806 attrs.LoopTripCount =
nullptr;
6811 for (
auto [loopLower, loopUpper, loopStep] :
6812 llvm::zip_equal(lowerBounds, upperBounds, steps)) {
6813 llvm::Value *lowerBound = moduleTranslation.
lookupValue(loopLower);
6814 llvm::Value *upperBound = moduleTranslation.
lookupValue(loopUpper);
6815 llvm::Value *step = moduleTranslation.
lookupValue(loopStep);
6817 if (!lowerBound || !upperBound || !step) {
6818 attrs.LoopTripCount =
nullptr;
6822 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
6823 llvm::Value *tripCount = ompBuilder->calculateCanonicalLoopTripCount(
6824 loc, lowerBound, upperBound, step,
true,
6825 loopOp.getLoopInclusive());
6827 if (!attrs.LoopTripCount) {
6828 attrs.LoopTripCount = tripCount;
6833 attrs.LoopTripCount = builder.CreateMul(attrs.LoopTripCount, tripCount,
6838 attrs.DeviceID = builder.getInt64(llvm::omp::OMP_DEVICEID_UNDEF);
6840 attrs.DeviceID = moduleTranslation.
lookupValue(devId);
6842 builder.CreateSExtOrTrunc(attrs.DeviceID, builder.getInt64Ty());
6849 auto targetOp = cast<omp::TargetOp>(opInst);
6854 llvm::DebugLoc outlinedFnLoc = builder.getCurrentDebugLocation();
6863 llvm::BasicBlock *parentBB = builder.GetInsertBlock();
6864 assert(parentBB &&
"No insert block is set for the builder");
6865 llvm::Function *parentLLVMFn = parentBB->getParent();
6866 assert(parentLLVMFn &&
"Parent Function must be valid");
6867 if (llvm::DISubprogram *SP = parentLLVMFn->getSubprogram())
6868 builder.SetCurrentDebugLocation(llvm::DILocation::get(
6869 parentLLVMFn->getContext(), outlinedFnLoc.getLine(),
6870 outlinedFnLoc.getCol(), SP, outlinedFnLoc.getInlinedAt()));
6873 bool isTargetDevice = ompBuilder->Config.isTargetDevice();
6874 bool isGPU = ompBuilder->Config.isGPU();
6877 auto argIface = cast<omp::BlockArgOpenMPOpInterface>(opInst);
6878 auto &targetRegion = targetOp.getRegion();
6895 llvm::Function *llvmOutlinedFn =
nullptr;
6896 TargetDirectiveEnumTy targetDirective =
6897 getTargetDirectiveEnumTyFromOp(&opInst);
6901 bool isOffloadEntry =
6902 isTargetDevice || !ompBuilder->Config.TargetTriples.empty();
6909 if (!targetOp.getPrivateVars().empty() && !targetOp.getMapVars().empty()) {
6911 std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms();
6912 std::optional<DenseI64ArrayAttr> privateMapIndices =
6913 targetOp.getPrivateMapsAttr();
6915 for (
auto [privVarIdx, privVarSymPair] :
6916 llvm::enumerate(llvm::zip_equal(privateVars, *privateSyms))) {
6917 auto privVar = std::get<0>(privVarSymPair);
6918 auto privSym = std::get<1>(privVarSymPair);
6920 SymbolRefAttr privatizerName = llvm::cast<SymbolRefAttr>(privSym);
6921 omp::PrivateClauseOp privatizer =
6924 if (!privatizer.needsMap())
6928 targetOp.getMappedValueForPrivateVar(privVarIdx);
6929 assert(mappedValue &&
"Expected to find mapped value for a privatized "
6930 "variable that needs mapping");
6935 auto mapInfoOp = mappedValue.
getDefiningOp<omp::MapInfoOp>();
6936 [[maybe_unused]]
Type varType = mapInfoOp.getVarType();
6940 if (!isa<LLVM::LLVMPointerType>(privVar.getType()))
6942 varType == privVar.getType() &&
6943 "Type of private var doesn't match the type of the mapped value");
6947 mappedPrivateVars.insert(
6949 targetRegion.getArgument(argIface.getMapBlockArgsStart() +
6950 (*privateMapIndices)[privVarIdx])});
6954 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
6955 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP)
6956 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
6957 llvm::IRBuilderBase::InsertPointGuard guard(builder);
6958 builder.SetCurrentDebugLocation(llvm::DebugLoc());
6961 llvm::Function *llvmParentFn =
6963 llvmOutlinedFn = codeGenIP.getBlock()->getParent();
6964 assert(llvmParentFn && llvmOutlinedFn &&
6965 "Both parent and outlined functions must exist at this point");
6967 if (outlinedFnLoc && llvmParentFn->getSubprogram())
6968 llvmOutlinedFn->setSubprogram(outlinedFnLoc->getScope()->getSubprogram());
6970 if (
auto attr = llvmParentFn->getFnAttribute(
"target-cpu");
6971 attr.isStringAttribute())
6972 llvmOutlinedFn->addFnAttr(attr);
6974 if (
auto attr = llvmParentFn->getFnAttribute(
"target-features");
6975 attr.isStringAttribute())
6976 llvmOutlinedFn->addFnAttr(attr);
6978 for (
auto [arg, mapOp] : llvm::zip_equal(mapBlockArgs, mapVars)) {
6979 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
6980 llvm::Value *mapOpValue =
6981 moduleTranslation.
lookupValue(mapInfoOp.getVarPtr());
6982 moduleTranslation.
mapValue(arg, mapOpValue);
6984 for (
auto [arg, mapOp] : llvm::zip_equal(hdaBlockArgs, hdaVars)) {
6985 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
6986 llvm::Value *mapOpValue =
6987 moduleTranslation.
lookupValue(mapInfoOp.getVarPtr());
6988 moduleTranslation.
mapValue(arg, mapOpValue);
6997 allocaIP, &mappedPrivateVars);
7000 return llvm::make_error<PreviouslyReportedError>();
7002 builder.restoreIP(codeGenIP);
7004 &mappedPrivateVars),
7007 return llvm::make_error<PreviouslyReportedError>();
7010 targetOp, builder, moduleTranslation, privateVarsInfo.
mlirVars,
7012 targetOp.getPrivateNeedsBarrier(), &mappedPrivateVars)))
7013 return llvm::make_error<PreviouslyReportedError>();
7017 std::back_inserter(privateCleanupRegions),
7018 [](omp::PrivateClauseOp privatizer) {
7019 return &privatizer.getDeallocRegion();
7023 targetRegion,
"omp.target", builder, moduleTranslation);
7026 return exitBlock.takeError();
7028 builder.SetInsertPoint(*exitBlock);
7029 if (!privateCleanupRegions.empty()) {
7031 privateCleanupRegions, privateVarsInfo.
llvmVars,
7032 moduleTranslation, builder,
"omp.targetop.private.cleanup",
7034 return llvm::createStringError(
7035 "failed to inline `dealloc` region of `omp.private` "
7036 "op in the target region");
7038 return builder.saveIP();
7041 return InsertPointTy(exitBlock.get(), exitBlock.get()->end());
7044 StringRef parentName = parentFn.getName();
7046 llvm::TargetRegionEntryInfo entryInfo;
7050 MapInfoData mapData;
7055 MapInfosTy combinedInfos;
7057 [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP) -> MapInfosTy & {
7058 builder.restoreIP(codeGenIP);
7059 genMapInfos(builder, moduleTranslation, dl, combinedInfos, mapData,
7064 auto *nullPtr = llvm::Constant::getNullValue(builder.getPtrTy());
7065 combinedInfos.BasePointers.push_back(nullPtr);
7066 combinedInfos.Pointers.push_back(nullPtr);
7067 combinedInfos.DevicePointers.push_back(
7068 llvm::OpenMPIRBuilder::DeviceInfoTy::None);
7069 combinedInfos.Sizes.push_back(builder.getInt64(0));
7070 combinedInfos.Types.push_back(
7071 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM |
7072 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL);
7073 if (!combinedInfos.Names.empty())
7074 combinedInfos.Names.push_back(nullPtr);
7075 combinedInfos.Mappers.push_back(
nullptr);
7077 return combinedInfos;
7080 auto argAccessorCB = [&](llvm::Argument &arg, llvm::Value *input,
7081 llvm::Value *&retVal, InsertPointTy allocaIP,
7082 InsertPointTy codeGenIP)
7083 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
7084 llvm::IRBuilderBase::InsertPointGuard guard(builder);
7085 builder.SetCurrentDebugLocation(llvm::DebugLoc());
7091 if (!isTargetDevice) {
7092 retVal = cast<llvm::Value>(&arg);
7097 *ompBuilder, moduleTranslation,
7098 allocaIP, codeGenIP);
7101 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs runtimeAttrs;
7102 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs;
7103 Operation *targetCapturedOp = targetOp.getInnermostCapturedOmpOp();
7105 isTargetDevice, isGPU);
7109 if (!isTargetDevice)
7111 targetCapturedOp, runtimeAttrs);
7119 for (
auto [arg, var] : llvm::zip_equal(hostEvalBlockArgs, hostEvalVars)) {
7120 llvm::Value *value = moduleTranslation.
lookupValue(var);
7121 moduleTranslation.
mapValue(arg, value);
7123 if (!llvm::isa<llvm::Constant>(value))
7124 kernelInput.push_back(value);
7127 for (
size_t i = 0, e = mapData.OriginalValue.size(); i != e; ++i) {
7134 if (!mapData.IsDeclareTarget[i] && !mapData.IsAMember[i])
7135 kernelInput.push_back(mapData.OriginalValue[i]);
7140 moduleTranslation, dds);
7142 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
7144 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
7146 llvm::OpenMPIRBuilder::TargetDataInfo info(
7150 auto customMapperCB =
7152 if (!combinedInfos.Mappers[i])
7154 info.HasMapper =
true;
7156 moduleTranslation, targetDirective);
7159 llvm::Value *ifCond =
nullptr;
7160 if (
Value targetIfCond = targetOp.getIfExpr())
7161 ifCond = moduleTranslation.
lookupValue(targetIfCond);
7163 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
7165 ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), info, entryInfo,
7166 defaultAttrs, runtimeAttrs, ifCond, kernelInput, genMapInfoCB, bodyCB,
7167 argAccessorCB, customMapperCB, dds, targetOp.getNowait());
7172 builder.restoreIP(*afterIP);
7193 if (FunctionOpInterface funcOp = dyn_cast<FunctionOpInterface>(op)) {
7194 if (
auto offloadMod = dyn_cast<omp::OffloadModuleInterface>(
7196 if (!offloadMod.getIsTargetDevice())
7199 omp::DeclareTargetDeviceType declareType =
7200 attribute.getDeviceType().getValue();
7202 if (declareType == omp::DeclareTargetDeviceType::host) {
7203 llvm::Function *llvmFunc =
7205 llvmFunc->dropAllReferences();
7206 llvmFunc->eraseFromParent();
7212 if (LLVM::GlobalOp gOp = dyn_cast<LLVM::GlobalOp>(op)) {
7213 llvm::Module *llvmModule = moduleTranslation.
getLLVMModule();
7214 if (
auto *gVal = llvmModule->getNamedValue(gOp.getSymName())) {
7216 bool isDeclaration = gOp.isDeclaration();
7217 bool isExternallyVisible =
7220 llvm::StringRef mangledName = gOp.getSymName();
7221 auto captureClause =
7227 std::vector<llvm::GlobalVariable *> generatedRefs;
7229 std::vector<llvm::Triple> targetTriple;
7230 auto targetTripleAttr = dyn_cast_or_null<mlir::StringAttr>(
7232 LLVM::LLVMDialect::getTargetTripleAttrName()));
7233 if (targetTripleAttr)
7234 targetTriple.emplace_back(targetTripleAttr.data());
7236 auto fileInfoCallBack = [&loc]() {
7237 std::string filename =
"";
7238 std::uint64_t lineNo = 0;
7241 filename = loc.getFilename().str();
7242 lineNo = loc.getLine();
7245 return std::pair<std::string, std::uint64_t>(llvm::StringRef(filename),
7249 auto vfs = llvm::vfs::getRealFileSystem();
7251 ompBuilder->registerTargetGlobalVariable(
7252 captureClause, deviceClause, isDeclaration, isExternallyVisible,
7253 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack, *vfs),
7254 mangledName, generatedRefs,
false, targetTriple,
7256 gVal->getType(), gVal);
7258 if (ompBuilder->Config.isTargetDevice() &&
7259 (attribute.getCaptureClause().getValue() !=
7260 mlir::omp::DeclareTargetCaptureClause::to ||
7261 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
7262 ompBuilder->getAddrOfDeclareTargetVar(
7263 captureClause, deviceClause, isDeclaration, isExternallyVisible,
7264 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack, *vfs),
7265 mangledName, generatedRefs,
false, targetTriple,
7266 gVal->getType(),
nullptr,
7279class OpenMPDialectLLVMIRTranslationInterface
7280 :
public LLVMTranslationDialectInterface {
7282 using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
7287 convertOperation(Operation *op, llvm::IRBuilderBase &builder,
7288 LLVM::ModuleTranslation &moduleTranslation)
const final;
7293 amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
7294 NamedAttribute attribute,
7295 LLVM::ModuleTranslation &moduleTranslation)
const final;
7300LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
7301 Operation *op, ArrayRef<llvm::Instruction *> instructions,
7302 NamedAttribute attribute,
7303 LLVM::ModuleTranslation &moduleTranslation)
const {
7304 return llvm::StringSwitch<llvm::function_ref<LogicalResult(Attribute)>>(
7306 .Case(
"omp.is_target_device",
7307 [&](Attribute attr) {
7308 if (
auto deviceAttr = dyn_cast<BoolAttr>(attr)) {
7309 llvm::OpenMPIRBuilderConfig &config =
7311 config.setIsTargetDevice(deviceAttr.getValue());
7317 [&](Attribute attr) {
7318 if (
auto gpuAttr = dyn_cast<BoolAttr>(attr)) {
7319 llvm::OpenMPIRBuilderConfig &config =
7321 config.setIsGPU(gpuAttr.getValue());
7326 .Case(
"omp.host_ir_filepath",
7327 [&](Attribute attr) {
7328 if (
auto filepathAttr = dyn_cast<StringAttr>(attr)) {
7329 llvm::OpenMPIRBuilder *ompBuilder =
7331 auto VFS = llvm::vfs::getRealFileSystem();
7332 ompBuilder->loadOffloadInfoMetadata(*VFS,
7333 filepathAttr.getValue());
7339 [&](Attribute attr) {
7340 if (
auto rtlAttr = dyn_cast<omp::FlagsAttr>(attr))
7344 .Case(
"omp.version",
7345 [&](Attribute attr) {
7346 if (
auto versionAttr = dyn_cast<omp::VersionAttr>(attr)) {
7347 llvm::OpenMPIRBuilder *ompBuilder =
7349 ompBuilder->M.addModuleFlag(llvm::Module::Max,
"openmp",
7350 versionAttr.getVersion());
7355 .Case(
"omp.declare_target",
7356 [&](Attribute attr) {
7357 if (
auto declareTargetAttr =
7358 dyn_cast<omp::DeclareTargetAttr>(attr))
7363 .Case(
"omp.requires",
7364 [&](Attribute attr) {
7365 if (
auto requiresAttr = dyn_cast<omp::ClauseRequiresAttr>(attr)) {
7366 using Requires = omp::ClauseRequires;
7367 Requires flags = requiresAttr.getValue();
7368 llvm::OpenMPIRBuilderConfig &config =
7370 config.setHasRequiresReverseOffload(
7371 bitEnumContainsAll(flags, Requires::reverse_offload));
7372 config.setHasRequiresUnifiedAddress(
7373 bitEnumContainsAll(flags, Requires::unified_address));
7374 config.setHasRequiresUnifiedSharedMemory(
7375 bitEnumContainsAll(flags, Requires::unified_shared_memory));
7376 config.setHasRequiresDynamicAllocators(
7377 bitEnumContainsAll(flags, Requires::dynamic_allocators));
7382 .Case(
"omp.target_triples",
7383 [&](Attribute attr) {
7384 if (
auto triplesAttr = dyn_cast<ArrayAttr>(attr)) {
7385 llvm::OpenMPIRBuilderConfig &config =
7387 config.TargetTriples.clear();
7388 config.TargetTriples.reserve(triplesAttr.size());
7389 for (Attribute tripleAttr : triplesAttr) {
7390 if (
auto tripleStrAttr = dyn_cast<StringAttr>(tripleAttr))
7391 config.TargetTriples.emplace_back(tripleStrAttr.getValue());
7399 .Default([](Attribute) {
7415 if (
auto declareTargetIface =
7416 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
7417 parentFn.getOperation()))
7418 if (declareTargetIface.isDeclareTarget() &&
7419 declareTargetIface.getDeclareTargetDeviceType() !=
7420 mlir::omp::DeclareTargetDeviceType::host)
7430 llvm::Module *llvmModule) {
7431 llvm::Type *i64Ty = builder.getInt64Ty();
7432 llvm::Type *i32Ty = builder.getInt32Ty();
7433 llvm::Type *returnType = builder.getPtrTy(0);
7434 llvm::FunctionType *fnType =
7435 llvm::FunctionType::get(returnType, {i64Ty, i32Ty},
false);
7436 llvm::Function *
func = cast<llvm::Function>(
7437 llvmModule->getOrInsertFunction(
"omp_target_alloc", fnType).getCallee());
7444 auto allocMemOp = cast<omp::TargetAllocMemOp>(opInst);
7449 llvm::Module *llvmModule = moduleTranslation.
getLLVMModule();
7453 llvm::Value *llvmDeviceNum = moduleTranslation.
lookupValue(deviceNum);
7455 llvm::DataLayout dataLayout = llvmModule->getDataLayout();
7456 mlir::Type heapTy = allocMemOp.getAllocatedType();
7457 llvm::Type *llvmHeapTy = moduleTranslation.
convertType(heapTy);
7458 llvm::TypeSize typeSize = dataLayout.getTypeStoreSize(llvmHeapTy);
7459 llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue());
7460 for (
auto typeParam : allocMemOp.getTypeparams())
7462 builder.CreateMul(allocSize, moduleTranslation.
lookupValue(typeParam));
7464 llvm::CallInst *call =
7465 builder.CreateCall(ompTargetAllocFunc, {allocSize, llvmDeviceNum});
7466 llvm::Value *resultI64 = builder.CreatePtrToInt(call, builder.getInt64Ty());
7469 moduleTranslation.
mapValue(allocMemOp.getResult(), resultI64);
7474 llvm::Module *llvmModule) {
7475 llvm::Type *ptrTy = builder.getPtrTy(0);
7476 llvm::Type *i32Ty = builder.getInt32Ty();
7477 llvm::Type *voidTy = builder.getVoidTy();
7478 llvm::FunctionType *fnType =
7479 llvm::FunctionType::get(voidTy, {ptrTy, i32Ty},
false);
7480 llvm::Function *
func = dyn_cast<llvm::Function>(
7481 llvmModule->getOrInsertFunction(
"omp_target_free", fnType).getCallee());
7488 auto freeMemOp = cast<omp::TargetFreeMemOp>(opInst);
7493 llvm::Module *llvmModule = moduleTranslation.
getLLVMModule();
7497 llvm::Value *llvmDeviceNum = moduleTranslation.
lookupValue(deviceNum);
7500 llvm::Value *llvmHeapref = moduleTranslation.
lookupValue(heapref);
7502 llvm::Value *intToPtr =
7503 builder.CreateIntToPtr(llvmHeapref, builder.getPtrTy(0));
7504 builder.CreateCall(ompTragetFreeFunc, {intToPtr, llvmDeviceNum});
7510LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
7511 Operation *op, llvm::IRBuilderBase &builder,
7512 LLVM::ModuleTranslation &moduleTranslation)
const {
7515 if (ompBuilder->Config.isTargetDevice() &&
7516 !isa<omp::TargetOp, omp::MapInfoOp, omp::TerminatorOp, omp::YieldOp>(
7519 return op->
emitOpError() <<
"unsupported host op found in device";
7527 bool isOutermostLoopWrapper =
7528 isa_and_present<omp::LoopWrapperInterface>(op) &&
7529 !dyn_cast_if_present<omp::LoopWrapperInterface>(op->
getParentOp());
7531 if (isOutermostLoopWrapper)
7532 moduleTranslation.
stackPush<OpenMPLoopInfoStackFrame>();
7535 llvm::TypeSwitch<Operation *, LogicalResult>(op)
7536 .Case([&](omp::BarrierOp op) -> LogicalResult {
7540 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
7541 ompBuilder->createBarrier(builder.saveIP(),
7542 llvm::omp::OMPD_barrier);
7544 if (res.succeeded()) {
7547 builder.restoreIP(*afterIP);
7551 .Case([&](omp::TaskyieldOp op) {
7555 ompBuilder->createTaskyield(builder.saveIP());
7558 .Case([&](omp::FlushOp op) {
7570 ompBuilder->createFlush(builder.saveIP());
7573 .Case([&](omp::ParallelOp op) {
7576 .Case([&](omp::MaskedOp) {
7579 .Case([&](omp::MasterOp) {
7582 .Case([&](omp::CriticalOp) {
7585 .Case([&](omp::OrderedRegionOp) {
7588 .Case([&](omp::OrderedOp) {
7591 .Case([&](omp::WsloopOp) {
7594 .Case([&](omp::SimdOp) {
7597 .Case([&](omp::AtomicReadOp) {
7600 .Case([&](omp::AtomicWriteOp) {
7603 .Case([&](omp::AtomicUpdateOp op) {
7606 .Case([&](omp::AtomicCaptureOp op) {
7609 .Case([&](omp::CancelOp op) {
7612 .Case([&](omp::CancellationPointOp op) {
7615 .Case([&](omp::SectionsOp) {
7618 .Case([&](omp::SingleOp op) {
7621 .Case([&](omp::TeamsOp op) {
7624 .Case([&](omp::TaskOp op) {
7627 .Case([&](omp::TaskloopOp op) {
7630 .Case([&](omp::TaskgroupOp op) {
7633 .Case([&](omp::TaskwaitOp op) {
7636 .Case<omp::YieldOp, omp::TerminatorOp, omp::DeclareMapperOp,
7637 omp::DeclareMapperInfoOp, omp::DeclareReductionOp,
7638 omp::CriticalDeclareOp>([](
auto op) {
7651 .Case([&](omp::ThreadprivateOp) {
7654 .Case<omp::TargetDataOp, omp::TargetEnterDataOp,
7655 omp::TargetExitDataOp, omp::TargetUpdateOp>([&](
auto op) {
7658 .Case([&](omp::TargetOp) {
7661 .Case([&](omp::DistributeOp) {
7664 .Case([&](omp::LoopNestOp) {
7667 .Case<omp::MapInfoOp, omp::MapBoundsOp, omp::PrivateClauseOp,
7668 omp::AffinityEntryOp, omp::IteratorOp>([&](
auto op) {
7674 .Case([&](omp::NewCliOp op) {
7679 .Case([&](omp::CanonicalLoopOp op) {
7682 .Case([&](omp::UnrollHeuristicOp op) {
7691 .Case([&](omp::TileOp op) {
7692 return applyTile(op, builder, moduleTranslation);
7694 .Case([&](omp::FuseOp op) {
7695 return applyFuse(op, builder, moduleTranslation);
7697 .Case([&](omp::TargetAllocMemOp) {
7700 .Case([&](omp::TargetFreeMemOp) {
7703 .Default([&](Operation *inst) {
7705 <<
"not yet implemented: " << inst->
getName();
7708 if (isOutermostLoopWrapper)
7715 registry.
insert<omp::OpenMPDialect>();
7717 dialect->addInterfaces<OpenMPDialectLLVMIRTranslationInterface>();
static ze_device_handle_t getDevice(const uint32_t driverIdx=0, const int32_t devIdx=0)
static bool mapTypeToBool(ClauseMapFlags value, ClauseMapFlags flag)
static void handleDeclareTargetMapVar(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::Function *func)
static LogicalResult convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP atomic update operation using OpenMPIRBuilder.
static llvm::omp::OrderKind convertOrderKind(std::optional< omp::ClauseOrderKind > o)
Convert Order attribute to llvm::omp::OrderKind.
static bool teamsReductionContainedInDistribute(omp::TeamsOp teamsOp)
static LogicalResult convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
owningDataPtrPtrReductionGens[i]
static Operation * getGlobalOpFromValue(Value value)
static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind convertToCaptureClauseKind(mlir::omp::DeclareTargetCaptureClause captureClause)
static mlir::LogicalResult convertIteratorRegion(llvm::Value *linearIV, IteratorInfo &iterInfo, mlir::Block &iteratorRegionBlock, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo, bool first)
static OpTy castOrGetParentOfType(Operation *op, bool immediateParent=false)
If op is of the given type parameter, return it casted to that type. Otherwise, if its immediate pare...
static LogicalResult convertOmpOrderedRegion(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'ordered_region' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertTargetFreeMemOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static void processMapMembersWithParent(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, llvm::omp::OpenMPOffloadMappingFlags memberOfFlag, TargetDirectiveEnumTy targetDirective)
static LogicalResult convertOmpAtomicWrite(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an omp.atomic.write operation to LLVM IR.
static OwningAtomicReductionGen makeAtomicReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Create an OpenMPIRBuilder-compatible atomic reduction generator for the given reduction declaration.
static OwningDataPtrPtrReductionGen makeRefDataPtrGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, bool isByRef)
Create an OpenMPIRBuilder-compatible data_ptr_ptr reduction generator for the given reduction declara...
static void popCancelFinalizationCB(const ArrayRef< llvm::UncondBrInst * > cancelTerminators, llvm::OpenMPIRBuilder &ompBuilder, const llvm::OpenMPIRBuilder::InsertPointTy &afterIP)
If we cancelled the construct, we should branch to the finalization block of that construct....
static llvm::Value * getRefPtrIfDeclareTarget(Value value, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult applyUnrollHeuristic(omp::UnrollHeuristicOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Apply a #pragma omp unroll / "!$omp unroll" transformation using the OpenMPIRBuilder.
static LogicalResult convertOmpMaster(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'master' operation into LLVM IR using OpenMPIRBuilder.
static void getAsIntegers(ArrayAttr values, llvm::SmallVector< int64_t > &ints)
static llvm::Value * findAssociatedValue(Value privateVar, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Return the llvm::Value * corresponding to the privateVar that is being privatized....
static ArrayRef< bool > getIsByRef(std::optional< ArrayRef< bool > > attr)
static LogicalResult convertOmpTaskwaitOp(omp::TaskwaitOp twOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP loop nest into LLVM IR using OpenMPIRBuilder.
static mlir::LogicalResult fillIteratorLoop(mlir::omp::IteratorOp itersOp, llvm::IRBuilderBase &builder, mlir::LLVM::ModuleTranslation &moduleTranslation, IteratorInfo &iterInfo, llvm::StringRef loopName, IteratorStoreEntryTy genStoreEntry)
static LogicalResult cleanupPrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, Location loc, SmallVectorImpl< llvm::Value * > &llvmPrivateVars, SmallVectorImpl< omp::PrivateClauseOp > &privateDecls)
static void createAlteredByCaptureMap(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder)
static LogicalResult convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP task construct into LLVM IR using OpenMPIRBuilder.
static void genMapInfos(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, TargetDirectiveEnumTy targetDirective)
static llvm::AtomicOrdering convertAtomicOrdering(std::optional< omp::ClauseMemoryOrderKind > ao)
Convert an Atomic Ordering attribute to llvm::AtomicOrdering.
static void setInsertPointForPossiblyEmptyBlock(llvm::IRBuilderBase &builder, llvm::BasicBlock *block=nullptr)
llvm::function_ref< void(llvm::Value *linearIV, mlir::omp::YieldOp yield)> IteratorStoreEntryTy
static LogicalResult convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'critical' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, TargetDirectiveEnumTy targetDirective)
static LogicalResult convertOmpCanonicalLoopOp(omp::CanonicalLoopOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Convert an omp.canonical_loop to LLVM-IR.
static LogicalResult convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static std::optional< int64_t > extractConstInteger(Value value)
If the given value is defined by an llvm.mlir.constant operation and it is of an integer type,...
static llvm::Expected< llvm::Value * > initPrivateVar(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, omp::PrivateClauseOp &privDecl, llvm::Value *nonPrivateVar, BlockArgument &blockArg, llvm::Value *llvmPrivateVar, llvm::BasicBlock *privInitBlock, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Initialize a single (first)private variable. You probably want to use allocateAndInitPrivateVars inst...
static mlir::LogicalResult buildAffinityData(mlir::omp::TaskOp &taskOp, llvm::IRBuilderBase &builder, mlir::LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::AffinityData &ad)
static LogicalResult allocAndInitializeReductionVars(OP op, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, llvm::ArrayRef< bool > isByRef)
static LogicalResult convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Function * getOmpTargetAlloc(llvm::IRBuilderBase &builder, llvm::Module *llvmModule)
static llvm::Expected< llvm::Function * > emitUserDefinedMapper(Operation *declMapperOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::StringRef mapperFuncName, TargetDirectiveEnumTy targetDirective)
static llvm::Expected< llvm::BasicBlock * > allocatePrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, PrivateVarsInfo &privateVarsInfo, const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Allocate and initialize delayed private variables. Returns the basic block which comes after all of t...
static LogicalResult convertOmpOrdered(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'ordered' operation into LLVM IR using OpenMPIRBuilder.
static void processMapWithMembersOf(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, TargetDirectiveEnumTy targetDirective)
static LogicalResult convertOmpTaskloopOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpMasked(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'masked' operation into LLVM IR using OpenMPIRBuilder.
static llvm::AtomicRMWInst::BinOp convertBinOpToAtomic(Operation &op)
Converts an LLVM dialect binary operation to the corresponding enum value for atomicrmw supported bin...
static LogicalResult convertOmpCancel(omp::CancelOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp)
allocatedType moduleTranslation static convertType(allocatedType) LogicalResult inlineOmpRegionCleanup(llvm::SmallVectorImpl< Region * > &cleanupRegions, llvm::ArrayRef< llvm::Value * > privateVariables, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, StringRef regionName, bool shouldLoadCleanupRegionArg=true)
handling of DeclareReductionOp's cleanup region
static LogicalResult applyFuse(omp::FuseOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Apply a #pragma omp fuse / !$omp fuse transformation using the OpenMPIRBuilder.
static llvm::SmallString< 64 > getDeclareTargetRefPtrSuffix(LLVM::GlobalOp globalOp, llvm::OpenMPIRBuilder &ompBuilder)
static llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Error initPrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, PrivateVarsInfo &privateVarsInfo, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
static llvm::CanonicalLoopInfo * findCurrentLoopInfo(LLVM::ModuleTranslation &moduleTranslation)
Find the loop information structure for the loop nest being translated.
static OwningReductionGen makeReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Create an OpenMPIRBuilder-compatible reduction generator for the given reduction declaration.
static std::vector< llvm::Value * > calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, bool isArrayTy, OperandRange bounds)
This function calculates the array/pointer offset for map data provided with bounds operations,...
static void storeAffinityEntry(llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, llvm::Value *affinityList, llvm::Value *index, llvm::Value *addr, llvm::Value *len)
static LogicalResult convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts the OpenMP parallel operation to LLVM IR.
static void pushCancelFinalizationCB(SmallVectorImpl< llvm::UncondBrInst * > &cancelTerminators, llvm::IRBuilderBase &llvmBuilder, llvm::OpenMPIRBuilder &ompBuilder, mlir::Operation *op, llvm::omp::Directive cancelDirective)
Shared implementation of a callback which adds a termiator for the new block created for the branch t...
static llvm::OpenMPIRBuilder::InsertPointTy findAllocaInsertPoint(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Find the insertion point for allocas given the current insertion point for normal operations in the b...
static LogicalResult inlineConvertOmpRegions(Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::Value * > *continuationBlockArgs=nullptr)
Translates the blocks contained in the given region and appends them to at the current insertion poin...
static LogicalResult convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP Threadprivate operation into LLVM IR using OpenMPIRBuilder.
static omp::PrivateClauseOp findPrivatizer(Operation *from, SymbolRefAttr symbolName)
Looks up from the operation from and returns the PrivateClauseOp with name symbolName.
static LogicalResult convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Expected< llvm::Function * > getOrCreateUserDefinedMapperFunc(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, TargetDirectiveEnumTy targetDirective)
static uint64_t getTypeByteSize(mlir::Type type, const DataLayout &dl)
static void processIndividualMap(MapInfoData &mapData, size_t mapDataIdx, MapInfosTy &combinedInfo, TargetDirectiveEnumTy targetDirective, int mapDataParentIdx=-1)
static void getTargetEntryUniqueInfo(llvm::TargetRegionEntryInfo &targetInfo, omp::TargetOp targetOp, llvm::StringRef parentName="")
static void extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads, Value &numTeamsLower, Value &numTeamsUpper, Value &threadLimit, llvm::SmallVectorImpl< Value > *lowerBounds=nullptr, llvm::SmallVectorImpl< Value > *upperBounds=nullptr, llvm::SmallVectorImpl< Value > *steps=nullptr)
Follow uses of host_eval-defined block arguments of the given omp.target operation and populate outpu...
static llvm::Expected< llvm::BasicBlock * > convertOmpOpRegions(Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::PHINode * > *continuationBlockPHIs=nullptr)
Converts the given region that appears within an OpenMP dialect operation to LLVM IR,...
static LogicalResult copyFirstPrivateVars(mlir::Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::Value * > &moldVars, ArrayRef< llvm::Value * > llvmPrivateVars, SmallVectorImpl< omp::PrivateClauseOp > &privateDecls, bool insertBarrier, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
static LogicalResult allocReductionVars(T loop, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, SmallVectorImpl< DeferredStore > &deferredStores, llvm::ArrayRef< bool > isByRefs)
Allocate space for privatized reduction variables.
static bool constructIsCancellable(Operation *op)
Returns true if the construct contains omp.cancel or omp.cancellation_point.
static llvm::omp::OpenMPOffloadMappingFlags convertClauseMapFlags(omp::ClauseMapFlags mlirFlags)
static llvm::IRBuilderBase::InsertPoint createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg, llvm::Value *input, llvm::Value *&retVal, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase::InsertPoint allocaIP, llvm::IRBuilderBase::InsertPoint codeGenIP)
static void mapInitializationArgs(T loop, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, DenseMap< Value, llvm::Value * > &reductionVariableMap, unsigned i)
Map input arguments to reduction initialization region.
static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind)
Convert ProcBindKind from MLIR-generated enum to LLVM enum.
static void fillAffinityLocators(Operation::operand_range affinityVars, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::Value *affinityList)
static void getOverlappedMembers(llvm::SmallVectorImpl< size_t > &overlapMapDataIdxs, omp::MapInfoOp parentOp)
static LogicalResult convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP single construct into LLVM IR using OpenMPIRBuilder.
static bool isDeclareTargetTo(Value value)
static uint64_t getArrayElementSizeInBits(LLVM::LLVMArrayType arrTy, DataLayout &dl)
static void collectReductionDecls(T op, SmallVectorImpl< omp::DeclareReductionOp > &reductions)
Populates reductions with reduction declarations used in the given op.
static LogicalResult handleError(llvm::Error error, Operation &op)
static LogicalResult convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind convertToDeviceClauseKind(mlir::omp::DeclareTargetDeviceType deviceClause)
static LogicalResult checkImplementationStatus(Operation &op)
Check whether translation to LLVM IR for the given operation is currently supported.
static LogicalResult createReductionsAndCleanup(OP op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, ArrayRef< llvm::Value * > privateReductionVariables, ArrayRef< bool > isByRef, bool isNowait=false, bool isTeamsReduction=false)
static LogicalResult convertOmpCancellationPoint(omp::CancellationPointOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static bool opIsInSingleThread(mlir::Operation *op)
This can't always be determined statically, but when we can, it is good to avoid generating compiler-...
static uint64_t getReductionDataSize(OpTy &op)
static LogicalResult convertOmpAtomicRead(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Convert omp.atomic.read operation to LLVM IR.
static llvm::omp::Directive convertCancellationConstructType(omp::ClauseCancellationConstructType directive)
static void initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp, llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs, bool isTargetDevice, bool isGPU)
Populate default MinTeams, MaxTeams and MaxThreads to their default values as stated by the correspon...
static void initTargetRuntimeAttrs(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, omp::TargetOp targetOp, Operation *capturedOp, llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs)
Gather LLVM runtime values for all clauses evaluated in the host that are passed to the kernel invoca...
static LogicalResult convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static bool isHostDeviceOp(Operation *op)
static bool isDeclareTargetLink(Value value)
static LogicalResult convertFlagsAttr(Operation *op, mlir::omp::FlagsAttr attribute, LLVM::ModuleTranslation &moduleTranslation)
Lowers the FlagsAttr which is applied to the module on the device pass when offloading,...
static bool checkIfPointerMap(omp::MapInfoOp mapOp)
static LogicalResult applyTile(omp::TileOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Apply a #pragma omp tile / !$omp tile transformation using the OpenMPIRBuilder.
static void sortMapIndices(llvm::SmallVectorImpl< size_t > &indices, omp::MapInfoOp mapInfo)
static void buildDependData(std::optional< ArrayAttr > dependKinds, OperandRange dependVars, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::OpenMPIRBuilder::DependData > &dds)
static llvm::Function * getOmpTargetFree(llvm::IRBuilderBase &builder, llvm::Module *llvmModule)
static LogicalResult convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder.
static void collectMapDataFromMapOperands(MapInfoData &mapData, SmallVectorImpl< Value > &mapVars, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, llvm::IRBuilderBase &builder, ArrayRef< Value > useDevPtrOperands={}, ArrayRef< Value > useDevAddrOperands={}, ArrayRef< Value > hasDevAddrOperands={})
static void extractAtomicControlFlags(omp::AtomicUpdateOp atomicUpdateOp, bool &isIgnoreDenormalMode, bool &isFineGrainedMemory, bool &isRemoteMemory)
static Operation * genLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr, unsigned numCases, bool needsUniv, ArrayRef< TensorLevel > tidLvls)
Generates a for-loop or a while-loop, depending on whether it implements singleton iteration or co-it...
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
An instance of this location represents a tuple of file, line number, and column number.
Implementation class for module translation.
llvm::BasicBlock * lookupBlock(Block *block) const
Finds an LLVM IR basic block that corresponds to the given MLIR block.
WalkResult stackWalk(llvm::function_ref< WalkResult(T &)> callback)
Calls callback for every ModuleTranslation stack frame of type T starting from the top of the stack.
void stackPush(Args &&...args)
Creates a stack frame of type T on ModuleTranslation stack.
LogicalResult convertBlock(Block &bb, bool ignoreArguments, llvm::IRBuilderBase &builder)
Translates the contents of the given block to LLVM IR using this translator.
SmallVector< llvm::Value * > lookupValues(ValueRange values)
Looks up remapped a list of remapped values.
void mapFunction(StringRef name, llvm::Function *func)
Stores the mapping between a function name and its LLVM IR representation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
void invalidateOmpLoop(omp::NewCliOp mlir)
Mark an OpenMP loop as having been consumed.
SymbolTableCollection & symbolTable()
llvm::Type * convertType(Type type)
Converts the type from MLIR LLVM dialect to LLVM.
llvm::OpenMPIRBuilder * getOpenMPBuilder()
Returns the OpenMP IR builder associated with the LLVM IR module being constructed.
void mapOmpLoop(omp::NewCliOp mlir, llvm::CanonicalLoopInfo *llvm)
Map an MLIR OpenMP dialect CanonicalLoopInfo to its lowered LLVM-IR OpenMPIRBuilder CanonicalLoopInfo...
llvm::GlobalValue * lookupGlobal(Operation *op)
Finds an LLVM IR global value that corresponds to the given MLIR operation defining a global value.
SaveStateStack< T, ModuleTranslation > SaveStack
RAII object calling stackPush/stackPop on construction/destruction.
llvm::Function * lookupFunction(StringRef name) const
Finds an LLVM IR function by its name.
void mapBlock(Block *mlir, llvm::BasicBlock *llvm)
Stores the mapping between an MLIR block and LLVM IR basic block.
llvm::Module * getLLVMModule()
Returns the LLVM module in which the IR is being constructed.
void stackPop()
Pops the last element from the ModuleTranslation stack.
void forgetMapping(Region ®ion)
Removes the mapping for blocks contained in the region and values defined in these blocks.
void mapValue(Value mlir, llvm::Value *llvm)
Stores the mapping between an MLIR value and its LLVM IR counterpart.
llvm::CanonicalLoopInfo * lookupOMPLoop(omp::NewCliOp mlir) const
Find the LLVM-IR loop that represents an MLIR loop.
llvm::LLVMContext & getLLVMContext() const
Returns the LLVM context in which the IR is being constructed.
Utility class to translate MLIR LLVM dialect types to LLVM IR.
unsigned getPreferredAlignment(Type type, const llvm::DataLayout &layout)
Returns the preferred alignment for the type given the data layout.
T findInstanceOf()
Return an instance of the given location type if one is nested under the current location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void appendDialectRegistry(const DialectRegistry ®istry)
Append the contents of the given dialect registry to the registry associated with this context.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Value getOperand(unsigned idx)
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
unsigned getNumOperands()
OperandRange operand_range
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
user_range getUsers()
Returns a range of all users.
MLIRContext * getContext()
Return the context this operation is associated with.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgListType getArguments()
unsigned getNumArguments()
Operation * getParentOp()
Return the parent operation this region is attached to.
BlockListType & getBlocks()
bool hasOneBlock()
Return true if this region has exactly one block.
Concrete CRTP base class for StateStack frames.
@ Private
The symbol is private and may only be referenced by SymbolRefAttrs local to the operations within the...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkResult interrupt()
The OpAsmOpInterface, see OpAsmInterface.td for more details.
void connectPHINodes(Region ®ion, const ModuleTranslation &state)
For all blocks in the region that were converted to LLVM IR using the given ModuleTranslation,...
llvm::Constant * createMappingInformation(Location loc, llvm::OpenMPIRBuilder &builder)
Create a constant string representing the mapping information extracted from the MLIR location inform...
auto getDims(VectorType vType)
Returns a range over the dims (size and scalability) of a VectorType.
Include the generated interface declarations.
SetVector< Block * > getBlocksSortedByDominance(Region ®ion)
Gets a list of blocks that is sorted according to dominance.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void registerOpenMPDialectTranslation(DialectRegistry ®istry)
Register the OpenMP dialect and the translation from it to the LLVM IR in the given registry;.
llvm::SetVector< T, Vector, Set, N > SetVector
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
A util to collect info needed to convert delayed privatizers from MLIR to LLVM.
SmallVector< mlir::Value > mlirVars
SmallVector< omp::PrivateClauseOp > privatizers
MutableArrayRef< BlockArgument > blockArgs
SmallVector< llvm::Value * > llvmVars