MLIR 22.0.0git
OpenMPToLLVMIRTranslation.cpp
Go to the documentation of this file.
1//===- OpenMPToLLVMIRTranslation.cpp - Translate OpenMP dialect to LLVM IR-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a translation between the MLIR OpenMP dialect and LLVM
10// IR.
11//
12//===----------------------------------------------------------------------===//
19#include "mlir/IR/Operation.h"
20#include "mlir/Support/LLVM.h"
23
24#include "llvm/ADT/ArrayRef.h"
25#include "llvm/ADT/SmallVector.h"
26#include "llvm/ADT/TypeSwitch.h"
27#include "llvm/Frontend/OpenMP/OMPConstants.h"
28#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
29#include "llvm/IR/Constants.h"
30#include "llvm/IR/DebugInfoMetadata.h"
31#include "llvm/IR/DerivedTypes.h"
32#include "llvm/IR/IRBuilder.h"
33#include "llvm/IR/MDBuilder.h"
34#include "llvm/IR/ReplaceConstant.h"
35#include "llvm/Support/FileSystem.h"
36#include "llvm/Support/VirtualFileSystem.h"
37#include "llvm/TargetParser/Triple.h"
38#include "llvm/Transforms/Utils/ModuleUtils.h"
39
40#include <cstdint>
41#include <iterator>
42#include <numeric>
43#include <optional>
44#include <utility>
45
46using namespace mlir;
47
48namespace {
49static llvm::omp::ScheduleKind
50convertToScheduleKind(std::optional<omp::ClauseScheduleKind> schedKind) {
51 if (!schedKind.has_value())
52 return llvm::omp::OMP_SCHEDULE_Default;
53 switch (schedKind.value()) {
54 case omp::ClauseScheduleKind::Static:
55 return llvm::omp::OMP_SCHEDULE_Static;
56 case omp::ClauseScheduleKind::Dynamic:
57 return llvm::omp::OMP_SCHEDULE_Dynamic;
58 case omp::ClauseScheduleKind::Guided:
59 return llvm::omp::OMP_SCHEDULE_Guided;
60 case omp::ClauseScheduleKind::Auto:
61 return llvm::omp::OMP_SCHEDULE_Auto;
62 case omp::ClauseScheduleKind::Runtime:
63 return llvm::omp::OMP_SCHEDULE_Runtime;
64 case omp::ClauseScheduleKind::Distribute:
65 return llvm::omp::OMP_SCHEDULE_Distribute;
66 }
67 llvm_unreachable("unhandled schedule clause argument");
68}
69
70/// ModuleTranslation stack frame for OpenMP operations. This keeps track of the
71/// insertion points for allocas.
72class OpenMPAllocaStackFrame
73 : public StateStackFrameBase<OpenMPAllocaStackFrame> {
74public:
76
77 explicit OpenMPAllocaStackFrame(llvm::OpenMPIRBuilder::InsertPointTy allocaIP)
78 : allocaInsertPoint(allocaIP) {}
79 llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
80};
81
82/// Stack frame to hold a \see llvm::CanonicalLoopInfo representing the
83/// collapsed canonical loop information corresponding to an \c omp.loop_nest
84/// operation.
85class OpenMPLoopInfoStackFrame
86 : public StateStackFrameBase<OpenMPLoopInfoStackFrame> {
87public:
88 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPLoopInfoStackFrame)
89 llvm::CanonicalLoopInfo *loopInfo = nullptr;
90};
91
92/// Custom error class to signal translation errors that don't need reporting,
93/// since encountering them will have already triggered relevant error messages.
94///
95/// Its purpose is to serve as the glue between MLIR failures represented as
96/// \see LogicalResult instances and \see llvm::Error instances used to
97/// propagate errors through the \see llvm::OpenMPIRBuilder. Generally, when an
98/// error of the first type is raised, a message is emitted directly (the \see
99/// LogicalResult itself does not hold any information). If we need to forward
100/// this error condition as an \see llvm::Error while avoiding triggering some
101/// redundant error reporting later on, we need a custom \see llvm::ErrorInfo
102/// class to just signal this situation has happened.
103///
104/// For example, this class should be used to trigger errors from within
105/// callbacks passed to the \see OpenMPIRBuilder when they were triggered by the
106/// translation of their own regions. This unclutters the error log from
107/// redundant messages.
108class PreviouslyReportedError
109 : public llvm::ErrorInfo<PreviouslyReportedError> {
110public:
111 void log(raw_ostream &) const override {
112 // Do not log anything.
113 }
114
115 std::error_code convertToErrorCode() const override {
116 llvm_unreachable(
117 "PreviouslyReportedError doesn't support ECError conversion");
118 }
119
120 // Used by ErrorInfo::classID.
121 static char ID;
122};
123
124char PreviouslyReportedError::ID = 0;
125
126/*
127 * Custom class for processing linear clause for omp.wsloop
128 * and omp.simd. Linear clause translation requires setup,
129 * initialization, update, and finalization at varying
130 * basic blocks in the IR. This class helps maintain
131 * internal state to allow consistent translation in
132 * each of these stages.
133 */
134
135class LinearClauseProcessor {
136
137private:
138 SmallVector<llvm::Value *> linearPreconditionVars;
139 SmallVector<llvm::Value *> linearLoopBodyTemps;
140 SmallVector<llvm::Value *> linearOrigVal;
141 SmallVector<llvm::Value *> linearSteps;
142 SmallVector<llvm::Type *> linearVarTypes;
143 llvm::BasicBlock *linearFinalizationBB;
144 llvm::BasicBlock *linearExitBB;
145 llvm::BasicBlock *linearLastIterExitBB;
146
147public:
148 // Register type for the linear variables
149 void registerType(LLVM::ModuleTranslation &moduleTranslation,
150 mlir::Attribute &ty) {
151 linearVarTypes.push_back(moduleTranslation.convertType(
152 mlir::cast<mlir::TypeAttr>(ty).getValue()));
153 }
154
155 // Allocate space for linear variabes
156 void createLinearVar(llvm::IRBuilderBase &builder,
157 LLVM::ModuleTranslation &moduleTranslation,
158 mlir::Value &linearVar, int idx) {
159 linearPreconditionVars.push_back(
160 builder.CreateAlloca(linearVarTypes[idx], nullptr, ".linear_var"));
161 llvm::Value *linearLoopBodyTemp =
162 builder.CreateAlloca(linearVarTypes[idx], nullptr, ".linear_result");
163 linearOrigVal.push_back(moduleTranslation.lookupValue(linearVar));
164 linearLoopBodyTemps.push_back(linearLoopBodyTemp);
165 }
166
167 // Initialize linear step
168 inline void initLinearStep(LLVM::ModuleTranslation &moduleTranslation,
169 mlir::Value &linearStep) {
170 linearSteps.push_back(moduleTranslation.lookupValue(linearStep));
171 }
172
173 // Emit IR for initialization of linear variables
174 void initLinearVar(llvm::IRBuilderBase &builder,
175 LLVM::ModuleTranslation &moduleTranslation,
176 llvm::BasicBlock *loopPreHeader) {
177 builder.SetInsertPoint(loopPreHeader->getTerminator());
178 for (size_t index = 0; index < linearOrigVal.size(); index++) {
179 llvm::LoadInst *linearVarLoad =
180 builder.CreateLoad(linearVarTypes[index], linearOrigVal[index]);
181 builder.CreateStore(linearVarLoad, linearPreconditionVars[index]);
182 }
183 }
184
185 // Emit IR for updating Linear variables
186 void updateLinearVar(llvm::IRBuilderBase &builder, llvm::BasicBlock *loopBody,
187 llvm::Value *loopInductionVar) {
188 builder.SetInsertPoint(loopBody->getTerminator());
189 for (size_t index = 0; index < linearPreconditionVars.size(); index++) {
190 // Emit increments for linear vars
191 llvm::LoadInst *linearVarStart = builder.CreateLoad(
192 linearVarTypes[index], linearPreconditionVars[index]);
193 auto mulInst = builder.CreateMul(loopInductionVar, linearSteps[index]);
194 if (linearVarTypes[index]->isIntegerTy()) {
195 auto addInst = builder.CreateAdd(linearVarStart, mulInst);
196 builder.CreateStore(addInst, linearLoopBodyTemps[index]);
197 } else if (linearVarTypes[index]->isFloatingPointTy()) {
198 auto cvt = builder.CreateSIToFP(mulInst, linearVarTypes[index]);
199 auto addInst = builder.CreateFAdd(linearVarStart, cvt);
200 builder.CreateStore(addInst, linearLoopBodyTemps[index]);
201 }
202 }
203 }
204
205 // Linear variable finalization is conditional on the last logical iteration.
206 // Create BB splits to manage the same.
207 void splitLinearFiniBB(llvm::IRBuilderBase &builder,
208 llvm::BasicBlock *loopExit) {
209 linearFinalizationBB = loopExit->splitBasicBlock(
210 loopExit->getTerminator(), "omp_loop.linear_finalization");
211 linearExitBB = linearFinalizationBB->splitBasicBlock(
212 linearFinalizationBB->getTerminator(), "omp_loop.linear_exit");
213 linearLastIterExitBB = linearFinalizationBB->splitBasicBlock(
214 linearFinalizationBB->getTerminator(), "omp_loop.linear_lastiter_exit");
215 }
216
217 // Finalize the linear vars
218 llvm::OpenMPIRBuilder::InsertPointOrErrorTy
219 finalizeLinearVar(llvm::IRBuilderBase &builder,
220 LLVM::ModuleTranslation &moduleTranslation,
221 llvm::Value *lastIter) {
222 // Emit condition to check whether last logical iteration is being executed
223 builder.SetInsertPoint(linearFinalizationBB->getTerminator());
224 llvm::Value *loopLastIterLoad = builder.CreateLoad(
225 llvm::Type::getInt32Ty(builder.getContext()), lastIter);
226 llvm::Value *isLast =
227 builder.CreateCmp(llvm::CmpInst::ICMP_NE, loopLastIterLoad,
228 llvm::ConstantInt::get(
229 llvm::Type::getInt32Ty(builder.getContext()), 0));
230 // Store the linear variable values to original variables.
231 builder.SetInsertPoint(linearLastIterExitBB->getTerminator());
232 for (size_t index = 0; index < linearOrigVal.size(); index++) {
233 llvm::LoadInst *linearVarTemp =
234 builder.CreateLoad(linearVarTypes[index], linearLoopBodyTemps[index]);
235 builder.CreateStore(linearVarTemp, linearOrigVal[index]);
236 }
237
238 // Create conditional branch such that the linear variable
239 // values are stored to original variables only at the
240 // last logical iteration
241 builder.SetInsertPoint(linearFinalizationBB->getTerminator());
242 builder.CreateCondBr(isLast, linearLastIterExitBB, linearExitBB);
243 linearFinalizationBB->getTerminator()->eraseFromParent();
244 // Emit barrier
245 builder.SetInsertPoint(linearExitBB->getTerminator());
246 return moduleTranslation.getOpenMPBuilder()->createBarrier(
247 builder.saveIP(), llvm::omp::OMPD_barrier);
248 }
249
250 // Rewrite all uses of the original variable in `BBName`
251 // with the linear variable in-place
252 void rewriteInPlace(llvm::IRBuilderBase &builder, const std::string &BBName,
253 size_t varIndex) {
254 llvm::SmallVector<llvm::User *> users;
255 for (llvm::User *user : linearOrigVal[varIndex]->users())
256 users.push_back(user);
257 for (auto *user : users) {
258 if (auto *userInst = dyn_cast<llvm::Instruction>(user)) {
259 if (userInst->getParent()->getName().str().find(BBName) !=
260 std::string::npos)
261 user->replaceUsesOfWith(linearOrigVal[varIndex],
262 linearLoopBodyTemps[varIndex]);
263 }
264 }
265 }
266};
267
268} // namespace
269
270/// Looks up from the operation from and returns the PrivateClauseOp with
271/// name symbolName
272static omp::PrivateClauseOp findPrivatizer(Operation *from,
273 SymbolRefAttr symbolName) {
274 omp::PrivateClauseOp privatizer =
276 symbolName);
277 assert(privatizer && "privatizer not found in the symbol table");
278 return privatizer;
279}
280
281/// Check whether translation to LLVM IR for the given operation is currently
282/// supported. If not, descriptive diagnostics will be emitted to let users know
283/// this is a not-yet-implemented feature.
284///
285/// \returns success if no unimplemented features are needed to translate the
286/// given operation.
287static LogicalResult checkImplementationStatus(Operation &op) {
288 auto todo = [&op](StringRef clauseName) {
289 return op.emitError() << "not yet implemented: Unhandled clause "
290 << clauseName << " in " << op.getName()
291 << " operation";
292 };
293
294 auto checkAllocate = [&todo](auto op, LogicalResult &result) {
295 if (!op.getAllocateVars().empty() || !op.getAllocatorVars().empty())
296 result = todo("allocate");
297 };
298 auto checkBare = [&todo](auto op, LogicalResult &result) {
299 if (op.getBare())
300 result = todo("ompx_bare");
301 };
302 auto checkCancelDirective = [&todo](auto op, LogicalResult &result) {
303 omp::ClauseCancellationConstructType cancelledDirective =
304 op.getCancelDirective();
305 // Cancelling a taskloop is not yet supported because we don't yet have LLVM
306 // IR conversion for taskloop
307 if (cancelledDirective == omp::ClauseCancellationConstructType::Taskgroup) {
308 Operation *parent = op->getParentOp();
309 while (parent) {
310 if (parent->getDialect() == op->getDialect())
311 break;
312 parent = parent->getParentOp();
313 }
314 if (isa_and_nonnull<omp::TaskloopOp>(parent))
315 result = todo("cancel directive inside of taskloop");
316 }
317 };
318 auto checkDepend = [&todo](auto op, LogicalResult &result) {
319 if (!op.getDependVars().empty() || op.getDependKinds())
320 result = todo("depend");
321 };
322 auto checkDevice = [&todo](auto op, LogicalResult &result) {
323 if (op.getDevice())
324 result = todo("device");
325 };
326 auto checkHint = [](auto op, LogicalResult &) {
327 if (op.getHint())
328 op.emitWarning("hint clause discarded");
329 };
330 auto checkInReduction = [&todo](auto op, LogicalResult &result) {
331 if (!op.getInReductionVars().empty() || op.getInReductionByref() ||
332 op.getInReductionSyms())
333 result = todo("in_reduction");
334 };
335 auto checkNowait = [&todo](auto op, LogicalResult &result) {
336 if (op.getNowait())
337 result = todo("nowait");
338 };
339 auto checkOrder = [&todo](auto op, LogicalResult &result) {
340 if (op.getOrder() || op.getOrderMod())
341 result = todo("order");
342 };
343 auto checkParLevelSimd = [&todo](auto op, LogicalResult &result) {
344 if (op.getParLevelSimd())
345 result = todo("parallelization-level");
346 };
347 auto checkPriority = [&todo](auto op, LogicalResult &result) {
348 if (op.getPriority())
349 result = todo("priority");
350 };
351 auto checkPrivate = [&todo](auto op, LogicalResult &result) {
352 if (!op.getPrivateVars().empty() || op.getPrivateSyms())
353 result = todo("privatization");
354 };
355 auto checkReduction = [&todo](auto op, LogicalResult &result) {
356 if (isa<omp::TeamsOp>(op))
357 if (!op.getReductionVars().empty() || op.getReductionByref() ||
358 op.getReductionSyms())
359 result = todo("reduction");
360 if (op.getReductionMod() &&
361 op.getReductionMod().value() != omp::ReductionModifier::defaultmod)
362 result = todo("reduction with modifier");
363 };
364 auto checkTaskReduction = [&todo](auto op, LogicalResult &result) {
365 if (!op.getTaskReductionVars().empty() || op.getTaskReductionByref() ||
366 op.getTaskReductionSyms())
367 result = todo("task_reduction");
368 };
369 auto checkUntied = [&todo](auto op, LogicalResult &result) {
370 if (op.getUntied())
371 result = todo("untied");
372 };
373
374 LogicalResult result = success();
376 .Case([&](omp::CancelOp op) { checkCancelDirective(op, result); })
377 .Case([&](omp::CancellationPointOp op) {
378 checkCancelDirective(op, result);
379 })
380 .Case([&](omp::DistributeOp op) {
381 checkAllocate(op, result);
382 checkOrder(op, result);
383 })
384 .Case([&](omp::OrderedRegionOp op) { checkParLevelSimd(op, result); })
385 .Case([&](omp::SectionsOp op) {
386 checkAllocate(op, result);
387 checkPrivate(op, result);
388 checkReduction(op, result);
389 })
390 .Case([&](omp::SingleOp op) {
391 checkAllocate(op, result);
392 checkPrivate(op, result);
393 })
394 .Case([&](omp::TeamsOp op) {
395 checkAllocate(op, result);
396 checkPrivate(op, result);
397 })
398 .Case([&](omp::TaskOp op) {
399 checkAllocate(op, result);
400 checkInReduction(op, result);
401 })
402 .Case([&](omp::TaskgroupOp op) {
403 checkAllocate(op, result);
404 checkTaskReduction(op, result);
405 })
406 .Case([&](omp::TaskwaitOp op) {
407 checkDepend(op, result);
408 checkNowait(op, result);
409 })
410 .Case([&](omp::TaskloopOp op) {
411 // TODO: Add other clauses check
412 checkUntied(op, result);
413 checkPriority(op, result);
414 })
415 .Case([&](omp::WsloopOp op) {
416 checkAllocate(op, result);
417 checkOrder(op, result);
418 checkReduction(op, result);
419 })
420 .Case([&](omp::ParallelOp op) {
421 checkAllocate(op, result);
422 checkReduction(op, result);
423 })
424 .Case([&](omp::SimdOp op) { checkReduction(op, result); })
425 .Case<omp::AtomicReadOp, omp::AtomicWriteOp, omp::AtomicUpdateOp,
426 omp::AtomicCaptureOp>([&](auto op) { checkHint(op, result); })
427 .Case<omp::TargetEnterDataOp, omp::TargetExitDataOp, omp::TargetUpdateOp>(
428 [&](auto op) { checkDepend(op, result); })
429 .Case([&](omp::TargetOp op) {
430 checkAllocate(op, result);
431 checkBare(op, result);
432 checkDevice(op, result);
433 checkInReduction(op, result);
434 })
435 .Default([](Operation &) {
436 // Assume all clauses for an operation can be translated unless they are
437 // checked above.
438 });
439 return result;
440}
441
442static LogicalResult handleError(llvm::Error error, Operation &op) {
443 LogicalResult result = success();
444 if (error) {
445 llvm::handleAllErrors(
446 std::move(error),
447 [&](const PreviouslyReportedError &) { result = failure(); },
448 [&](const llvm::ErrorInfoBase &err) {
449 result = op.emitError(err.message());
450 });
451 }
452 return result;
453}
454
455template <typename T>
456static LogicalResult handleError(llvm::Expected<T> &result, Operation &op) {
457 if (!result)
458 return handleError(result.takeError(), op);
459
460 return success();
461}
462
463/// Find the insertion point for allocas given the current insertion point for
464/// normal operations in the builder.
465static llvm::OpenMPIRBuilder::InsertPointTy
466findAllocaInsertPoint(llvm::IRBuilderBase &builder,
467 LLVM::ModuleTranslation &moduleTranslation) {
468 // If there is an alloca insertion point on stack, i.e. we are in a nested
469 // operation and a specific point was provided by some surrounding operation,
470 // use it.
471 llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
472 WalkResult walkResult = moduleTranslation.stackWalk<OpenMPAllocaStackFrame>(
473 [&](OpenMPAllocaStackFrame &frame) {
474 allocaInsertPoint = frame.allocaInsertPoint;
475 return WalkResult::interrupt();
476 });
477 // In cases with multiple levels of outlining, the tree walk might find an
478 // alloca insertion point that is inside the original function while the
479 // builder insertion point is inside the outlined function. We need to make
480 // sure that we do not use it in those cases.
481 if (walkResult.wasInterrupted() &&
482 allocaInsertPoint.getBlock()->getParent() ==
483 builder.GetInsertBlock()->getParent())
484 return allocaInsertPoint;
485
486 // Otherwise, insert to the entry block of the surrounding function.
487 // If the current IRBuilder InsertPoint is the function's entry, it cannot
488 // also be used for alloca insertion which would result in insertion order
489 // confusion. Create a new BasicBlock for the Builder and use the entry block
490 // for the allocs.
491 // TODO: Create a dedicated alloca BasicBlock at function creation such that
492 // we do not need to move the current InertPoint here.
493 if (builder.GetInsertBlock() ==
494 &builder.GetInsertBlock()->getParent()->getEntryBlock()) {
495 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end() &&
496 "Assuming end of basic block");
497 llvm::BasicBlock *entryBB = llvm::BasicBlock::Create(
498 builder.getContext(), "entry", builder.GetInsertBlock()->getParent(),
499 builder.GetInsertBlock()->getNextNode());
500 builder.CreateBr(entryBB);
501 builder.SetInsertPoint(entryBB);
502 }
503
504 llvm::BasicBlock &funcEntryBlock =
505 builder.GetInsertBlock()->getParent()->getEntryBlock();
506 return llvm::OpenMPIRBuilder::InsertPointTy(
507 &funcEntryBlock, funcEntryBlock.getFirstInsertionPt());
508}
509
510/// Find the loop information structure for the loop nest being translated. It
511/// will return a `null` value unless called from the translation function for
512/// a loop wrapper operation after successfully translating its body.
513static llvm::CanonicalLoopInfo *
515 llvm::CanonicalLoopInfo *loopInfo = nullptr;
516 moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
517 [&](OpenMPLoopInfoStackFrame &frame) {
518 loopInfo = frame.loopInfo;
519 return WalkResult::interrupt();
520 });
521 return loopInfo;
522}
523
524/// Converts the given region that appears within an OpenMP dialect operation to
525/// LLVM IR, creating a branch from the `sourceBlock` to the entry block of the
526/// region, and a branch from any block with an successor-less OpenMP terminator
527/// to `continuationBlock`. Populates `continuationBlockPHIs` with the PHI nodes
528/// of the continuation block if provided.
530 Region &region, StringRef blockName, llvm::IRBuilderBase &builder,
531 LLVM::ModuleTranslation &moduleTranslation,
532 SmallVectorImpl<llvm::PHINode *> *continuationBlockPHIs = nullptr) {
533 bool isLoopWrapper = isa<omp::LoopWrapperInterface>(region.getParentOp());
534
535 llvm::BasicBlock *continuationBlock =
536 splitBB(builder, true, "omp.region.cont");
537 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
538
539 llvm::LLVMContext &llvmContext = builder.getContext();
540 for (Block &bb : region) {
541 llvm::BasicBlock *llvmBB = llvm::BasicBlock::Create(
542 llvmContext, blockName, builder.GetInsertBlock()->getParent(),
543 builder.GetInsertBlock()->getNextNode());
544 moduleTranslation.mapBlock(&bb, llvmBB);
545 }
546
547 llvm::Instruction *sourceTerminator = sourceBlock->getTerminator();
548
549 // Terminators (namely YieldOp) may be forwarding values to the region that
550 // need to be available in the continuation block. Collect the types of these
551 // operands in preparation of creating PHI nodes. This is skipped for loop
552 // wrapper operations, for which we know in advance they have no terminators.
553 SmallVector<llvm::Type *> continuationBlockPHITypes;
554 unsigned numYields = 0;
555
556 if (!isLoopWrapper) {
557 bool operandsProcessed = false;
558 for (Block &bb : region.getBlocks()) {
559 if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(bb.getTerminator())) {
560 if (!operandsProcessed) {
561 for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
562 continuationBlockPHITypes.push_back(
563 moduleTranslation.convertType(yield->getOperand(i).getType()));
564 }
565 operandsProcessed = true;
566 } else {
567 assert(continuationBlockPHITypes.size() == yield->getNumOperands() &&
568 "mismatching number of values yielded from the region");
569 for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
570 llvm::Type *operandType =
571 moduleTranslation.convertType(yield->getOperand(i).getType());
572 (void)operandType;
573 assert(continuationBlockPHITypes[i] == operandType &&
574 "values of mismatching types yielded from the region");
575 }
576 }
577 numYields++;
578 }
579 }
580 }
581
582 // Insert PHI nodes in the continuation block for any values forwarded by the
583 // terminators in this region.
584 if (!continuationBlockPHITypes.empty())
585 assert(
586 continuationBlockPHIs &&
587 "expected continuation block PHIs if converted regions yield values");
588 if (continuationBlockPHIs) {
589 llvm::IRBuilderBase::InsertPointGuard guard(builder);
590 continuationBlockPHIs->reserve(continuationBlockPHITypes.size());
591 builder.SetInsertPoint(continuationBlock, continuationBlock->begin());
592 for (llvm::Type *ty : continuationBlockPHITypes)
593 continuationBlockPHIs->push_back(builder.CreatePHI(ty, numYields));
594 }
595
596 // Convert blocks one by one in topological order to ensure
597 // defs are converted before uses.
599 for (Block *bb : blocks) {
600 llvm::BasicBlock *llvmBB = moduleTranslation.lookupBlock(bb);
601 // Retarget the branch of the entry block to the entry block of the
602 // converted region (regions are single-entry).
603 if (bb->isEntryBlock()) {
604 assert(sourceTerminator->getNumSuccessors() == 1 &&
605 "provided entry block has multiple successors");
606 assert(sourceTerminator->getSuccessor(0) == continuationBlock &&
607 "ContinuationBlock is not the successor of the entry block");
608 sourceTerminator->setSuccessor(0, llvmBB);
609 }
610
611 llvm::IRBuilderBase::InsertPointGuard guard(builder);
612 if (failed(
613 moduleTranslation.convertBlock(*bb, bb->isEntryBlock(), builder)))
614 return llvm::make_error<PreviouslyReportedError>();
615
616 // Create a direct branch here for loop wrappers to prevent their lack of a
617 // terminator from causing a crash below.
618 if (isLoopWrapper) {
619 builder.CreateBr(continuationBlock);
620 continue;
621 }
622
623 // Special handling for `omp.yield` and `omp.terminator` (we may have more
624 // than one): they return the control to the parent OpenMP dialect operation
625 // so replace them with the branch to the continuation block. We handle this
626 // here to avoid relying inter-function communication through the
627 // ModuleTranslation class to set up the correct insertion point. This is
628 // also consistent with MLIR's idiom of handling special region terminators
629 // in the same code that handles the region-owning operation.
630 Operation *terminator = bb->getTerminator();
631 if (isa<omp::TerminatorOp, omp::YieldOp>(terminator)) {
632 builder.CreateBr(continuationBlock);
633
634 for (unsigned i = 0, e = terminator->getNumOperands(); i < e; ++i)
635 (*continuationBlockPHIs)[i]->addIncoming(
636 moduleTranslation.lookupValue(terminator->getOperand(i)), llvmBB);
637 }
638 }
639 // After all blocks have been traversed and values mapped, connect the PHI
640 // nodes to the results of preceding blocks.
641 LLVM::detail::connectPHINodes(region, moduleTranslation);
642
643 // Remove the blocks and values defined in this region from the mapping since
644 // they are not visible outside of this region. This allows the same region to
645 // be converted several times, that is cloned, without clashes, and slightly
646 // speeds up the lookups.
647 moduleTranslation.forgetMapping(region);
648
649 return continuationBlock;
650}
651
652/// Convert ProcBindKind from MLIR-generated enum to LLVM enum.
653static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind) {
654 switch (kind) {
655 case omp::ClauseProcBindKind::Close:
656 return llvm::omp::ProcBindKind::OMP_PROC_BIND_close;
657 case omp::ClauseProcBindKind::Master:
658 return llvm::omp::ProcBindKind::OMP_PROC_BIND_master;
659 case omp::ClauseProcBindKind::Primary:
660 return llvm::omp::ProcBindKind::OMP_PROC_BIND_primary;
661 case omp::ClauseProcBindKind::Spread:
662 return llvm::omp::ProcBindKind::OMP_PROC_BIND_spread;
663 }
664 llvm_unreachable("Unknown ClauseProcBindKind kind");
665}
666
667/// Maps block arguments from \p blockArgIface (which are MLIR values) to the
668/// corresponding LLVM values of \p the interface's operands. This is useful
669/// when an OpenMP region with entry block arguments is converted to LLVM. In
670/// this case the block arguments are (part of) of the OpenMP region's entry
671/// arguments and the operands are (part of) of the operands to the OpenMP op
672/// containing the region.
673static void forwardArgs(LLVM::ModuleTranslation &moduleTranslation,
674 omp::BlockArgOpenMPOpInterface blockArgIface) {
676 blockArgIface.getBlockArgsPairs(blockArgsPairs);
677 for (auto [var, arg] : blockArgsPairs)
678 moduleTranslation.mapValue(arg, moduleTranslation.lookupValue(var));
679}
680
681/// Converts an OpenMP 'masked' operation into LLVM IR using OpenMPIRBuilder.
682static LogicalResult
683convertOmpMasked(Operation &opInst, llvm::IRBuilderBase &builder,
684 LLVM::ModuleTranslation &moduleTranslation) {
685 auto maskedOp = cast<omp::MaskedOp>(opInst);
686 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
687
688 if (failed(checkImplementationStatus(opInst)))
689 return failure();
690
691 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
692 // MaskedOp has only one region associated with it.
693 auto &region = maskedOp.getRegion();
694 builder.restoreIP(codeGenIP);
695 return convertOmpOpRegions(region, "omp.masked.region", builder,
696 moduleTranslation)
697 .takeError();
698 };
699
700 // TODO: Perform finalization actions for variables. This has to be
701 // called for variables which have destructors/finalizers.
702 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
703
704 llvm::Value *filterVal = nullptr;
705 if (auto filterVar = maskedOp.getFilteredThreadId()) {
706 filterVal = moduleTranslation.lookupValue(filterVar);
707 } else {
708 llvm::LLVMContext &llvmContext = builder.getContext();
709 filterVal =
710 llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext), /*V=*/0);
711 }
712 assert(filterVal != nullptr);
713 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
714 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
715 moduleTranslation.getOpenMPBuilder()->createMasked(ompLoc, bodyGenCB,
716 finiCB, filterVal);
717
718 if (failed(handleError(afterIP, opInst)))
719 return failure();
720
721 builder.restoreIP(*afterIP);
722 return success();
723}
724
725/// Converts an OpenMP 'master' operation into LLVM IR using OpenMPIRBuilder.
726static LogicalResult
727convertOmpMaster(Operation &opInst, llvm::IRBuilderBase &builder,
728 LLVM::ModuleTranslation &moduleTranslation) {
729 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
730 auto masterOp = cast<omp::MasterOp>(opInst);
731
732 if (failed(checkImplementationStatus(opInst)))
733 return failure();
734
735 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
736 // MasterOp has only one region associated with it.
737 auto &region = masterOp.getRegion();
738 builder.restoreIP(codeGenIP);
739 return convertOmpOpRegions(region, "omp.master.region", builder,
740 moduleTranslation)
741 .takeError();
742 };
743
744 // TODO: Perform finalization actions for variables. This has to be
745 // called for variables which have destructors/finalizers.
746 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
747
748 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
749 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
750 moduleTranslation.getOpenMPBuilder()->createMaster(ompLoc, bodyGenCB,
751 finiCB);
752
753 if (failed(handleError(afterIP, opInst)))
754 return failure();
755
756 builder.restoreIP(*afterIP);
757 return success();
758}
759
760/// Converts an OpenMP 'critical' operation into LLVM IR using OpenMPIRBuilder.
761static LogicalResult
762convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder,
763 LLVM::ModuleTranslation &moduleTranslation) {
764 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
765 auto criticalOp = cast<omp::CriticalOp>(opInst);
766
767 if (failed(checkImplementationStatus(opInst)))
768 return failure();
769
770 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
771 // CriticalOp has only one region associated with it.
772 auto &region = cast<omp::CriticalOp>(opInst).getRegion();
773 builder.restoreIP(codeGenIP);
774 return convertOmpOpRegions(region, "omp.critical.region", builder,
775 moduleTranslation)
776 .takeError();
777 };
778
779 // TODO: Perform finalization actions for variables. This has to be
780 // called for variables which have destructors/finalizers.
781 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
782
783 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
784 llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
785 llvm::Constant *hint = nullptr;
786
787 // If it has a name, it probably has a hint too.
788 if (criticalOp.getNameAttr()) {
789 // The verifiers in OpenMP Dialect guarentee that all the pointers are
790 // non-null
791 auto symbolRef = cast<SymbolRefAttr>(criticalOp.getNameAttr());
792 auto criticalDeclareOp =
794 symbolRef);
795 hint =
796 llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext),
797 static_cast<int>(criticalDeclareOp.getHint()));
798 }
799 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
800 moduleTranslation.getOpenMPBuilder()->createCritical(
801 ompLoc, bodyGenCB, finiCB, criticalOp.getName().value_or(""), hint);
802
803 if (failed(handleError(afterIP, opInst)))
804 return failure();
805
806 builder.restoreIP(*afterIP);
807 return success();
808}
809
810/// A util to collect info needed to convert delayed privatizers from MLIR to
811/// LLVM.
813 template <typename OP>
815 : blockArgs(
816 cast<omp::BlockArgOpenMPOpInterface>(*op).getPrivateBlockArgs()) {
817 mlirVars.reserve(blockArgs.size());
818 llvmVars.reserve(blockArgs.size());
819 collectPrivatizationDecls<OP>(op);
820
821 for (mlir::Value privateVar : op.getPrivateVars())
822 mlirVars.push_back(privateVar);
823 }
824
829
830private:
831 /// Populates `privatizations` with privatization declarations used for the
832 /// given op.
833 template <class OP>
834 void collectPrivatizationDecls(OP op) {
835 std::optional<ArrayAttr> attr = op.getPrivateSyms();
836 if (!attr)
837 return;
838
839 privatizers.reserve(privatizers.size() + attr->size());
840 for (auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
841 privatizers.push_back(findPrivatizer(op, symbolRef));
842 }
843 }
844};
845
846/// Populates `reductions` with reduction declarations used in the given op.
847template <typename T>
848static void
851 std::optional<ArrayAttr> attr = op.getReductionSyms();
852 if (!attr)
853 return;
854
855 reductions.reserve(reductions.size() + op.getNumReductionVars());
856 for (auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
857 reductions.push_back(
859 op, symbolRef));
860 }
861}
862
863/// Translates the blocks contained in the given region and appends them to at
864/// the current insertion point of `builder`. The operations of the entry block
865/// are appended to the current insertion block. If set, `continuationBlockArgs`
866/// is populated with translated values that correspond to the values
867/// omp.yield'ed from the region.
868static LogicalResult inlineConvertOmpRegions(
869 Region &region, StringRef blockName, llvm::IRBuilderBase &builder,
870 LLVM::ModuleTranslation &moduleTranslation,
871 SmallVectorImpl<llvm::Value *> *continuationBlockArgs = nullptr) {
872 if (region.empty())
873 return success();
874
875 // Special case for single-block regions that don't create additional blocks:
876 // insert operations without creating additional blocks.
877 if (region.hasOneBlock()) {
878 llvm::Instruction *potentialTerminator =
879 builder.GetInsertBlock()->empty() ? nullptr
880 : &builder.GetInsertBlock()->back();
881
882 if (potentialTerminator && potentialTerminator->isTerminator())
883 potentialTerminator->removeFromParent();
884 moduleTranslation.mapBlock(&region.front(), builder.GetInsertBlock());
885
886 if (failed(moduleTranslation.convertBlock(
887 region.front(), /*ignoreArguments=*/true, builder)))
888 return failure();
889
890 // The continuation arguments are simply the translated terminator operands.
891 if (continuationBlockArgs)
892 llvm::append_range(
893 *continuationBlockArgs,
894 moduleTranslation.lookupValues(region.front().back().getOperands()));
895
896 // Drop the mapping that is no longer necessary so that the same region can
897 // be processed multiple times.
898 moduleTranslation.forgetMapping(region);
899
900 if (potentialTerminator && potentialTerminator->isTerminator()) {
901 llvm::BasicBlock *block = builder.GetInsertBlock();
902 if (block->empty()) {
903 // this can happen for really simple reduction init regions e.g.
904 // %0 = llvm.mlir.constant(0 : i32) : i32
905 // omp.yield(%0 : i32)
906 // because the llvm.mlir.constant (MLIR op) isn't converted into any
907 // llvm op
908 potentialTerminator->insertInto(block, block->begin());
909 } else {
910 potentialTerminator->insertAfter(&block->back());
911 }
912 }
913
914 return success();
915 }
916
918 llvm::Expected<llvm::BasicBlock *> continuationBlock =
919 convertOmpOpRegions(region, blockName, builder, moduleTranslation, &phis);
920
921 if (failed(handleError(continuationBlock, *region.getParentOp())))
922 return failure();
923
924 if (continuationBlockArgs)
925 llvm::append_range(*continuationBlockArgs, phis);
926 builder.SetInsertPoint(*continuationBlock,
927 (*continuationBlock)->getFirstInsertionPt());
928 return success();
929}
930
931namespace {
932/// Owning equivalents of OpenMPIRBuilder::(Atomic)ReductionGen that are used to
933/// store lambdas with capture.
934using OwningReductionGen =
935 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
936 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *,
937 llvm::Value *&)>;
938using OwningAtomicReductionGen =
939 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
940 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Type *, llvm::Value *,
941 llvm::Value *)>;
942using OwningDataPtrPtrReductionGen =
943 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
944 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *&)>;
945} // namespace
946
947/// Create an OpenMPIRBuilder-compatible reduction generator for the given
948/// reduction declaration. The generator uses `builder` but ignores its
949/// insertion point.
950static OwningReductionGen
951makeReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder,
952 LLVM::ModuleTranslation &moduleTranslation) {
953 // The lambda is mutable because we need access to non-const methods of decl
954 // (which aren't actually mutating it), and we must capture decl by-value to
955 // avoid the dangling reference after the parent function returns.
956 OwningReductionGen gen =
957 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
958 llvm::Value *lhs, llvm::Value *rhs,
959 llvm::Value *&result) mutable
960 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
961 moduleTranslation.mapValue(decl.getReductionLhsArg(), lhs);
962 moduleTranslation.mapValue(decl.getReductionRhsArg(), rhs);
963 builder.restoreIP(insertPoint);
965 if (failed(inlineConvertOmpRegions(decl.getReductionRegion(),
966 "omp.reduction.nonatomic.body", builder,
967 moduleTranslation, &phis)))
968 return llvm::createStringError(
969 "failed to inline `combiner` region of `omp.declare_reduction`");
970 result = llvm::getSingleElement(phis);
971 return builder.saveIP();
972 };
973 return gen;
974}
975
976/// Create an OpenMPIRBuilder-compatible atomic reduction generator for the
977/// given reduction declaration. The generator uses `builder` but ignores its
978/// insertion point. Returns null if there is no atomic region available in the
979/// reduction declaration.
980static OwningAtomicReductionGen
981makeAtomicReductionGen(omp::DeclareReductionOp decl,
982 llvm::IRBuilderBase &builder,
983 LLVM::ModuleTranslation &moduleTranslation) {
984 if (decl.getAtomicReductionRegion().empty())
985 return OwningAtomicReductionGen();
986
987 // The lambda is mutable because we need access to non-const methods of decl
988 // (which aren't actually mutating it), and we must capture decl by-value to
989 // avoid the dangling reference after the parent function returns.
990 OwningAtomicReductionGen atomicGen =
991 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint, llvm::Type *,
992 llvm::Value *lhs, llvm::Value *rhs) mutable
993 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
994 moduleTranslation.mapValue(decl.getAtomicReductionLhsArg(), lhs);
995 moduleTranslation.mapValue(decl.getAtomicReductionRhsArg(), rhs);
996 builder.restoreIP(insertPoint);
998 if (failed(inlineConvertOmpRegions(decl.getAtomicReductionRegion(),
999 "omp.reduction.atomic.body", builder,
1000 moduleTranslation, &phis)))
1001 return llvm::createStringError(
1002 "failed to inline `atomic` region of `omp.declare_reduction`");
1003 assert(phis.empty());
1004 return builder.saveIP();
1005 };
1006 return atomicGen;
1007}
1008
1009/// Create an OpenMPIRBuilder-compatible `data_ptr_ptr` reduction generator for
1010/// the given reduction declaration. The generator uses `builder` but ignores
1011/// its insertion point. Returns null if there is no `data_ptr_ptr` region
1012/// available in the reduction declaration.
1013static OwningDataPtrPtrReductionGen
1014makeRefDataPtrGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder,
1015 LLVM::ModuleTranslation &moduleTranslation, bool isByRef) {
1016 if (!isByRef)
1017 return OwningDataPtrPtrReductionGen();
1018
1019 OwningDataPtrPtrReductionGen refDataPtrGen =
1020 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
1021 llvm::Value *byRefVal, llvm::Value *&result) mutable
1022 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1023 moduleTranslation.mapValue(decl.getDataPtrPtrRegionArg(), byRefVal);
1024 builder.restoreIP(insertPoint);
1026 if (failed(inlineConvertOmpRegions(decl.getDataPtrPtrRegion(),
1027 "omp.data_ptr_ptr.body", builder,
1028 moduleTranslation, &phis)))
1029 return llvm::createStringError(
1030 "failed to inline `data_ptr_ptr` region of `omp.declare_reduction`");
1031 result = llvm::getSingleElement(phis);
1032 return builder.saveIP();
1033 };
1034
1035 return refDataPtrGen;
1036}
1037
1038/// Converts an OpenMP 'ordered' operation into LLVM IR using OpenMPIRBuilder.
1039static LogicalResult
1040convertOmpOrdered(Operation &opInst, llvm::IRBuilderBase &builder,
1041 LLVM::ModuleTranslation &moduleTranslation) {
1042 auto orderedOp = cast<omp::OrderedOp>(opInst);
1043
1044 if (failed(checkImplementationStatus(opInst)))
1045 return failure();
1046
1047 omp::ClauseDepend dependType = *orderedOp.getDoacrossDependType();
1048 bool isDependSource = dependType == omp::ClauseDepend::dependsource;
1049 unsigned numLoops = *orderedOp.getDoacrossNumLoops();
1050 SmallVector<llvm::Value *> vecValues =
1051 moduleTranslation.lookupValues(orderedOp.getDoacrossDependVars());
1052
1053 size_t indexVecValues = 0;
1054 while (indexVecValues < vecValues.size()) {
1055 SmallVector<llvm::Value *> storeValues;
1056 storeValues.reserve(numLoops);
1057 for (unsigned i = 0; i < numLoops; i++) {
1058 storeValues.push_back(vecValues[indexVecValues]);
1059 indexVecValues++;
1060 }
1061 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1062 findAllocaInsertPoint(builder, moduleTranslation);
1063 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1064 builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createOrderedDepend(
1065 ompLoc, allocaIP, numLoops, storeValues, ".cnt.addr", isDependSource));
1066 }
1067 return success();
1068}
1069
1070/// Converts an OpenMP 'ordered_region' operation into LLVM IR using
1071/// OpenMPIRBuilder.
1072static LogicalResult
1073convertOmpOrderedRegion(Operation &opInst, llvm::IRBuilderBase &builder,
1074 LLVM::ModuleTranslation &moduleTranslation) {
1075 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1076 auto orderedRegionOp = cast<omp::OrderedRegionOp>(opInst);
1077
1078 if (failed(checkImplementationStatus(opInst)))
1079 return failure();
1080
1081 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
1082 // OrderedOp has only one region associated with it.
1083 auto &region = cast<omp::OrderedRegionOp>(opInst).getRegion();
1084 builder.restoreIP(codeGenIP);
1085 return convertOmpOpRegions(region, "omp.ordered.region", builder,
1086 moduleTranslation)
1087 .takeError();
1088 };
1089
1090 // TODO: Perform finalization actions for variables. This has to be
1091 // called for variables which have destructors/finalizers.
1092 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
1093
1094 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1095 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1096 moduleTranslation.getOpenMPBuilder()->createOrderedThreadsSimd(
1097 ompLoc, bodyGenCB, finiCB, !orderedRegionOp.getParLevelSimd());
1098
1099 if (failed(handleError(afterIP, opInst)))
1100 return failure();
1101
1102 builder.restoreIP(*afterIP);
1103 return success();
1104}
1105
1106namespace {
1107/// Contains the arguments for an LLVM store operation
1108struct DeferredStore {
1109 DeferredStore(llvm::Value *value, llvm::Value *address)
1110 : value(value), address(address) {}
1111
1112 llvm::Value *value;
1113 llvm::Value *address;
1114};
1115} // namespace
1116
1117/// Allocate space for privatized reduction variables.
1118/// `deferredStores` contains information to create store operations which needs
1119/// to be inserted after all allocas
1120template <typename T>
1121static LogicalResult
1123 llvm::IRBuilderBase &builder,
1124 LLVM::ModuleTranslation &moduleTranslation,
1125 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1127 SmallVectorImpl<llvm::Value *> &privateReductionVariables,
1128 DenseMap<Value, llvm::Value *> &reductionVariableMap,
1129 SmallVectorImpl<DeferredStore> &deferredStores,
1130 llvm::ArrayRef<bool> isByRefs) {
1131 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1132 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1133
1134 // delay creating stores until after all allocas
1135 deferredStores.reserve(loop.getNumReductionVars());
1136
1137 for (std::size_t i = 0; i < loop.getNumReductionVars(); ++i) {
1138 Region &allocRegion = reductionDecls[i].getAllocRegion();
1139 if (isByRefs[i]) {
1140 if (allocRegion.empty())
1141 continue;
1142
1144 if (failed(inlineConvertOmpRegions(allocRegion, "omp.reduction.alloc",
1145 builder, moduleTranslation, &phis)))
1146 return loop.emitError(
1147 "failed to inline `alloc` region of `omp.declare_reduction`");
1148
1149 assert(phis.size() == 1 && "expected one allocation to be yielded");
1150 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1151
1152 // Allocate reduction variable (which is a pointer to the real reduction
1153 // variable allocated in the inlined region)
1154 llvm::Value *var = builder.CreateAlloca(
1155 moduleTranslation.convertType(reductionDecls[i].getType()));
1156
1157 llvm::Type *ptrTy = builder.getPtrTy();
1158 llvm::Value *castVar =
1159 builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
1160 llvm::Value *castPhi =
1161 builder.CreatePointerBitCastOrAddrSpaceCast(phis[0], ptrTy);
1162
1163 deferredStores.emplace_back(castPhi, castVar);
1164
1165 privateReductionVariables[i] = castVar;
1166 moduleTranslation.mapValue(reductionArgs[i], castPhi);
1167 reductionVariableMap.try_emplace(loop.getReductionVars()[i], castPhi);
1168 } else {
1169 assert(allocRegion.empty() &&
1170 "allocaction is implicit for by-val reduction");
1171 llvm::Value *var = builder.CreateAlloca(
1172 moduleTranslation.convertType(reductionDecls[i].getType()));
1173
1174 llvm::Type *ptrTy = builder.getPtrTy();
1175 llvm::Value *castVar =
1176 builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
1177
1178 moduleTranslation.mapValue(reductionArgs[i], castVar);
1179 privateReductionVariables[i] = castVar;
1180 reductionVariableMap.try_emplace(loop.getReductionVars()[i], castVar);
1181 }
1182 }
1183
1184 return success();
1185}
1186
1187/// Map input arguments to reduction initialization region
1188template <typename T>
1189static void
1191 llvm::IRBuilderBase &builder,
1193 DenseMap<Value, llvm::Value *> &reductionVariableMap,
1194 unsigned i) {
1195 // map input argument to the initialization region
1196 mlir::omp::DeclareReductionOp &reduction = reductionDecls[i];
1197 Region &initializerRegion = reduction.getInitializerRegion();
1198 Block &entry = initializerRegion.front();
1199
1200 mlir::Value mlirSource = loop.getReductionVars()[i];
1201 llvm::Value *llvmSource = moduleTranslation.lookupValue(mlirSource);
1202 llvm::Value *origVal = llvmSource;
1203 // If a non-pointer value is expected, load the value from the source pointer.
1204 if (!isa<LLVM::LLVMPointerType>(
1205 reduction.getInitializerMoldArg().getType()) &&
1206 isa<LLVM::LLVMPointerType>(mlirSource.getType())) {
1207 origVal =
1208 builder.CreateLoad(moduleTranslation.convertType(
1209 reduction.getInitializerMoldArg().getType()),
1210 llvmSource, "omp_orig");
1211 }
1212 moduleTranslation.mapValue(reduction.getInitializerMoldArg(), origVal);
1213
1214 if (entry.getNumArguments() > 1) {
1215 llvm::Value *allocation =
1216 reductionVariableMap.lookup(loop.getReductionVars()[i]);
1217 moduleTranslation.mapValue(reduction.getInitializerAllocArg(), allocation);
1218 }
1219}
1220
1221static void
1222setInsertPointForPossiblyEmptyBlock(llvm::IRBuilderBase &builder,
1223 llvm::BasicBlock *block = nullptr) {
1224 if (block == nullptr)
1225 block = builder.GetInsertBlock();
1226
1227 if (block->empty() || block->getTerminator() == nullptr)
1228 builder.SetInsertPoint(block);
1229 else
1230 builder.SetInsertPoint(block->getTerminator());
1231}
1232
1233/// Inline reductions' `init` regions. This functions assumes that the
1234/// `builder`'s insertion point is where the user wants the `init` regions to be
1235/// inlined; i.e. it does not try to find a proper insertion location for the
1236/// `init` regions. It also leaves the `builder's insertions point in a state
1237/// where the user can continue the code-gen directly afterwards.
1238template <typename OP>
1239static LogicalResult
1240initReductionVars(OP op, ArrayRef<BlockArgument> reductionArgs,
1241 llvm::IRBuilderBase &builder,
1242 LLVM::ModuleTranslation &moduleTranslation,
1243 llvm::BasicBlock *latestAllocaBlock,
1245 SmallVectorImpl<llvm::Value *> &privateReductionVariables,
1246 DenseMap<Value, llvm::Value *> &reductionVariableMap,
1247 llvm::ArrayRef<bool> isByRef,
1248 SmallVectorImpl<DeferredStore> &deferredStores) {
1249 if (op.getNumReductionVars() == 0)
1250 return success();
1251
1252 llvm::BasicBlock *initBlock = splitBB(builder, true, "omp.reduction.init");
1253 auto allocaIP = llvm::IRBuilderBase::InsertPoint(
1254 latestAllocaBlock, latestAllocaBlock->getTerminator()->getIterator());
1255 builder.restoreIP(allocaIP);
1256 SmallVector<llvm::Value *> byRefVars(op.getNumReductionVars());
1257
1258 for (unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1259 if (isByRef[i]) {
1260 if (!reductionDecls[i].getAllocRegion().empty())
1261 continue;
1262
1263 // TODO: remove after all users of by-ref are updated to use the alloc
1264 // region: Allocate reduction variable (which is a pointer to the real
1265 // reduciton variable allocated in the inlined region)
1266 byRefVars[i] = builder.CreateAlloca(
1267 moduleTranslation.convertType(reductionDecls[i].getType()));
1268 }
1269 }
1270
1271 setInsertPointForPossiblyEmptyBlock(builder, initBlock);
1272
1273 // store result of the alloc region to the allocated pointer to the real
1274 // reduction variable
1275 for (auto [data, addr] : deferredStores)
1276 builder.CreateStore(data, addr);
1277
1278 // Before the loop, store the initial values of reductions into reduction
1279 // variables. Although this could be done after allocas, we don't want to mess
1280 // up with the alloca insertion point.
1281 for (unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1283
1284 // map block argument to initializer region
1285 mapInitializationArgs(op, moduleTranslation, builder, reductionDecls,
1286 reductionVariableMap, i);
1287
1288 // TODO In some cases (specially on the GPU), the init regions may
1289 // contains stack alloctaions. If the region is inlined in a loop, this is
1290 // problematic. Instead of just inlining the region, handle allocations by
1291 // hoisting fixed length allocations to the function entry and using
1292 // stacksave and restore for variable length ones.
1293 if (failed(inlineConvertOmpRegions(reductionDecls[i].getInitializerRegion(),
1294 "omp.reduction.neutral", builder,
1295 moduleTranslation, &phis)))
1296 return failure();
1297
1298 assert(phis.size() == 1 && "expected one value to be yielded from the "
1299 "reduction neutral element declaration region");
1300
1302
1303 if (isByRef[i]) {
1304 if (!reductionDecls[i].getAllocRegion().empty())
1305 // done in allocReductionVars
1306 continue;
1307
1308 // TODO: this path can be removed once all users of by-ref are updated to
1309 // use an alloc region
1310
1311 // Store the result of the inlined region to the allocated reduction var
1312 // ptr
1313 builder.CreateStore(phis[0], byRefVars[i]);
1314
1315 privateReductionVariables[i] = byRefVars[i];
1316 moduleTranslation.mapValue(reductionArgs[i], phis[0]);
1317 reductionVariableMap.try_emplace(op.getReductionVars()[i], phis[0]);
1318 } else {
1319 // for by-ref case the store is inside of the reduction region
1320 builder.CreateStore(phis[0], privateReductionVariables[i]);
1321 // the rest was handled in allocByValReductionVars
1322 }
1323
1324 // forget the mapping for the initializer region because we might need a
1325 // different mapping if this reduction declaration is re-used for a
1326 // different variable
1327 moduleTranslation.forgetMapping(reductionDecls[i].getInitializerRegion());
1328 }
1329
1330 return success();
1331}
1332
1333/// Collect reduction info
1334template <typename T>
1335static void collectReductionInfo(
1336 T loop, llvm::IRBuilderBase &builder,
1337 LLVM::ModuleTranslation &moduleTranslation,
1340 SmallVectorImpl<OwningAtomicReductionGen> &owningAtomicReductionGens,
1342 const ArrayRef<llvm::Value *> privateReductionVariables,
1344 ArrayRef<bool> isByRef) {
1345 unsigned numReductions = loop.getNumReductionVars();
1346
1347 for (unsigned i = 0; i < numReductions; ++i) {
1348 owningReductionGens.push_back(
1349 makeReductionGen(reductionDecls[i], builder, moduleTranslation));
1350 owningAtomicReductionGens.push_back(
1351 makeAtomicReductionGen(reductionDecls[i], builder, moduleTranslation));
1353 reductionDecls[i], builder, moduleTranslation, isByRef[i]));
1354 }
1355
1356 // Collect the reduction information.
1357 reductionInfos.reserve(numReductions);
1358 for (unsigned i = 0; i < numReductions; ++i) {
1359 llvm::OpenMPIRBuilder::ReductionGenAtomicCBTy atomicGen = nullptr;
1360 if (owningAtomicReductionGens[i])
1361 atomicGen = owningAtomicReductionGens[i];
1362 llvm::Value *variable =
1363 moduleTranslation.lookupValue(loop.getReductionVars()[i]);
1364 mlir::Type allocatedType;
1365 reductionDecls[i].getAllocRegion().walk([&](mlir::Operation *op) {
1366 if (auto alloca = mlir::dyn_cast<LLVM::AllocaOp>(op)) {
1367 allocatedType = alloca.getElemType();
1369 }
1370
1372 });
1373
1374 reductionInfos.push_back(
1375 {moduleTranslation.convertType(reductionDecls[i].getType()), variable,
1376 privateReductionVariables[i],
1377 /*EvaluationKind=*/llvm::OpenMPIRBuilder::EvalKind::Scalar,
1379 /*ReductionGenClang=*/nullptr, atomicGen,
1381 allocatedType ? moduleTranslation.convertType(allocatedType) : nullptr,
1382 reductionDecls[i].getByrefElementType()
1383 ? moduleTranslation.convertType(
1384 *reductionDecls[i].getByrefElementType())
1385 : nullptr});
1386 }
1387}
1388
1389/// handling of DeclareReductionOp's cleanup region
1390static LogicalResult
1392 llvm::ArrayRef<llvm::Value *> privateVariables,
1393 LLVM::ModuleTranslation &moduleTranslation,
1394 llvm::IRBuilderBase &builder, StringRef regionName,
1395 bool shouldLoadCleanupRegionArg = true) {
1396 for (auto [i, cleanupRegion] : llvm::enumerate(cleanupRegions)) {
1397 if (cleanupRegion->empty())
1398 continue;
1399
1400 // map the argument to the cleanup region
1401 Block &entry = cleanupRegion->front();
1402
1403 llvm::Instruction *potentialTerminator =
1404 builder.GetInsertBlock()->empty() ? nullptr
1405 : &builder.GetInsertBlock()->back();
1406 if (potentialTerminator && potentialTerminator->isTerminator())
1407 builder.SetInsertPoint(potentialTerminator);
1408 llvm::Value *privateVarValue =
1409 shouldLoadCleanupRegionArg
1410 ? builder.CreateLoad(
1411 moduleTranslation.convertType(entry.getArgument(0).getType()),
1412 privateVariables[i])
1413 : privateVariables[i];
1414
1415 moduleTranslation.mapValue(entry.getArgument(0), privateVarValue);
1416
1417 if (failed(inlineConvertOmpRegions(*cleanupRegion, regionName, builder,
1418 moduleTranslation)))
1419 return failure();
1420
1421 // clear block argument mapping in case it needs to be re-created with a
1422 // different source for another use of the same reduction decl
1423 moduleTranslation.forgetMapping(*cleanupRegion);
1424 }
1425 return success();
1426}
1427
1428// TODO: not used by ParallelOp
1429template <class OP>
1430static LogicalResult createReductionsAndCleanup(
1431 OP op, llvm::IRBuilderBase &builder,
1432 LLVM::ModuleTranslation &moduleTranslation,
1433 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1435 ArrayRef<llvm::Value *> privateReductionVariables, ArrayRef<bool> isByRef,
1436 bool isNowait = false, bool isTeamsReduction = false) {
1437 // Process the reductions if required.
1438 if (op.getNumReductionVars() == 0)
1439 return success();
1440
1442 SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
1443 SmallVector<OwningDataPtrPtrReductionGen> owningReductionGenRefDataPtrGens;
1445
1446 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
1447
1448 // Create the reduction generators. We need to own them here because
1449 // ReductionInfo only accepts references to the generators.
1450 collectReductionInfo(op, builder, moduleTranslation, reductionDecls,
1451 owningReductionGens, owningAtomicReductionGens,
1452 owningReductionGenRefDataPtrGens,
1453 privateReductionVariables, reductionInfos, isByRef);
1454
1455 // The call to createReductions below expects the block to have a
1456 // terminator. Create an unreachable instruction to serve as terminator
1457 // and remove it later.
1458 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
1459 builder.SetInsertPoint(tempTerminator);
1460 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
1461 ompBuilder->createReductions(builder.saveIP(), allocaIP, reductionInfos,
1462 isByRef, isNowait, isTeamsReduction);
1463
1464 if (failed(handleError(contInsertPoint, *op)))
1465 return failure();
1466
1467 if (!contInsertPoint->getBlock())
1468 return op->emitOpError() << "failed to convert reductions";
1469
1470 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1471 ompBuilder->createBarrier(*contInsertPoint, llvm::omp::OMPD_for);
1472
1473 if (failed(handleError(afterIP, *op)))
1474 return failure();
1475
1476 tempTerminator->eraseFromParent();
1477 builder.restoreIP(*afterIP);
1478
1479 // after the construct, deallocate private reduction variables
1480 SmallVector<Region *> reductionRegions;
1481 llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
1482 [](omp::DeclareReductionOp reductionDecl) {
1483 return &reductionDecl.getCleanupRegion();
1484 });
1485 return inlineOmpRegionCleanup(reductionRegions, privateReductionVariables,
1486 moduleTranslation, builder,
1487 "omp.reduction.cleanup");
1488 return success();
1489}
1490
1491static ArrayRef<bool> getIsByRef(std::optional<ArrayRef<bool>> attr) {
1492 if (!attr)
1493 return {};
1494 return *attr;
1495}
1496
1497// TODO: not used by omp.parallel
1498template <typename OP>
1500 OP op, ArrayRef<BlockArgument> reductionArgs, llvm::IRBuilderBase &builder,
1501 LLVM::ModuleTranslation &moduleTranslation,
1502 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1504 SmallVectorImpl<llvm::Value *> &privateReductionVariables,
1505 DenseMap<Value, llvm::Value *> &reductionVariableMap,
1506 llvm::ArrayRef<bool> isByRef) {
1507 if (op.getNumReductionVars() == 0)
1508 return success();
1509
1510 SmallVector<DeferredStore> deferredStores;
1511
1512 if (failed(allocReductionVars(op, reductionArgs, builder, moduleTranslation,
1513 allocaIP, reductionDecls,
1514 privateReductionVariables, reductionVariableMap,
1515 deferredStores, isByRef)))
1516 return failure();
1517
1518 return initReductionVars(op, reductionArgs, builder, moduleTranslation,
1519 allocaIP.getBlock(), reductionDecls,
1520 privateReductionVariables, reductionVariableMap,
1521 isByRef, deferredStores);
1522}
1523
1524/// Return the llvm::Value * corresponding to the `privateVar` that
1525/// is being privatized. It isn't always as simple as looking up
1526/// moduleTranslation with privateVar. For instance, in case of
1527/// an allocatable, the descriptor for the allocatable is privatized.
1528/// This descriptor is mapped using an MapInfoOp. So, this function
1529/// will return a pointer to the llvm::Value corresponding to the
1530/// block argument for the mapped descriptor.
1531static llvm::Value *
1532findAssociatedValue(Value privateVar, llvm::IRBuilderBase &builder,
1533 LLVM::ModuleTranslation &moduleTranslation,
1534 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1535 if (mappedPrivateVars == nullptr || !mappedPrivateVars->contains(privateVar))
1536 return moduleTranslation.lookupValue(privateVar);
1537
1538 Value blockArg = (*mappedPrivateVars)[privateVar];
1539 Type privVarType = privateVar.getType();
1540 Type blockArgType = blockArg.getType();
1541 assert(isa<LLVM::LLVMPointerType>(blockArgType) &&
1542 "A block argument corresponding to a mapped var should have "
1543 "!llvm.ptr type");
1544
1545 if (privVarType == blockArgType)
1546 return moduleTranslation.lookupValue(blockArg);
1547
1548 // This typically happens when the privatized type is lowered from
1549 // boxchar<KIND> and gets lowered to !llvm.struct<(ptr, i64)>. That is the
1550 // struct/pair is passed by value. But, mapped values are passed only as
1551 // pointers, so before we privatize, we must load the pointer.
1552 if (!isa<LLVM::LLVMPointerType>(privVarType))
1553 return builder.CreateLoad(moduleTranslation.convertType(privVarType),
1554 moduleTranslation.lookupValue(blockArg));
1555
1556 return moduleTranslation.lookupValue(privateVar);
1557}
1558
1559/// Initialize a single (first)private variable. You probably want to use
1560/// allocateAndInitPrivateVars instead of this.
1561/// This returns the private variable which has been initialized. This
1562/// variable should be mapped before constructing the body of the Op.
1564 llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation,
1565 omp::PrivateClauseOp &privDecl, Value mlirPrivVar, BlockArgument &blockArg,
1566 llvm::Value *llvmPrivateVar, llvm::BasicBlock *privInitBlock,
1567 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1568 Region &initRegion = privDecl.getInitRegion();
1569 if (initRegion.empty())
1570 return llvmPrivateVar;
1571
1572 // map initialization region block arguments
1573 llvm::Value *nonPrivateVar = findAssociatedValue(
1574 mlirPrivVar, builder, moduleTranslation, mappedPrivateVars);
1575 assert(nonPrivateVar);
1576 moduleTranslation.mapValue(privDecl.getInitMoldArg(), nonPrivateVar);
1577 moduleTranslation.mapValue(privDecl.getInitPrivateArg(), llvmPrivateVar);
1578
1579 // in-place convert the private initialization region
1581 if (failed(inlineConvertOmpRegions(initRegion, "omp.private.init", builder,
1582 moduleTranslation, &phis)))
1583 return llvm::createStringError(
1584 "failed to inline `init` region of `omp.private`");
1585
1586 assert(phis.size() == 1 && "expected one allocation to be yielded");
1587
1588 // clear init region block argument mapping in case it needs to be
1589 // re-created with a different source for another use of the same
1590 // reduction decl
1591 moduleTranslation.forgetMapping(initRegion);
1592
1593 // Prefer the value yielded from the init region to the allocated private
1594 // variable in case the region is operating on arguments by-value (e.g.
1595 // Fortran character boxes).
1596 return phis[0];
1597}
1598
1599static llvm::Error
1600initPrivateVars(llvm::IRBuilderBase &builder,
1601 LLVM::ModuleTranslation &moduleTranslation,
1602 PrivateVarsInfo &privateVarsInfo,
1603 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1604 if (privateVarsInfo.blockArgs.empty())
1605 return llvm::Error::success();
1606
1607 llvm::BasicBlock *privInitBlock = splitBB(builder, true, "omp.private.init");
1608 setInsertPointForPossiblyEmptyBlock(builder, privInitBlock);
1609
1610 for (auto [idx, zip] : llvm::enumerate(llvm::zip_equal(
1611 privateVarsInfo.privatizers, privateVarsInfo.mlirVars,
1612 privateVarsInfo.blockArgs, privateVarsInfo.llvmVars))) {
1613 auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVar] = zip;
1615 builder, moduleTranslation, privDecl, mlirPrivVar, blockArg,
1616 llvmPrivateVar, privInitBlock, mappedPrivateVars);
1617
1618 if (!privVarOrErr)
1619 return privVarOrErr.takeError();
1620
1621 llvmPrivateVar = privVarOrErr.get();
1622 moduleTranslation.mapValue(blockArg, llvmPrivateVar);
1623
1625 }
1626
1627 return llvm::Error::success();
1628}
1629
1630/// Allocate and initialize delayed private variables. Returns the basic block
1631/// which comes after all of these allocations. llvm::Value * for each of these
1632/// private variables are populated in llvmPrivateVars.
1634allocatePrivateVars(llvm::IRBuilderBase &builder,
1635 LLVM::ModuleTranslation &moduleTranslation,
1636 PrivateVarsInfo &privateVarsInfo,
1637 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1638 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1639 // Allocate private vars
1640 llvm::Instruction *allocaTerminator = allocaIP.getBlock()->getTerminator();
1641 splitBB(llvm::OpenMPIRBuilder::InsertPointTy(allocaIP.getBlock(),
1642 allocaTerminator->getIterator()),
1643 true, allocaTerminator->getStableDebugLoc(),
1644 "omp.region.after_alloca");
1645
1646 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1647 // Update the allocaTerminator since the alloca block was split above.
1648 allocaTerminator = allocaIP.getBlock()->getTerminator();
1649 builder.SetInsertPoint(allocaTerminator);
1650 // The new terminator is an uncondition branch created by the splitBB above.
1651 assert(allocaTerminator->getNumSuccessors() == 1 &&
1652 "This is an unconditional branch created by splitBB");
1653
1654 llvm::DataLayout dataLayout = builder.GetInsertBlock()->getDataLayout();
1655 llvm::BasicBlock *afterAllocas = allocaTerminator->getSuccessor(0);
1656
1657 unsigned int allocaAS =
1658 moduleTranslation.getLLVMModule()->getDataLayout().getAllocaAddrSpace();
1659 unsigned int defaultAS = moduleTranslation.getLLVMModule()
1660 ->getDataLayout()
1661 .getProgramAddressSpace();
1662
1663 for (auto [privDecl, mlirPrivVar, blockArg] :
1664 llvm::zip_equal(privateVarsInfo.privatizers, privateVarsInfo.mlirVars,
1665 privateVarsInfo.blockArgs)) {
1666 llvm::Type *llvmAllocType =
1667 moduleTranslation.convertType(privDecl.getType());
1668 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1669 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
1670 llvmAllocType, /*ArraySize=*/nullptr, "omp.private.alloc");
1671 if (allocaAS != defaultAS)
1672 llvmPrivateVar = builder.CreateAddrSpaceCast(llvmPrivateVar,
1673 builder.getPtrTy(defaultAS));
1674
1675 privateVarsInfo.llvmVars.push_back(llvmPrivateVar);
1676 }
1677
1678 return afterAllocas;
1679}
1680
1681static LogicalResult copyFirstPrivateVars(
1682 mlir::Operation *op, llvm::IRBuilderBase &builder,
1683 LLVM::ModuleTranslation &moduleTranslation,
1684 SmallVectorImpl<mlir::Value> &mlirPrivateVars,
1685 ArrayRef<llvm::Value *> llvmPrivateVars,
1686 SmallVectorImpl<omp::PrivateClauseOp> &privateDecls, bool insertBarrier,
1687 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1688 // Apply copy region for firstprivate.
1689 bool needsFirstprivate =
1690 llvm::any_of(privateDecls, [](omp::PrivateClauseOp &privOp) {
1691 return privOp.getDataSharingType() ==
1692 omp::DataSharingClauseType::FirstPrivate;
1693 });
1694
1695 if (!needsFirstprivate)
1696 return success();
1697
1698 llvm::BasicBlock *copyBlock =
1699 splitBB(builder, /*CreateBranch=*/true, "omp.private.copy");
1700 setInsertPointForPossiblyEmptyBlock(builder, copyBlock);
1701
1702 for (auto [decl, mlirVar, llvmVar] :
1703 llvm::zip_equal(privateDecls, mlirPrivateVars, llvmPrivateVars)) {
1704 if (decl.getDataSharingType() != omp::DataSharingClauseType::FirstPrivate)
1705 continue;
1706
1707 // copyRegion implements `lhs = rhs`
1708 Region &copyRegion = decl.getCopyRegion();
1709
1710 // map copyRegion rhs arg
1711 llvm::Value *nonPrivateVar = findAssociatedValue(
1712 mlirVar, builder, moduleTranslation, mappedPrivateVars);
1713 assert(nonPrivateVar);
1714 moduleTranslation.mapValue(decl.getCopyMoldArg(), nonPrivateVar);
1715
1716 // map copyRegion lhs arg
1717 moduleTranslation.mapValue(decl.getCopyPrivateArg(), llvmVar);
1718
1719 // in-place convert copy region
1720 if (failed(inlineConvertOmpRegions(copyRegion, "omp.private.copy", builder,
1721 moduleTranslation)))
1722 return decl.emitError("failed to inline `copy` region of `omp.private`");
1723
1725
1726 // ignore unused value yielded from copy region
1727
1728 // clear copy region block argument mapping in case it needs to be
1729 // re-created with different sources for reuse of the same reduction
1730 // decl
1731 moduleTranslation.forgetMapping(copyRegion);
1732 }
1733
1734 if (insertBarrier) {
1735 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
1736 llvm::OpenMPIRBuilder::InsertPointOrErrorTy res =
1737 ompBuilder->createBarrier(builder.saveIP(), llvm::omp::OMPD_barrier);
1738 if (failed(handleError(res, *op)))
1739 return failure();
1740 }
1741
1742 return success();
1743}
1744
1745static LogicalResult
1746cleanupPrivateVars(llvm::IRBuilderBase &builder,
1747 LLVM::ModuleTranslation &moduleTranslation, Location loc,
1748 SmallVectorImpl<llvm::Value *> &llvmPrivateVars,
1750 // private variable deallocation
1751 SmallVector<Region *> privateCleanupRegions;
1752 llvm::transform(privateDecls, std::back_inserter(privateCleanupRegions),
1753 [](omp::PrivateClauseOp privatizer) {
1754 return &privatizer.getDeallocRegion();
1755 });
1756
1757 if (failed(inlineOmpRegionCleanup(
1758 privateCleanupRegions, llvmPrivateVars, moduleTranslation, builder,
1759 "omp.private.dealloc", /*shouldLoadCleanupRegionArg=*/false)))
1760 return mlir::emitError(loc, "failed to inline `dealloc` region of an "
1761 "`omp.private` op in");
1762
1763 return success();
1764}
1765
1766/// Returns true if the construct contains omp.cancel or omp.cancellation_point
1768 // omp.cancel and omp.cancellation_point must be "closely nested" so they will
1769 // be visible and not inside of function calls. This is enforced by the
1770 // verifier.
1771 return op
1772 ->walk([](Operation *child) {
1773 if (mlir::isa<omp::CancelOp, omp::CancellationPointOp>(child))
1774 return WalkResult::interrupt();
1775 return WalkResult::advance();
1776 })
1777 .wasInterrupted();
1778}
1779
1780static LogicalResult
1781convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder,
1782 LLVM::ModuleTranslation &moduleTranslation) {
1783 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1784 using StorableBodyGenCallbackTy =
1785 llvm::OpenMPIRBuilder::StorableBodyGenCallbackTy;
1786
1787 auto sectionsOp = cast<omp::SectionsOp>(opInst);
1788
1789 if (failed(checkImplementationStatus(opInst)))
1790 return failure();
1791
1792 llvm::ArrayRef<bool> isByRef = getIsByRef(sectionsOp.getReductionByref());
1793 assert(isByRef.size() == sectionsOp.getNumReductionVars());
1794
1796 collectReductionDecls(sectionsOp, reductionDecls);
1797 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1798 findAllocaInsertPoint(builder, moduleTranslation);
1799
1800 SmallVector<llvm::Value *> privateReductionVariables(
1801 sectionsOp.getNumReductionVars());
1802 DenseMap<Value, llvm::Value *> reductionVariableMap;
1803
1804 MutableArrayRef<BlockArgument> reductionArgs =
1805 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
1806
1808 sectionsOp, reductionArgs, builder, moduleTranslation, allocaIP,
1809 reductionDecls, privateReductionVariables, reductionVariableMap,
1810 isByRef)))
1811 return failure();
1812
1814
1815 for (Operation &op : *sectionsOp.getRegion().begin()) {
1816 auto sectionOp = dyn_cast<omp::SectionOp>(op);
1817 if (!sectionOp) // omp.terminator
1818 continue;
1819
1820 Region &region = sectionOp.getRegion();
1821 auto sectionCB = [&sectionsOp, &region, &builder, &moduleTranslation](
1822 InsertPointTy allocaIP, InsertPointTy codeGenIP) {
1823 builder.restoreIP(codeGenIP);
1824
1825 // map the omp.section reduction block argument to the omp.sections block
1826 // arguments
1827 // TODO: this assumes that the only block arguments are reduction
1828 // variables
1829 assert(region.getNumArguments() ==
1830 sectionsOp.getRegion().getNumArguments());
1831 for (auto [sectionsArg, sectionArg] : llvm::zip_equal(
1832 sectionsOp.getRegion().getArguments(), region.getArguments())) {
1833 llvm::Value *llvmVal = moduleTranslation.lookupValue(sectionsArg);
1834 assert(llvmVal);
1835 moduleTranslation.mapValue(sectionArg, llvmVal);
1836 }
1837
1838 return convertOmpOpRegions(region, "omp.section.region", builder,
1839 moduleTranslation)
1840 .takeError();
1841 };
1842 sectionCBs.push_back(sectionCB);
1843 }
1844
1845 // No sections within omp.sections operation - skip generation. This situation
1846 // is only possible if there is only a terminator operation inside the
1847 // sections operation
1848 if (sectionCBs.empty())
1849 return success();
1850
1851 assert(isa<omp::SectionOp>(*sectionsOp.getRegion().op_begin()));
1852
1853 // TODO: Perform appropriate actions according to the data-sharing
1854 // attribute (shared, private, firstprivate, ...) of variables.
1855 // Currently defaults to shared.
1856 auto privCB = [&](InsertPointTy, InsertPointTy codeGenIP, llvm::Value &,
1857 llvm::Value &vPtr, llvm::Value *&replacementValue)
1858 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1859 replacementValue = &vPtr;
1860 return codeGenIP;
1861 };
1862
1863 // TODO: Perform finalization actions for variables. This has to be
1864 // called for variables which have destructors/finalizers.
1865 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
1866
1867 allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
1868 bool isCancellable = constructIsCancellable(sectionsOp);
1869 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1870 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1871 moduleTranslation.getOpenMPBuilder()->createSections(
1872 ompLoc, allocaIP, sectionCBs, privCB, finiCB, isCancellable,
1873 sectionsOp.getNowait());
1874
1875 if (failed(handleError(afterIP, opInst)))
1876 return failure();
1877
1878 builder.restoreIP(*afterIP);
1879
1880 // Process the reductions if required.
1882 sectionsOp, builder, moduleTranslation, allocaIP, reductionDecls,
1883 privateReductionVariables, isByRef, sectionsOp.getNowait());
1884}
1885
1886/// Converts an OpenMP single construct into LLVM IR using OpenMPIRBuilder.
1887static LogicalResult
1888convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder,
1889 LLVM::ModuleTranslation &moduleTranslation) {
1890 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1891 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1892
1893 if (failed(checkImplementationStatus(*singleOp)))
1894 return failure();
1895
1896 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
1897 builder.restoreIP(codegenIP);
1898 return convertOmpOpRegions(singleOp.getRegion(), "omp.single.region",
1899 builder, moduleTranslation)
1900 .takeError();
1901 };
1902 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
1903
1904 // Handle copyprivate
1905 Operation::operand_range cpVars = singleOp.getCopyprivateVars();
1906 std::optional<ArrayAttr> cpFuncs = singleOp.getCopyprivateSyms();
1909 for (size_t i = 0, e = cpVars.size(); i < e; ++i) {
1910 llvmCPVars.push_back(moduleTranslation.lookupValue(cpVars[i]));
1912 singleOp, cast<SymbolRefAttr>((*cpFuncs)[i]));
1913 llvmCPFuncs.push_back(
1914 moduleTranslation.lookupFunction(llvmFuncOp.getName()));
1915 }
1916
1917 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1918 moduleTranslation.getOpenMPBuilder()->createSingle(
1919 ompLoc, bodyCB, finiCB, singleOp.getNowait(), llvmCPVars,
1920 llvmCPFuncs);
1921
1922 if (failed(handleError(afterIP, *singleOp)))
1923 return failure();
1924
1925 builder.restoreIP(*afterIP);
1926 return success();
1927}
1928
1929static bool teamsReductionContainedInDistribute(omp::TeamsOp teamsOp) {
1930 auto iface =
1931 llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(teamsOp.getOperation());
1932 // Check that all uses of the reduction block arg has the same distribute op
1933 // parent.
1935 Operation *distOp = nullptr;
1936 for (auto ra : iface.getReductionBlockArgs())
1937 for (auto &use : ra.getUses()) {
1938 auto *useOp = use.getOwner();
1939 // Ignore debug uses.
1940 if (mlir::isa<LLVM::DbgDeclareOp, LLVM::DbgValueOp>(useOp)) {
1941 debugUses.push_back(useOp);
1942 continue;
1943 }
1944
1945 auto currentDistOp = useOp->getParentOfType<omp::DistributeOp>();
1946 // Use is not inside a distribute op - return false
1947 if (!currentDistOp)
1948 return false;
1949 // Multiple distribute operations - return false
1950 Operation *currentOp = currentDistOp.getOperation();
1951 if (distOp && (distOp != currentOp))
1952 return false;
1953
1954 distOp = currentOp;
1955 }
1956
1957 // If we are going to use distribute reduction then remove any debug uses of
1958 // the reduction parameters in teamsOp. Otherwise they will be left without
1959 // any mapped value in moduleTranslation and will eventually error out.
1960 for (auto *use : debugUses)
1961 use->erase();
1962 return true;
1963}
1964
1965// Convert an OpenMP Teams construct to LLVM IR using OpenMPIRBuilder
1966static LogicalResult
1967convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder,
1968 LLVM::ModuleTranslation &moduleTranslation) {
1969 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1970 if (failed(checkImplementationStatus(*op)))
1971 return failure();
1972
1973 DenseMap<Value, llvm::Value *> reductionVariableMap;
1974 unsigned numReductionVars = op.getNumReductionVars();
1976 SmallVector<llvm::Value *> privateReductionVariables(numReductionVars);
1977 llvm::ArrayRef<bool> isByRef;
1978 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1979 findAllocaInsertPoint(builder, moduleTranslation);
1980
1981 // Only do teams reduction if there is no distribute op that captures the
1982 // reduction instead.
1983 bool doTeamsReduction = !teamsReductionContainedInDistribute(op);
1984 if (doTeamsReduction) {
1985 isByRef = getIsByRef(op.getReductionByref());
1986
1987 assert(isByRef.size() == op.getNumReductionVars());
1988
1989 MutableArrayRef<BlockArgument> reductionArgs =
1990 llvm::cast<omp::BlockArgOpenMPOpInterface>(*op).getReductionBlockArgs();
1991
1992 collectReductionDecls(op, reductionDecls);
1993
1995 op, reductionArgs, builder, moduleTranslation, allocaIP,
1996 reductionDecls, privateReductionVariables, reductionVariableMap,
1997 isByRef)))
1998 return failure();
1999 }
2000
2001 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
2003 moduleTranslation, allocaIP);
2004 builder.restoreIP(codegenIP);
2005 return convertOmpOpRegions(op.getRegion(), "omp.teams.region", builder,
2006 moduleTranslation)
2007 .takeError();
2008 };
2009
2010 llvm::Value *numTeamsLower = nullptr;
2011 if (Value numTeamsLowerVar = op.getNumTeamsLower())
2012 numTeamsLower = moduleTranslation.lookupValue(numTeamsLowerVar);
2013
2014 llvm::Value *numTeamsUpper = nullptr;
2015 if (Value numTeamsUpperVar = op.getNumTeamsUpper())
2016 numTeamsUpper = moduleTranslation.lookupValue(numTeamsUpperVar);
2017
2018 llvm::Value *threadLimit = nullptr;
2019 if (Value threadLimitVar = op.getThreadLimit())
2020 threadLimit = moduleTranslation.lookupValue(threadLimitVar);
2021
2022 llvm::Value *ifExpr = nullptr;
2023 if (Value ifVar = op.getIfExpr())
2024 ifExpr = moduleTranslation.lookupValue(ifVar);
2025
2026 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2027 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2028 moduleTranslation.getOpenMPBuilder()->createTeams(
2029 ompLoc, bodyCB, numTeamsLower, numTeamsUpper, threadLimit, ifExpr);
2030
2031 if (failed(handleError(afterIP, *op)))
2032 return failure();
2033
2034 builder.restoreIP(*afterIP);
2035 if (doTeamsReduction) {
2036 // Process the reductions if required.
2038 op, builder, moduleTranslation, allocaIP, reductionDecls,
2039 privateReductionVariables, isByRef,
2040 /*isNoWait*/ false, /*isTeamsReduction*/ true);
2041 }
2042 return success();
2043}
2044
2045static void
2046buildDependData(std::optional<ArrayAttr> dependKinds, OperandRange dependVars,
2047 LLVM::ModuleTranslation &moduleTranslation,
2049 if (dependVars.empty())
2050 return;
2051 for (auto dep : llvm::zip(dependVars, dependKinds->getValue())) {
2052 llvm::omp::RTLDependenceKindTy type;
2053 switch (
2054 cast<mlir::omp::ClauseTaskDependAttr>(std::get<1>(dep)).getValue()) {
2055 case mlir::omp::ClauseTaskDepend::taskdependin:
2056 type = llvm::omp::RTLDependenceKindTy::DepIn;
2057 break;
2058 // The OpenMP runtime requires that the codegen for 'depend' clause for
2059 // 'out' dependency kind must be the same as codegen for 'depend' clause
2060 // with 'inout' dependency.
2061 case mlir::omp::ClauseTaskDepend::taskdependout:
2062 case mlir::omp::ClauseTaskDepend::taskdependinout:
2063 type = llvm::omp::RTLDependenceKindTy::DepInOut;
2064 break;
2065 case mlir::omp::ClauseTaskDepend::taskdependmutexinoutset:
2066 type = llvm::omp::RTLDependenceKindTy::DepMutexInOutSet;
2067 break;
2068 case mlir::omp::ClauseTaskDepend::taskdependinoutset:
2069 type = llvm::omp::RTLDependenceKindTy::DepInOutSet;
2070 break;
2071 };
2072 llvm::Value *depVal = moduleTranslation.lookupValue(std::get<0>(dep));
2073 llvm::OpenMPIRBuilder::DependData dd(type, depVal->getType(), depVal);
2074 dds.emplace_back(dd);
2075 }
2076}
2077
2078/// Shared implementation of a callback which adds a termiator for the new block
2079/// created for the branch taken when an openmp construct is cancelled. The
2080/// terminator is saved in \p cancelTerminators. This callback is invoked only
2081/// if there is cancellation inside of the taskgroup body.
2082/// The terminator will need to be fixed to branch to the correct block to
2083/// cleanup the construct.
2084static void
2086 llvm::IRBuilderBase &llvmBuilder,
2087 llvm::OpenMPIRBuilder &ompBuilder, mlir::Operation *op,
2088 llvm::omp::Directive cancelDirective) {
2089 auto finiCB = [&](llvm::OpenMPIRBuilder::InsertPointTy ip) -> llvm::Error {
2090 llvm::IRBuilderBase::InsertPointGuard guard(llvmBuilder);
2091
2092 // ip is currently in the block branched to if cancellation occured.
2093 // We need to create a branch to terminate that block.
2094 llvmBuilder.restoreIP(ip);
2095
2096 // We must still clean up the construct after cancelling it, so we need to
2097 // branch to the block that finalizes the taskgroup.
2098 // That block has not been created yet so use this block as a dummy for now
2099 // and fix this after creating the operation.
2100 cancelTerminators.push_back(llvmBuilder.CreateBr(ip.getBlock()));
2101 return llvm::Error::success();
2102 };
2103 // We have to add the cleanup to the OpenMPIRBuilder before the body gets
2104 // created in case the body contains omp.cancel (which will then expect to be
2105 // able to find this cleanup callback).
2106 ompBuilder.pushFinalizationCB(
2107 {finiCB, cancelDirective, constructIsCancellable(op)});
2108}
2109
2110/// If we cancelled the construct, we should branch to the finalization block of
2111/// that construct. OMPIRBuilder structures the CFG such that the cleanup block
2112/// is immediately before the continuation block. Now this finalization has
2113/// been created we can fix the branch.
2114static void
2116 llvm::OpenMPIRBuilder &ompBuilder,
2117 const llvm::OpenMPIRBuilder::InsertPointTy &afterIP) {
2118 ompBuilder.popFinalizationCB();
2119 llvm::BasicBlock *constructFini = afterIP.getBlock()->getSinglePredecessor();
2120 for (llvm::BranchInst *cancelBranch : cancelTerminators) {
2121 assert(cancelBranch->getNumSuccessors() == 1 &&
2122 "cancel branch should have one target");
2123 cancelBranch->setSuccessor(0, constructFini);
2124 }
2125}
2126
2127namespace {
2128/// TaskContextStructManager takes care of creating and freeing a structure
2129/// containing information needed by the task body to execute.
2130class TaskContextStructManager {
2131public:
2132 TaskContextStructManager(llvm::IRBuilderBase &builder,
2133 LLVM::ModuleTranslation &moduleTranslation,
2134 MutableArrayRef<omp::PrivateClauseOp> privateDecls)
2135 : builder{builder}, moduleTranslation{moduleTranslation},
2136 privateDecls{privateDecls} {}
2137
2138 /// Creates a heap allocated struct containing space for each private
2139 /// variable. Invariant: privateVarTypes, privateDecls, and the elements of
2140 /// the structure should all have the same order (although privateDecls which
2141 /// do not read from the mold argument are skipped).
2142 void generateTaskContextStruct();
2143
2144 /// Create GEPs to access each member of the structure representing a private
2145 /// variable, adding them to llvmPrivateVars. Null values are added where
2146 /// private decls were skipped so that the ordering continues to match the
2147 /// private decls.
2148 void createGEPsToPrivateVars();
2149
2150 /// De-allocate the task context structure.
2151 void freeStructPtr();
2152
2153 MutableArrayRef<llvm::Value *> getLLVMPrivateVarGEPs() {
2154 return llvmPrivateVarGEPs;
2155 }
2156
2157 llvm::Value *getStructPtr() { return structPtr; }
2158
2159private:
2160 llvm::IRBuilderBase &builder;
2161 LLVM::ModuleTranslation &moduleTranslation;
2162 MutableArrayRef<omp::PrivateClauseOp> privateDecls;
2163
2164 /// The type of each member of the structure, in order.
2165 SmallVector<llvm::Type *> privateVarTypes;
2166
2167 /// LLVM values for each private variable, or null if that private variable is
2168 /// not included in the task context structure
2169 SmallVector<llvm::Value *> llvmPrivateVarGEPs;
2170
2171 /// A pointer to the structure containing context for this task.
2172 llvm::Value *structPtr = nullptr;
2173 /// The type of the structure
2174 llvm::Type *structTy = nullptr;
2175};
2176} // namespace
2177
2178void TaskContextStructManager::generateTaskContextStruct() {
2179 if (privateDecls.empty())
2180 return;
2181 privateVarTypes.reserve(privateDecls.size());
2182
2183 for (omp::PrivateClauseOp &privOp : privateDecls) {
2184 // Skip private variables which can safely be allocated and initialised
2185 // inside of the task
2186 if (!privOp.readsFromMold())
2187 continue;
2188 Type mlirType = privOp.getType();
2189 privateVarTypes.push_back(moduleTranslation.convertType(mlirType));
2190 }
2191
2192 structTy = llvm::StructType::get(moduleTranslation.getLLVMContext(),
2193 privateVarTypes);
2194
2195 llvm::DataLayout dataLayout =
2196 builder.GetInsertBlock()->getModule()->getDataLayout();
2197 llvm::Type *intPtrTy = builder.getIntPtrTy(dataLayout);
2198 llvm::Constant *allocSize = llvm::ConstantExpr::getSizeOf(structTy);
2199
2200 // Heap allocate the structure
2201 structPtr = builder.CreateMalloc(intPtrTy, structTy, allocSize,
2202 /*ArraySize=*/nullptr, /*MallocF=*/nullptr,
2203 "omp.task.context_ptr");
2204}
2205
2206void TaskContextStructManager::createGEPsToPrivateVars() {
2207 if (!structPtr) {
2208 assert(privateVarTypes.empty());
2209 return;
2210 }
2211
2212 // Create GEPs for each struct member
2213 llvmPrivateVarGEPs.clear();
2214 llvmPrivateVarGEPs.reserve(privateDecls.size());
2215 llvm::Value *zero = builder.getInt32(0);
2216 unsigned i = 0;
2217 for (auto privDecl : privateDecls) {
2218 if (!privDecl.readsFromMold()) {
2219 // Handle this inside of the task so we don't pass unnessecary vars in
2220 llvmPrivateVarGEPs.push_back(nullptr);
2221 continue;
2222 }
2223 llvm::Value *iVal = builder.getInt32(i);
2224 llvm::Value *gep = builder.CreateGEP(structTy, structPtr, {zero, iVal});
2225 llvmPrivateVarGEPs.push_back(gep);
2226 i += 1;
2227 }
2228}
2229
2230void TaskContextStructManager::freeStructPtr() {
2231 if (!structPtr)
2232 return;
2233
2234 llvm::IRBuilderBase::InsertPointGuard guard{builder};
2235 // Ensure we don't put the call to free() after the terminator
2236 builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
2237 builder.CreateFree(structPtr);
2238}
2239
2240/// Converts an OpenMP task construct into LLVM IR using OpenMPIRBuilder.
2241static LogicalResult
2242convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
2243 LLVM::ModuleTranslation &moduleTranslation) {
2244 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2245 if (failed(checkImplementationStatus(*taskOp)))
2246 return failure();
2247
2248 PrivateVarsInfo privateVarsInfo(taskOp);
2249 TaskContextStructManager taskStructMgr{builder, moduleTranslation,
2250 privateVarsInfo.privatizers};
2251
2252 // Allocate and copy private variables before creating the task. This avoids
2253 // accessing invalid memory if (after this scope ends) the private variables
2254 // are initialized from host variables or if the variables are copied into
2255 // from host variables (firstprivate). The insertion point is just before
2256 // where the code for creating and scheduling the task will go. That puts this
2257 // code outside of the outlined task region, which is what we want because
2258 // this way the initialization and copy regions are executed immediately while
2259 // the host variable data are still live.
2260
2261 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2262 findAllocaInsertPoint(builder, moduleTranslation);
2263
2264 // Not using splitBB() because that requires the current block to have a
2265 // terminator.
2266 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end());
2267 llvm::BasicBlock *taskStartBlock = llvm::BasicBlock::Create(
2268 builder.getContext(), "omp.task.start",
2269 /*Parent=*/builder.GetInsertBlock()->getParent());
2270 llvm::Instruction *branchToTaskStartBlock = builder.CreateBr(taskStartBlock);
2271 builder.SetInsertPoint(branchToTaskStartBlock);
2272
2273 // Now do this again to make the initialization and copy blocks
2274 llvm::BasicBlock *copyBlock =
2275 splitBB(builder, /*CreateBranch=*/true, "omp.private.copy");
2276 llvm::BasicBlock *initBlock =
2277 splitBB(builder, /*CreateBranch=*/true, "omp.private.init");
2278
2279 // Now the control flow graph should look like
2280 // starter_block:
2281 // <---- where we started when convertOmpTaskOp was called
2282 // br %omp.private.init
2283 // omp.private.init:
2284 // br %omp.private.copy
2285 // omp.private.copy:
2286 // br %omp.task.start
2287 // omp.task.start:
2288 // <---- where we want the insertion point to be when we call createTask()
2289
2290 // Save the alloca insertion point on ModuleTranslation stack for use in
2291 // nested regions.
2293 moduleTranslation, allocaIP);
2294
2295 // Allocate and initialize private variables
2296 builder.SetInsertPoint(initBlock->getTerminator());
2297
2298 // Create task variable structure
2299 taskStructMgr.generateTaskContextStruct();
2300 // GEPs so that we can initialize the variables. Don't use these GEPs inside
2301 // of the body otherwise it will be the GEP not the struct which is fowarded
2302 // to the outlined function. GEPs forwarded in this way are passed in a
2303 // stack-allocated (by OpenMPIRBuilder) structure which is not safe for tasks
2304 // which may not be executed until after the current stack frame goes out of
2305 // scope.
2306 taskStructMgr.createGEPsToPrivateVars();
2307
2308 for (auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVarAlloc] :
2309 llvm::zip_equal(privateVarsInfo.privatizers, privateVarsInfo.mlirVars,
2310 privateVarsInfo.blockArgs,
2311 taskStructMgr.getLLVMPrivateVarGEPs())) {
2312 // To be handled inside the task.
2313 if (!privDecl.readsFromMold())
2314 continue;
2315 assert(llvmPrivateVarAlloc &&
2316 "reads from mold so shouldn't have been skipped");
2317
2318 llvm::Expected<llvm::Value *> privateVarOrErr =
2319 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2320 blockArg, llvmPrivateVarAlloc, initBlock);
2321 if (!privateVarOrErr)
2322 return handleError(privateVarOrErr, *taskOp.getOperation());
2323
2325
2326 // TODO: this is a bit of a hack for Fortran character boxes.
2327 // Character boxes are passed by value into the init region and then the
2328 // initialized character box is yielded by value. Here we need to store the
2329 // yielded value into the private allocation, and load the private
2330 // allocation to match the type expected by region block arguments.
2331 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
2332 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2333 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
2334 // Load it so we have the value pointed to by the GEP
2335 llvmPrivateVarAlloc = builder.CreateLoad(privateVarOrErr.get()->getType(),
2336 llvmPrivateVarAlloc);
2337 }
2338 assert(llvmPrivateVarAlloc->getType() ==
2339 moduleTranslation.convertType(blockArg.getType()));
2340
2341 // Mapping blockArg -> llvmPrivateVarAlloc is done inside the body callback
2342 // so that OpenMPIRBuilder doesn't try to pass each GEP address through a
2343 // stack allocated structure.
2344 }
2345
2346 // firstprivate copy region
2347 setInsertPointForPossiblyEmptyBlock(builder, copyBlock);
2348 if (failed(copyFirstPrivateVars(
2349 taskOp, builder, moduleTranslation, privateVarsInfo.mlirVars,
2350 taskStructMgr.getLLVMPrivateVarGEPs(), privateVarsInfo.privatizers,
2351 taskOp.getPrivateNeedsBarrier())))
2352 return llvm::failure();
2353
2354 // Set up for call to createTask()
2355 builder.SetInsertPoint(taskStartBlock);
2356
2357 auto bodyCB = [&](InsertPointTy allocaIP,
2358 InsertPointTy codegenIP) -> llvm::Error {
2359 // Save the alloca insertion point on ModuleTranslation stack for use in
2360 // nested regions.
2362 moduleTranslation, allocaIP);
2363
2364 // translate the body of the task:
2365 builder.restoreIP(codegenIP);
2366
2367 llvm::BasicBlock *privInitBlock = nullptr;
2368 privateVarsInfo.llvmVars.resize(privateVarsInfo.blockArgs.size());
2369 for (auto [i, zip] : llvm::enumerate(llvm::zip_equal(
2370 privateVarsInfo.blockArgs, privateVarsInfo.privatizers,
2371 privateVarsInfo.mlirVars))) {
2372 auto [blockArg, privDecl, mlirPrivVar] = zip;
2373 // This is handled before the task executes
2374 if (privDecl.readsFromMold())
2375 continue;
2376
2377 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2378 llvm::Type *llvmAllocType =
2379 moduleTranslation.convertType(privDecl.getType());
2380 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
2381 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
2382 llvmAllocType, /*ArraySize=*/nullptr, "omp.private.alloc");
2383
2384 llvm::Expected<llvm::Value *> privateVarOrError =
2385 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2386 blockArg, llvmPrivateVar, privInitBlock);
2387 if (!privateVarOrError)
2388 return privateVarOrError.takeError();
2389 moduleTranslation.mapValue(blockArg, privateVarOrError.get());
2390 privateVarsInfo.llvmVars[i] = privateVarOrError.get();
2391 }
2392
2393 taskStructMgr.createGEPsToPrivateVars();
2394 for (auto [i, llvmPrivVar] :
2395 llvm::enumerate(taskStructMgr.getLLVMPrivateVarGEPs())) {
2396 if (!llvmPrivVar) {
2397 assert(privateVarsInfo.llvmVars[i] &&
2398 "This is added in the loop above");
2399 continue;
2400 }
2401 privateVarsInfo.llvmVars[i] = llvmPrivVar;
2402 }
2403
2404 // Find and map the addresses of each variable within the task context
2405 // structure
2406 for (auto [blockArg, llvmPrivateVar, privateDecl] :
2407 llvm::zip_equal(privateVarsInfo.blockArgs, privateVarsInfo.llvmVars,
2408 privateVarsInfo.privatizers)) {
2409 // This was handled above.
2410 if (!privateDecl.readsFromMold())
2411 continue;
2412 // Fix broken pass-by-value case for Fortran character boxes
2413 if (!mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2414 llvmPrivateVar = builder.CreateLoad(
2415 moduleTranslation.convertType(blockArg.getType()), llvmPrivateVar);
2416 }
2417 assert(llvmPrivateVar->getType() ==
2418 moduleTranslation.convertType(blockArg.getType()));
2419 moduleTranslation.mapValue(blockArg, llvmPrivateVar);
2420 }
2421
2422 auto continuationBlockOrError = convertOmpOpRegions(
2423 taskOp.getRegion(), "omp.task.region", builder, moduleTranslation);
2424 if (failed(handleError(continuationBlockOrError, *taskOp)))
2425 return llvm::make_error<PreviouslyReportedError>();
2426
2427 builder.SetInsertPoint(continuationBlockOrError.get()->getTerminator());
2428
2429 if (failed(cleanupPrivateVars(builder, moduleTranslation, taskOp.getLoc(),
2430 privateVarsInfo.llvmVars,
2431 privateVarsInfo.privatizers)))
2432 return llvm::make_error<PreviouslyReportedError>();
2433
2434 // Free heap allocated task context structure at the end of the task.
2435 taskStructMgr.freeStructPtr();
2436
2437 return llvm::Error::success();
2438 };
2439
2440 llvm::OpenMPIRBuilder &ompBuilder = *moduleTranslation.getOpenMPBuilder();
2441 SmallVector<llvm::BranchInst *> cancelTerminators;
2442 // The directive to match here is OMPD_taskgroup because it is the taskgroup
2443 // which is canceled. This is handled here because it is the task's cleanup
2444 // block which should be branched to.
2445 pushCancelFinalizationCB(cancelTerminators, builder, ompBuilder, taskOp,
2446 llvm::omp::Directive::OMPD_taskgroup);
2447
2449 buildDependData(taskOp.getDependKinds(), taskOp.getDependVars(),
2450 moduleTranslation, dds);
2451
2452 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2453 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2454 moduleTranslation.getOpenMPBuilder()->createTask(
2455 ompLoc, allocaIP, bodyCB, !taskOp.getUntied(),
2456 moduleTranslation.lookupValue(taskOp.getFinal()),
2457 moduleTranslation.lookupValue(taskOp.getIfExpr()), dds,
2458 taskOp.getMergeable(),
2459 moduleTranslation.lookupValue(taskOp.getEventHandle()),
2460 moduleTranslation.lookupValue(taskOp.getPriority()));
2461
2462 if (failed(handleError(afterIP, *taskOp)))
2463 return failure();
2464
2465 // Set the correct branch target for task cancellation
2466 popCancelFinalizationCB(cancelTerminators, ompBuilder, afterIP.get());
2467
2468 builder.restoreIP(*afterIP);
2469 return success();
2470}
2471
2472/// Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder.
2473static LogicalResult
2474convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder,
2475 LLVM::ModuleTranslation &moduleTranslation) {
2476 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2477 if (failed(checkImplementationStatus(*tgOp)))
2478 return failure();
2479
2480 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
2481 builder.restoreIP(codegenIP);
2482 return convertOmpOpRegions(tgOp.getRegion(), "omp.taskgroup.region",
2483 builder, moduleTranslation)
2484 .takeError();
2485 };
2486
2487 InsertPointTy allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
2488 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2489 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2490 moduleTranslation.getOpenMPBuilder()->createTaskgroup(ompLoc, allocaIP,
2491 bodyCB);
2492
2493 if (failed(handleError(afterIP, *tgOp)))
2494 return failure();
2495
2496 builder.restoreIP(*afterIP);
2497 return success();
2498}
2499
2500static LogicalResult
2501convertOmpTaskwaitOp(omp::TaskwaitOp twOp, llvm::IRBuilderBase &builder,
2502 LLVM::ModuleTranslation &moduleTranslation) {
2503 if (failed(checkImplementationStatus(*twOp)))
2504 return failure();
2505
2506 moduleTranslation.getOpenMPBuilder()->createTaskwait(builder.saveIP());
2507 return success();
2508}
2509
2510/// Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
2511static LogicalResult
2512convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
2513 LLVM::ModuleTranslation &moduleTranslation) {
2514 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
2515 auto wsloopOp = cast<omp::WsloopOp>(opInst);
2516 if (failed(checkImplementationStatus(opInst)))
2517 return failure();
2518
2519 auto loopOp = cast<omp::LoopNestOp>(wsloopOp.getWrappedLoop());
2520 llvm::ArrayRef<bool> isByRef = getIsByRef(wsloopOp.getReductionByref());
2521 assert(isByRef.size() == wsloopOp.getNumReductionVars());
2522
2523 // Static is the default.
2524 auto schedule =
2525 wsloopOp.getScheduleKind().value_or(omp::ClauseScheduleKind::Static);
2526
2527 // Find the loop configuration.
2528 llvm::Value *step = moduleTranslation.lookupValue(loopOp.getLoopSteps()[0]);
2529 llvm::Type *ivType = step->getType();
2530 llvm::Value *chunk = nullptr;
2531 if (wsloopOp.getScheduleChunk()) {
2532 llvm::Value *chunkVar =
2533 moduleTranslation.lookupValue(wsloopOp.getScheduleChunk());
2534 chunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
2535 }
2536
2537 omp::DistributeOp distributeOp = nullptr;
2538 llvm::Value *distScheduleChunk = nullptr;
2539 bool hasDistSchedule = false;
2540 if (llvm::isa_and_present<omp::DistributeOp>(opInst.getParentOp())) {
2541 distributeOp = cast<omp::DistributeOp>(opInst.getParentOp());
2542 hasDistSchedule = distributeOp.getDistScheduleStatic();
2543 if (distributeOp.getDistScheduleChunkSize()) {
2544 llvm::Value *chunkVar = moduleTranslation.lookupValue(
2545 distributeOp.getDistScheduleChunkSize());
2546 distScheduleChunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
2547 }
2548 }
2549
2550 PrivateVarsInfo privateVarsInfo(wsloopOp);
2551
2553 collectReductionDecls(wsloopOp, reductionDecls);
2554 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2555 findAllocaInsertPoint(builder, moduleTranslation);
2556
2557 SmallVector<llvm::Value *> privateReductionVariables(
2558 wsloopOp.getNumReductionVars());
2559
2561 builder, moduleTranslation, privateVarsInfo, allocaIP);
2562 if (handleError(afterAllocas, opInst).failed())
2563 return failure();
2564
2565 DenseMap<Value, llvm::Value *> reductionVariableMap;
2566
2567 MutableArrayRef<BlockArgument> reductionArgs =
2568 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
2569
2570 SmallVector<DeferredStore> deferredStores;
2571
2572 if (failed(allocReductionVars(wsloopOp, reductionArgs, builder,
2573 moduleTranslation, allocaIP, reductionDecls,
2574 privateReductionVariables, reductionVariableMap,
2575 deferredStores, isByRef)))
2576 return failure();
2577
2578 if (handleError(initPrivateVars(builder, moduleTranslation, privateVarsInfo),
2579 opInst)
2580 .failed())
2581 return failure();
2582
2583 if (failed(copyFirstPrivateVars(
2584 wsloopOp, builder, moduleTranslation, privateVarsInfo.mlirVars,
2585 privateVarsInfo.llvmVars, privateVarsInfo.privatizers,
2586 wsloopOp.getPrivateNeedsBarrier())))
2587 return failure();
2588
2589 assert(afterAllocas.get()->getSinglePredecessor());
2590 if (failed(initReductionVars(wsloopOp, reductionArgs, builder,
2591 moduleTranslation,
2592 afterAllocas.get()->getSinglePredecessor(),
2593 reductionDecls, privateReductionVariables,
2594 reductionVariableMap, isByRef, deferredStores)))
2595 return failure();
2596
2597 // TODO: Handle doacross loops when the ordered clause has a parameter.
2598 bool isOrdered = wsloopOp.getOrdered().has_value();
2599 std::optional<omp::ScheduleModifier> scheduleMod = wsloopOp.getScheduleMod();
2600 bool isSimd = wsloopOp.getScheduleSimd();
2601 bool loopNeedsBarrier = !wsloopOp.getNowait();
2602
2603 // The only legal way for the direct parent to be omp.distribute is that this
2604 // represents 'distribute parallel do'. Otherwise, this is a regular
2605 // worksharing loop.
2606 llvm::omp::WorksharingLoopType workshareLoopType =
2607 llvm::isa_and_present<omp::DistributeOp>(opInst.getParentOp())
2608 ? llvm::omp::WorksharingLoopType::DistributeForStaticLoop
2609 : llvm::omp::WorksharingLoopType::ForStaticLoop;
2610
2611 SmallVector<llvm::BranchInst *> cancelTerminators;
2612 pushCancelFinalizationCB(cancelTerminators, builder, *ompBuilder, wsloopOp,
2613 llvm::omp::Directive::OMPD_for);
2614
2615 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2616
2617 // Initialize linear variables and linear step
2618 LinearClauseProcessor linearClauseProcessor;
2619
2620 if (!wsloopOp.getLinearVars().empty()) {
2621 auto linearVarTypes = wsloopOp.getLinearVarTypes().value();
2622 for (mlir::Attribute linearVarType : linearVarTypes)
2623 linearClauseProcessor.registerType(moduleTranslation, linearVarType);
2624
2625 for (auto [idx, linearVar] : llvm::enumerate(wsloopOp.getLinearVars()))
2626 linearClauseProcessor.createLinearVar(builder, moduleTranslation,
2627 linearVar, idx);
2628 for (mlir::Value linearStep : wsloopOp.getLinearStepVars())
2629 linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
2630 }
2631
2633 wsloopOp.getRegion(), "omp.wsloop.region", builder, moduleTranslation);
2634
2635 if (failed(handleError(regionBlock, opInst)))
2636 return failure();
2637
2638 llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
2639
2640 // Emit Initialization and Update IR for linear variables
2641 if (!wsloopOp.getLinearVars().empty()) {
2642 linearClauseProcessor.initLinearVar(builder, moduleTranslation,
2643 loopInfo->getPreheader());
2644 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
2645 moduleTranslation.getOpenMPBuilder()->createBarrier(
2646 builder.saveIP(), llvm::omp::OMPD_barrier);
2647 if (failed(handleError(afterBarrierIP, *loopOp)))
2648 return failure();
2649 builder.restoreIP(*afterBarrierIP);
2650 linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
2651 loopInfo->getIndVar());
2652 linearClauseProcessor.splitLinearFiniBB(builder, loopInfo->getExit());
2653 }
2654
2655 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
2656
2657 // Check if we can generate no-loop kernel
2658 bool noLoopMode = false;
2659 omp::TargetOp targetOp = wsloopOp->getParentOfType<mlir::omp::TargetOp>();
2660 if (targetOp) {
2661 Operation *targetCapturedOp = targetOp.getInnermostCapturedOmpOp();
2662 // We need this check because, without it, noLoopMode would be set to true
2663 // for every omp.wsloop nested inside a no-loop SPMD target region, even if
2664 // that loop is not the top-level SPMD one.
2665 if (loopOp == targetCapturedOp) {
2666 omp::TargetRegionFlags kernelFlags =
2667 targetOp.getKernelExecFlags(targetCapturedOp);
2668 if (omp::bitEnumContainsAll(kernelFlags,
2669 omp::TargetRegionFlags::spmd |
2670 omp::TargetRegionFlags::no_loop) &&
2671 !omp::bitEnumContainsAny(kernelFlags,
2672 omp::TargetRegionFlags::generic))
2673 noLoopMode = true;
2674 }
2675 }
2676
2677 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
2678 ompBuilder->applyWorkshareLoop(
2679 ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
2680 convertToScheduleKind(schedule), chunk, isSimd,
2681 scheduleMod == omp::ScheduleModifier::monotonic,
2682 scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
2683 workshareLoopType, noLoopMode, hasDistSchedule, distScheduleChunk);
2684
2685 if (failed(handleError(wsloopIP, opInst)))
2686 return failure();
2687
2688 // Emit finalization and in-place rewrites for linear vars.
2689 if (!wsloopOp.getLinearVars().empty()) {
2690 llvm::OpenMPIRBuilder::InsertPointTy oldIP = builder.saveIP();
2691 assert(loopInfo->getLastIter() &&
2692 "`lastiter` in CanonicalLoopInfo is nullptr");
2693 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
2694 linearClauseProcessor.finalizeLinearVar(builder, moduleTranslation,
2695 loopInfo->getLastIter());
2696 if (failed(handleError(afterBarrierIP, *loopOp)))
2697 return failure();
2698 for (size_t index = 0; index < wsloopOp.getLinearVars().size(); index++)
2699 linearClauseProcessor.rewriteInPlace(builder, "omp.loop_nest.region",
2700 index);
2701 builder.restoreIP(oldIP);
2702 }
2703
2704 // Set the correct branch target for task cancellation
2705 popCancelFinalizationCB(cancelTerminators, *ompBuilder, wsloopIP.get());
2706
2707 // Process the reductions if required.
2708 if (failed(createReductionsAndCleanup(
2709 wsloopOp, builder, moduleTranslation, allocaIP, reductionDecls,
2710 privateReductionVariables, isByRef, wsloopOp.getNowait(),
2711 /*isTeamsReduction=*/false)))
2712 return failure();
2713
2714 return cleanupPrivateVars(builder, moduleTranslation, wsloopOp.getLoc(),
2715 privateVarsInfo.llvmVars,
2716 privateVarsInfo.privatizers);
2717}
2718
2719/// Converts the OpenMP parallel operation to LLVM IR.
2720static LogicalResult
2721convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
2722 LLVM::ModuleTranslation &moduleTranslation) {
2723 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2724 ArrayRef<bool> isByRef = getIsByRef(opInst.getReductionByref());
2725 assert(isByRef.size() == opInst.getNumReductionVars());
2726 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
2727 bool isCancellable = constructIsCancellable(opInst);
2728
2729 if (failed(checkImplementationStatus(*opInst)))
2730 return failure();
2731
2732 PrivateVarsInfo privateVarsInfo(opInst);
2733
2734 // Collect reduction declarations
2736 collectReductionDecls(opInst, reductionDecls);
2737 SmallVector<llvm::Value *> privateReductionVariables(
2738 opInst.getNumReductionVars());
2739 SmallVector<DeferredStore> deferredStores;
2740
2741 auto bodyGenCB = [&](InsertPointTy allocaIP,
2742 InsertPointTy codeGenIP) -> llvm::Error {
2744 builder, moduleTranslation, privateVarsInfo, allocaIP);
2745 if (handleError(afterAllocas, *opInst).failed())
2746 return llvm::make_error<PreviouslyReportedError>();
2747
2748 // Allocate reduction vars
2749 DenseMap<Value, llvm::Value *> reductionVariableMap;
2750
2751 MutableArrayRef<BlockArgument> reductionArgs =
2752 cast<omp::BlockArgOpenMPOpInterface>(*opInst).getReductionBlockArgs();
2753
2754 allocaIP =
2755 InsertPointTy(allocaIP.getBlock(),
2756 allocaIP.getBlock()->getTerminator()->getIterator());
2757
2758 if (failed(allocReductionVars(
2759 opInst, reductionArgs, builder, moduleTranslation, allocaIP,
2760 reductionDecls, privateReductionVariables, reductionVariableMap,
2761 deferredStores, isByRef)))
2762 return llvm::make_error<PreviouslyReportedError>();
2763
2764 assert(afterAllocas.get()->getSinglePredecessor());
2765 builder.restoreIP(codeGenIP);
2766
2767 if (handleError(
2768 initPrivateVars(builder, moduleTranslation, privateVarsInfo),
2769 *opInst)
2770 .failed())
2771 return llvm::make_error<PreviouslyReportedError>();
2772
2773 if (failed(copyFirstPrivateVars(
2774 opInst, builder, moduleTranslation, privateVarsInfo.mlirVars,
2775 privateVarsInfo.llvmVars, privateVarsInfo.privatizers,
2776 opInst.getPrivateNeedsBarrier())))
2777 return llvm::make_error<PreviouslyReportedError>();
2778
2779 if (failed(
2780 initReductionVars(opInst, reductionArgs, builder, moduleTranslation,
2781 afterAllocas.get()->getSinglePredecessor(),
2782 reductionDecls, privateReductionVariables,
2783 reductionVariableMap, isByRef, deferredStores)))
2784 return llvm::make_error<PreviouslyReportedError>();
2785
2786 // Save the alloca insertion point on ModuleTranslation stack for use in
2787 // nested regions.
2789 moduleTranslation, allocaIP);
2790
2791 // ParallelOp has only one region associated with it.
2793 opInst.getRegion(), "omp.par.region", builder, moduleTranslation);
2794 if (!regionBlock)
2795 return regionBlock.takeError();
2796
2797 // Process the reductions if required.
2798 if (opInst.getNumReductionVars() > 0) {
2799 // Collect reduction info
2801 SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
2803 owningReductionGenRefDataPtrGens;
2805 collectReductionInfo(opInst, builder, moduleTranslation, reductionDecls,
2806 owningReductionGens, owningAtomicReductionGens,
2807 owningReductionGenRefDataPtrGens,
2808 privateReductionVariables, reductionInfos, isByRef);
2809
2810 // Move to region cont block
2811 builder.SetInsertPoint((*regionBlock)->getTerminator());
2812
2813 // Generate reductions from info
2814 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
2815 builder.SetInsertPoint(tempTerminator);
2816
2817 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
2818 ompBuilder->createReductions(
2819 builder.saveIP(), allocaIP, reductionInfos, isByRef,
2820 /*IsNoWait=*/false, /*IsTeamsReduction=*/false);
2821 if (!contInsertPoint)
2822 return contInsertPoint.takeError();
2823
2824 if (!contInsertPoint->getBlock())
2825 return llvm::make_error<PreviouslyReportedError>();
2826
2827 tempTerminator->eraseFromParent();
2828 builder.restoreIP(*contInsertPoint);
2829 }
2830
2831 return llvm::Error::success();
2832 };
2833
2834 auto privCB = [](InsertPointTy allocaIP, InsertPointTy codeGenIP,
2835 llvm::Value &, llvm::Value &val, llvm::Value *&replVal) {
2836 // tell OpenMPIRBuilder not to do anything. We handled Privatisation in
2837 // bodyGenCB.
2838 replVal = &val;
2839 return codeGenIP;
2840 };
2841
2842 // TODO: Perform finalization actions for variables. This has to be
2843 // called for variables which have destructors/finalizers.
2844 auto finiCB = [&](InsertPointTy codeGenIP) -> llvm::Error {
2845 InsertPointTy oldIP = builder.saveIP();
2846 builder.restoreIP(codeGenIP);
2847
2848 // if the reduction has a cleanup region, inline it here to finalize the
2849 // reduction variables
2850 SmallVector<Region *> reductionCleanupRegions;
2851 llvm::transform(reductionDecls, std::back_inserter(reductionCleanupRegions),
2852 [](omp::DeclareReductionOp reductionDecl) {
2853 return &reductionDecl.getCleanupRegion();
2854 });
2855 if (failed(inlineOmpRegionCleanup(
2856 reductionCleanupRegions, privateReductionVariables,
2857 moduleTranslation, builder, "omp.reduction.cleanup")))
2858 return llvm::createStringError(
2859 "failed to inline `cleanup` region of `omp.declare_reduction`");
2860
2861 if (failed(cleanupPrivateVars(builder, moduleTranslation, opInst.getLoc(),
2862 privateVarsInfo.llvmVars,
2863 privateVarsInfo.privatizers)))
2864 return llvm::make_error<PreviouslyReportedError>();
2865
2866 // If we could be performing cancellation, add the cancellation barrier on
2867 // the way out of the outlined region.
2868 if (isCancellable) {
2869 auto IPOrErr = ompBuilder->createBarrier(
2870 llvm::OpenMPIRBuilder::LocationDescription(builder),
2871 llvm::omp::Directive::OMPD_unknown,
2872 /* ForceSimpleCall */ false,
2873 /* CheckCancelFlag */ false);
2874 if (!IPOrErr)
2875 return IPOrErr.takeError();
2876 }
2877
2878 builder.restoreIP(oldIP);
2879 return llvm::Error::success();
2880 };
2881
2882 llvm::Value *ifCond = nullptr;
2883 if (auto ifVar = opInst.getIfExpr())
2884 ifCond = moduleTranslation.lookupValue(ifVar);
2885 llvm::Value *numThreads = nullptr;
2886 if (auto numThreadsVar = opInst.getNumThreads())
2887 numThreads = moduleTranslation.lookupValue(numThreadsVar);
2888 auto pbKind = llvm::omp::OMP_PROC_BIND_default;
2889 if (auto bind = opInst.getProcBindKind())
2890 pbKind = getProcBindKind(*bind);
2891
2892 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2893 findAllocaInsertPoint(builder, moduleTranslation);
2894 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2895
2896 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2897 ompBuilder->createParallel(ompLoc, allocaIP, bodyGenCB, privCB, finiCB,
2898 ifCond, numThreads, pbKind, isCancellable);
2899
2900 if (failed(handleError(afterIP, *opInst)))
2901 return failure();
2902
2903 builder.restoreIP(*afterIP);
2904 return success();
2905}
2906
2907/// Convert Order attribute to llvm::omp::OrderKind.
2908static llvm::omp::OrderKind
2909convertOrderKind(std::optional<omp::ClauseOrderKind> o) {
2910 if (!o)
2911 return llvm::omp::OrderKind::OMP_ORDER_unknown;
2912 switch (*o) {
2913 case omp::ClauseOrderKind::Concurrent:
2914 return llvm::omp::OrderKind::OMP_ORDER_concurrent;
2915 }
2916 llvm_unreachable("Unknown ClauseOrderKind kind");
2917}
2918
2919/// Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder.
2920static LogicalResult
2921convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
2922 LLVM::ModuleTranslation &moduleTranslation) {
2923 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
2924 auto simdOp = cast<omp::SimdOp>(opInst);
2925
2926 if (failed(checkImplementationStatus(opInst)))
2927 return failure();
2928
2929 PrivateVarsInfo privateVarsInfo(simdOp);
2930
2931 MutableArrayRef<BlockArgument> reductionArgs =
2932 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
2933 DenseMap<Value, llvm::Value *> reductionVariableMap;
2934 SmallVector<llvm::Value *> privateReductionVariables(
2935 simdOp.getNumReductionVars());
2936 SmallVector<DeferredStore> deferredStores;
2938 collectReductionDecls(simdOp, reductionDecls);
2939 llvm::ArrayRef<bool> isByRef = getIsByRef(simdOp.getReductionByref());
2940 assert(isByRef.size() == simdOp.getNumReductionVars());
2941
2942 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2943 findAllocaInsertPoint(builder, moduleTranslation);
2944
2945 // Initialize linear variables and linear step
2946 LinearClauseProcessor linearClauseProcessor;
2947
2948 if (!simdOp.getLinearVars().empty()) {
2949 auto linearVarTypes = simdOp.getLinearVarTypes().value();
2950 for (mlir::Attribute linearVarType : linearVarTypes)
2951 linearClauseProcessor.registerType(moduleTranslation, linearVarType);
2952 for (auto [idx, linearVar] : llvm::enumerate(simdOp.getLinearVars()))
2953 linearClauseProcessor.createLinearVar(builder, moduleTranslation,
2954 linearVar, idx);
2955 for (mlir::Value linearStep : simdOp.getLinearStepVars())
2956 linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
2957 }
2958
2960 builder, moduleTranslation, privateVarsInfo, allocaIP);
2961 if (handleError(afterAllocas, opInst).failed())
2962 return failure();
2963
2964 if (failed(allocReductionVars(simdOp, reductionArgs, builder,
2965 moduleTranslation, allocaIP, reductionDecls,
2966 privateReductionVariables, reductionVariableMap,
2967 deferredStores, isByRef)))
2968 return failure();
2969
2970 if (handleError(initPrivateVars(builder, moduleTranslation, privateVarsInfo),
2971 opInst)
2972 .failed())
2973 return failure();
2974
2975 // No call to copyFirstPrivateVars because FIRSTPRIVATE is not allowed for
2976 // SIMD.
2977
2978 assert(afterAllocas.get()->getSinglePredecessor());
2979 if (failed(initReductionVars(simdOp, reductionArgs, builder,
2980 moduleTranslation,
2981 afterAllocas.get()->getSinglePredecessor(),
2982 reductionDecls, privateReductionVariables,
2983 reductionVariableMap, isByRef, deferredStores)))
2984 return failure();
2985
2986 llvm::ConstantInt *simdlen = nullptr;
2987 if (std::optional<uint64_t> simdlenVar = simdOp.getSimdlen())
2988 simdlen = builder.getInt64(simdlenVar.value());
2989
2990 llvm::ConstantInt *safelen = nullptr;
2991 if (std::optional<uint64_t> safelenVar = simdOp.getSafelen())
2992 safelen = builder.getInt64(safelenVar.value());
2993
2994 llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars;
2995 llvm::omp::OrderKind order = convertOrderKind(simdOp.getOrder());
2996
2997 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
2998 std::optional<ArrayAttr> alignmentValues = simdOp.getAlignments();
2999 mlir::OperandRange operands = simdOp.getAlignedVars();
3000 for (size_t i = 0; i < operands.size(); ++i) {
3001 llvm::Value *alignment = nullptr;
3002 llvm::Value *llvmVal = moduleTranslation.lookupValue(operands[i]);
3003 llvm::Type *ty = llvmVal->getType();
3004
3005 auto intAttr = cast<IntegerAttr>((*alignmentValues)[i]);
3006 alignment = builder.getInt64(intAttr.getInt());
3007 assert(ty->isPointerTy() && "Invalid type for aligned variable");
3008 assert(alignment && "Invalid alignment value");
3009
3010 // Check if the alignment value is not a power of 2. If so, skip emitting
3011 // alignment.
3012 if (!intAttr.getValue().isPowerOf2())
3013 continue;
3014
3015 auto curInsert = builder.saveIP();
3016 builder.SetInsertPoint(sourceBlock);
3017 llvmVal = builder.CreateLoad(ty, llvmVal);
3018 builder.restoreIP(curInsert);
3019 alignedVars[llvmVal] = alignment;
3020 }
3021
3023 simdOp.getRegion(), "omp.simd.region", builder, moduleTranslation);
3024
3025 if (failed(handleError(regionBlock, opInst)))
3026 return failure();
3027
3028 llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
3029 // Emit Initialization for linear variables
3030 if (simdOp.getLinearVars().size()) {
3031 linearClauseProcessor.initLinearVar(builder, moduleTranslation,
3032 loopInfo->getPreheader());
3033
3034 linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
3035 loopInfo->getIndVar());
3036 }
3037 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
3038
3039 ompBuilder->applySimd(loopInfo, alignedVars,
3040 simdOp.getIfExpr()
3041 ? moduleTranslation.lookupValue(simdOp.getIfExpr())
3042 : nullptr,
3043 order, simdlen, safelen);
3044
3045 for (size_t index = 0; index < simdOp.getLinearVars().size(); index++)
3046 linearClauseProcessor.rewriteInPlace(builder, "omp.loop_nest.region",
3047 index);
3048
3049 // We now need to reduce the per-simd-lane reduction variable into the
3050 // original variable. This works a bit differently to other reductions (e.g.
3051 // wsloop) because we don't need to call into the OpenMP runtime to handle
3052 // threads: everything happened in this one thread.
3053 for (auto [i, tuple] : llvm::enumerate(
3054 llvm::zip(reductionDecls, isByRef, simdOp.getReductionVars(),
3055 privateReductionVariables))) {
3056 auto [decl, byRef, reductionVar, privateReductionVar] = tuple;
3057
3058 OwningReductionGen gen = makeReductionGen(decl, builder, moduleTranslation);
3059 llvm::Value *originalVariable = moduleTranslation.lookupValue(reductionVar);
3060 llvm::Type *reductionType = moduleTranslation.convertType(decl.getType());
3061
3062 // We have one less load for by-ref case because that load is now inside of
3063 // the reduction region.
3064 llvm::Value *redValue = originalVariable;
3065 if (!byRef)
3066 redValue =
3067 builder.CreateLoad(reductionType, redValue, "red.value." + Twine(i));
3068 llvm::Value *privateRedValue = builder.CreateLoad(
3069 reductionType, privateReductionVar, "red.private.value." + Twine(i));
3070 llvm::Value *reduced;
3071
3072 auto res = gen(builder.saveIP(), redValue, privateRedValue, reduced);
3073 if (failed(handleError(res, opInst)))
3074 return failure();
3075 builder.restoreIP(res.get());
3076
3077 // For by-ref case, the store is inside of the reduction region.
3078 if (!byRef)
3079 builder.CreateStore(reduced, originalVariable);
3080 }
3081
3082 // After the construct, deallocate private reduction variables.
3083 SmallVector<Region *> reductionRegions;
3084 llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
3085 [](omp::DeclareReductionOp reductionDecl) {
3086 return &reductionDecl.getCleanupRegion();
3087 });
3088 if (failed(inlineOmpRegionCleanup(reductionRegions, privateReductionVariables,
3089 moduleTranslation, builder,
3090 "omp.reduction.cleanup")))
3091 return failure();
3092
3093 return cleanupPrivateVars(builder, moduleTranslation, simdOp.getLoc(),
3094 privateVarsInfo.llvmVars,
3095 privateVarsInfo.privatizers);
3096}
3097
3098/// Converts an OpenMP loop nest into LLVM IR using OpenMPIRBuilder.
3099static LogicalResult
3100convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder,
3101 LLVM::ModuleTranslation &moduleTranslation) {
3102 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3103 auto loopOp = cast<omp::LoopNestOp>(opInst);
3104
3105 // Set up the source location value for OpenMP runtime.
3106 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3107
3108 // Generator of the canonical loop body.
3111 auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip,
3112 llvm::Value *iv) -> llvm::Error {
3113 // Make sure further conversions know about the induction variable.
3114 moduleTranslation.mapValue(
3115 loopOp.getRegion().front().getArgument(loopInfos.size()), iv);
3116
3117 // Capture the body insertion point for use in nested loops. BodyIP of the
3118 // CanonicalLoopInfo always points to the beginning of the entry block of
3119 // the body.
3120 bodyInsertPoints.push_back(ip);
3121
3122 if (loopInfos.size() != loopOp.getNumLoops() - 1)
3123 return llvm::Error::success();
3124
3125 // Convert the body of the loop.
3126 builder.restoreIP(ip);
3128 loopOp.getRegion(), "omp.loop_nest.region", builder, moduleTranslation);
3129 if (!regionBlock)
3130 return regionBlock.takeError();
3131
3132 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
3133 return llvm::Error::success();
3134 };
3135
3136 // Delegate actual loop construction to the OpenMP IRBuilder.
3137 // TODO: this currently assumes omp.loop_nest is semantically similar to SCF
3138 // loop, i.e. it has a positive step, uses signed integer semantics.
3139 // Reconsider this code when the nested loop operation clearly supports more
3140 // cases.
3141 for (unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) {
3142 llvm::Value *lowerBound =
3143 moduleTranslation.lookupValue(loopOp.getLoopLowerBounds()[i]);
3144 llvm::Value *upperBound =
3145 moduleTranslation.lookupValue(loopOp.getLoopUpperBounds()[i]);
3146 llvm::Value *step = moduleTranslation.lookupValue(loopOp.getLoopSteps()[i]);
3147
3148 // Make sure loop trip count are emitted in the preheader of the outermost
3149 // loop at the latest so that they are all available for the new collapsed
3150 // loop will be created below.
3151 llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc;
3152 llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP;
3153 if (i != 0) {
3154 loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back(),
3155 ompLoc.DL);
3156 computeIP = loopInfos.front()->getPreheaderIP();
3157 }
3158
3160 ompBuilder->createCanonicalLoop(
3161 loc, bodyGen, lowerBound, upperBound, step,
3162 /*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP);
3163
3164 if (failed(handleError(loopResult, *loopOp)))
3165 return failure();
3166
3167 loopInfos.push_back(*loopResult);
3168 }
3169
3170 llvm::OpenMPIRBuilder::InsertPointTy afterIP =
3171 loopInfos.front()->getAfterIP();
3172
3173 // Do tiling.
3174 if (const auto &tiles = loopOp.getTileSizes()) {
3175 llvm::Type *ivType = loopInfos.front()->getIndVarType();
3177
3178 for (auto tile : tiles.value()) {
3179 llvm::Value *tileVal = llvm::ConstantInt::get(ivType, tile);
3180 tileSizes.push_back(tileVal);
3181 }
3182
3183 std::vector<llvm::CanonicalLoopInfo *> newLoops =
3184 ompBuilder->tileLoops(ompLoc.DL, loopInfos, tileSizes);
3185
3186 // Update afterIP to get the correct insertion point after
3187 // tiling.
3188 llvm::BasicBlock *afterBB = newLoops.front()->getAfter();
3189 llvm::BasicBlock *afterAfterBB = afterBB->getSingleSuccessor();
3190 afterIP = {afterAfterBB, afterAfterBB->begin()};
3191
3192 // Update the loop infos.
3193 loopInfos.clear();
3194 for (const auto &newLoop : newLoops)
3195 loopInfos.push_back(newLoop);
3196 } // Tiling done.
3197
3198 // Do collapse.
3199 const auto &numCollapse = loopOp.getCollapseNumLoops();
3201 loopInfos.begin(), loopInfos.begin() + (numCollapse));
3202
3203 auto newTopLoopInfo =
3204 ompBuilder->collapseLoops(ompLoc.DL, collapseLoopInfos, {});
3205
3206 assert(newTopLoopInfo && "New top loop information is missing");
3207 moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
3208 [&](OpenMPLoopInfoStackFrame &frame) {
3209 frame.loopInfo = newTopLoopInfo;
3210 return WalkResult::interrupt();
3211 });
3212
3213 // Continue building IR after the loop. Note that the LoopInfo returned by
3214 // `collapseLoops` points inside the outermost loop and is intended for
3215 // potential further loop transformations. Use the insertion point stored
3216 // before collapsing loops instead.
3217 builder.restoreIP(afterIP);
3218 return success();
3219}
3220
3221/// Convert an omp.canonical_loop to LLVM-IR
3222static LogicalResult
3223convertOmpCanonicalLoopOp(omp::CanonicalLoopOp op, llvm::IRBuilderBase &builder,
3224 LLVM::ModuleTranslation &moduleTranslation) {
3225 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3226
3227 llvm::OpenMPIRBuilder::LocationDescription loopLoc(builder);
3228 Value loopIV = op.getInductionVar();
3229 Value loopTC = op.getTripCount();
3230
3231 llvm::Value *llvmTC = moduleTranslation.lookupValue(loopTC);
3232
3234 ompBuilder->createCanonicalLoop(
3235 loopLoc,
3236 [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *llvmIV) {
3237 // Register the mapping of MLIR induction variable to LLVM-IR
3238 // induction variable
3239 moduleTranslation.mapValue(loopIV, llvmIV);
3240
3241 builder.restoreIP(ip);
3243 convertOmpOpRegions(op.getRegion(), "omp.loop.region", builder,
3244 moduleTranslation);
3245
3246 return bodyGenStatus.takeError();
3247 },
3248 llvmTC, "omp.loop");
3249 if (!llvmOrError)
3250 return op.emitError(llvm::toString(llvmOrError.takeError()));
3251
3252 llvm::CanonicalLoopInfo *llvmCLI = *llvmOrError;
3253 llvm::IRBuilderBase::InsertPoint afterIP = llvmCLI->getAfterIP();
3254 builder.restoreIP(afterIP);
3255
3256 // Register the mapping of MLIR loop to LLVM-IR OpenMPIRBuilder loop
3257 if (Value cli = op.getCli())
3258 moduleTranslation.mapOmpLoop(cli, llvmCLI);
3259
3260 return success();
3261}
3262
3263/// Apply a `#pragma omp unroll` / "!$omp unroll" transformation using the
3264/// OpenMPIRBuilder.
3265static LogicalResult
3266applyUnrollHeuristic(omp::UnrollHeuristicOp op, llvm::IRBuilderBase &builder,
3267 LLVM::ModuleTranslation &moduleTranslation) {
3268 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3269
3270 Value applyee = op.getApplyee();
3271 assert(applyee && "Loop to apply unrolling on required");
3272
3273 llvm::CanonicalLoopInfo *consBuilderCLI =
3274 moduleTranslation.lookupOMPLoop(applyee);
3275 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
3276 ompBuilder->unrollLoopHeuristic(loc.DL, consBuilderCLI);
3277
3278 moduleTranslation.invalidateOmpLoop(applyee);
3279 return success();
3280}
3281
3282/// Apply a `#pragma omp tile` / `!$omp tile` transformation using the
3283/// OpenMPIRBuilder.
3284static LogicalResult applyTile(omp::TileOp op, llvm::IRBuilderBase &builder,
3285 LLVM::ModuleTranslation &moduleTranslation) {
3286 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3287 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
3288
3290 SmallVector<llvm::Value *> translatedSizes;
3291
3292 for (Value size : op.getSizes()) {
3293 llvm::Value *translatedSize = moduleTranslation.lookupValue(size);
3294 assert(translatedSize &&
3295 "sizes clause arguments must already be translated");
3296 translatedSizes.push_back(translatedSize);
3297 }
3298
3299 for (Value applyee : op.getApplyees()) {
3300 llvm::CanonicalLoopInfo *consBuilderCLI =
3301 moduleTranslation.lookupOMPLoop(applyee);
3302 assert(applyee && "Canonical loop must already been translated");
3303 translatedLoops.push_back(consBuilderCLI);
3304 }
3305
3306 auto generatedLoops =
3307 ompBuilder->tileLoops(loc.DL, translatedLoops, translatedSizes);
3308 if (!op.getGeneratees().empty()) {
3309 for (auto [mlirLoop, genLoop] :
3310 zip_equal(op.getGeneratees(), generatedLoops))
3311 moduleTranslation.mapOmpLoop(mlirLoop, genLoop);
3312 }
3313
3314 // CLIs can only be consumed once
3315 for (Value applyee : op.getApplyees())
3316 moduleTranslation.invalidateOmpLoop(applyee);
3317
3318 return success();
3319}
3320
3321/// Convert an Atomic Ordering attribute to llvm::AtomicOrdering.
3322static llvm::AtomicOrdering
3323convertAtomicOrdering(std::optional<omp::ClauseMemoryOrderKind> ao) {
3324 if (!ao)
3325 return llvm::AtomicOrdering::Monotonic; // Default Memory Ordering
3326
3327 switch (*ao) {
3328 case omp::ClauseMemoryOrderKind::Seq_cst:
3329 return llvm::AtomicOrdering::SequentiallyConsistent;
3330 case omp::ClauseMemoryOrderKind::Acq_rel:
3331 return llvm::AtomicOrdering::AcquireRelease;
3332 case omp::ClauseMemoryOrderKind::Acquire:
3333 return llvm::AtomicOrdering::Acquire;
3334 case omp::ClauseMemoryOrderKind::Release:
3335 return llvm::AtomicOrdering::Release;
3336 case omp::ClauseMemoryOrderKind::Relaxed:
3337 return llvm::AtomicOrdering::Monotonic;
3338 }
3339 llvm_unreachable("Unknown ClauseMemoryOrderKind kind");
3340}
3341
3342/// Convert omp.atomic.read operation to LLVM IR.
3343static LogicalResult
3344convertOmpAtomicRead(Operation &opInst, llvm::IRBuilderBase &builder,
3345 LLVM::ModuleTranslation &moduleTranslation) {
3346 auto readOp = cast<omp::AtomicReadOp>(opInst);
3347 if (failed(checkImplementationStatus(opInst)))
3348 return failure();
3349
3350 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3351 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3352 findAllocaInsertPoint(builder, moduleTranslation);
3353
3354 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3355
3356 llvm::AtomicOrdering AO = convertAtomicOrdering(readOp.getMemoryOrder());
3357 llvm::Value *x = moduleTranslation.lookupValue(readOp.getX());
3358 llvm::Value *v = moduleTranslation.lookupValue(readOp.getV());
3359
3360 llvm::Type *elementType =
3361 moduleTranslation.convertType(readOp.getElementType());
3362
3363 llvm::OpenMPIRBuilder::AtomicOpValue V = {v, elementType, false, false};
3364 llvm::OpenMPIRBuilder::AtomicOpValue X = {x, elementType, false, false};
3365 builder.restoreIP(ompBuilder->createAtomicRead(ompLoc, X, V, AO, allocaIP));
3366 return success();
3367}
3368
3369/// Converts an omp.atomic.write operation to LLVM IR.
3370static LogicalResult
3371convertOmpAtomicWrite(Operation &opInst, llvm::IRBuilderBase &builder,
3372 LLVM::ModuleTranslation &moduleTranslation) {
3373 auto writeOp = cast<omp::AtomicWriteOp>(opInst);
3374 if (failed(checkImplementationStatus(opInst)))
3375 return failure();
3376
3377 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3378 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3379 findAllocaInsertPoint(builder, moduleTranslation);
3380
3381 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3382 llvm::AtomicOrdering ao = convertAtomicOrdering(writeOp.getMemoryOrder());
3383 llvm::Value *expr = moduleTranslation.lookupValue(writeOp.getExpr());
3384 llvm::Value *dest = moduleTranslation.lookupValue(writeOp.getX());
3385 llvm::Type *ty = moduleTranslation.convertType(writeOp.getExpr().getType());
3386 llvm::OpenMPIRBuilder::AtomicOpValue x = {dest, ty, /*isSigned=*/false,
3387 /*isVolatile=*/false};
3388 builder.restoreIP(
3389 ompBuilder->createAtomicWrite(ompLoc, x, expr, ao, allocaIP));
3390 return success();
3391}
3392
3393/// Converts an LLVM dialect binary operation to the corresponding enum value
3394/// for `atomicrmw` supported binary operation.
3395static llvm::AtomicRMWInst::BinOp convertBinOpToAtomic(Operation &op) {
3397 .Case([&](LLVM::AddOp) { return llvm::AtomicRMWInst::BinOp::Add; })
3398 .Case([&](LLVM::SubOp) { return llvm::AtomicRMWInst::BinOp::Sub; })
3399 .Case([&](LLVM::AndOp) { return llvm::AtomicRMWInst::BinOp::And; })
3400 .Case([&](LLVM::OrOp) { return llvm::AtomicRMWInst::BinOp::Or; })
3401 .Case([&](LLVM::XOrOp) { return llvm::AtomicRMWInst::BinOp::Xor; })
3402 .Case([&](LLVM::UMaxOp) { return llvm::AtomicRMWInst::BinOp::UMax; })
3403 .Case([&](LLVM::UMinOp) { return llvm::AtomicRMWInst::BinOp::UMin; })
3404 .Case([&](LLVM::FAddOp) { return llvm::AtomicRMWInst::BinOp::FAdd; })
3405 .Case([&](LLVM::FSubOp) { return llvm::AtomicRMWInst::BinOp::FSub; })
3406 .Default(llvm::AtomicRMWInst::BinOp::BAD_BINOP);
3407}
3408
3409static void extractAtomicControlFlags(omp::AtomicUpdateOp atomicUpdateOp,
3410 bool &isIgnoreDenormalMode,
3411 bool &isFineGrainedMemory,
3412 bool &isRemoteMemory) {
3413 isIgnoreDenormalMode = false;
3414 isFineGrainedMemory = false;
3415 isRemoteMemory = false;
3416 if (atomicUpdateOp &&
3417 atomicUpdateOp->hasAttr(atomicUpdateOp.getAtomicControlAttrName())) {
3418 mlir::omp::AtomicControlAttr atomicControlAttr =
3419 atomicUpdateOp.getAtomicControlAttr();
3420 isIgnoreDenormalMode = atomicControlAttr.getIgnoreDenormalMode();
3421 isFineGrainedMemory = atomicControlAttr.getFineGrainedMemory();
3422 isRemoteMemory = atomicControlAttr.getRemoteMemory();
3423 }
3424}
3425
3426/// Converts an OpenMP atomic update operation using OpenMPIRBuilder.
3427static LogicalResult
3428convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst,
3429 llvm::IRBuilderBase &builder,
3430 LLVM::ModuleTranslation &moduleTranslation) {
3431 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3432 if (failed(checkImplementationStatus(*opInst)))
3433 return failure();
3434
3435 // Convert values and types.
3436 auto &innerOpList = opInst.getRegion().front().getOperations();
3437 bool isXBinopExpr{false};
3438 llvm::AtomicRMWInst::BinOp binop;
3439 mlir::Value mlirExpr;
3440 llvm::Value *llvmExpr = nullptr;
3441 llvm::Value *llvmX = nullptr;
3442 llvm::Type *llvmXElementType = nullptr;
3443 if (innerOpList.size() == 2) {
3444 // The two operations here are the update and the terminator.
3445 // Since we can identify the update operation, there is a possibility
3446 // that we can generate the atomicrmw instruction.
3447 mlir::Operation &innerOp = *opInst.getRegion().front().begin();
3448 if (!llvm::is_contained(innerOp.getOperands(),
3449 opInst.getRegion().getArgument(0))) {
3450 return opInst.emitError("no atomic update operation with region argument"
3451 " as operand found inside atomic.update region");
3452 }
3453 binop = convertBinOpToAtomic(innerOp);
3454 isXBinopExpr = innerOp.getOperand(0) == opInst.getRegion().getArgument(0);
3455 mlirExpr = (isXBinopExpr ? innerOp.getOperand(1) : innerOp.getOperand(0));
3456 llvmExpr = moduleTranslation.lookupValue(mlirExpr);
3457 } else {
3458 // Since the update region includes more than one operation
3459 // we will resort to generating a cmpxchg loop.
3460 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
3461 }
3462 llvmX = moduleTranslation.lookupValue(opInst.getX());
3463 llvmXElementType = moduleTranslation.convertType(
3464 opInst.getRegion().getArgument(0).getType());
3465 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
3466 /*isSigned=*/false,
3467 /*isVolatile=*/false};
3468
3469 llvm::AtomicOrdering atomicOrdering =
3470 convertAtomicOrdering(opInst.getMemoryOrder());
3471
3472 // Generate update code.
3473 auto updateFn =
3474 [&opInst, &moduleTranslation](
3475 llvm::Value *atomicx,
3476 llvm::IRBuilder<> &builder) -> llvm::Expected<llvm::Value *> {
3477 Block &bb = *opInst.getRegion().begin();
3478 moduleTranslation.mapValue(*opInst.getRegion().args_begin(), atomicx);
3479 moduleTranslation.mapBlock(&bb, builder.GetInsertBlock());
3480 if (failed(moduleTranslation.convertBlock(bb, true, builder)))
3481 return llvm::make_error<PreviouslyReportedError>();
3482
3483 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.getTerminator());
3484 assert(yieldop && yieldop.getResults().size() == 1 &&
3485 "terminator must be omp.yield op and it must have exactly one "
3486 "argument");
3487 return moduleTranslation.lookupValue(yieldop.getResults()[0]);
3488 };
3489
3490 bool isIgnoreDenormalMode;
3491 bool isFineGrainedMemory;
3492 bool isRemoteMemory;
3493 extractAtomicControlFlags(opInst, isIgnoreDenormalMode, isFineGrainedMemory,
3494 isRemoteMemory);
3495 // Handle ambiguous alloca, if any.
3496 auto allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
3497 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3498 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3499 ompBuilder->createAtomicUpdate(ompLoc, allocaIP, llvmAtomicX, llvmExpr,
3500 atomicOrdering, binop, updateFn,
3501 isXBinopExpr, isIgnoreDenormalMode,
3502 isFineGrainedMemory, isRemoteMemory);
3503
3504 if (failed(handleError(afterIP, *opInst)))
3505 return failure();
3506
3507 builder.restoreIP(*afterIP);
3508 return success();
3509}
3510
3511static LogicalResult
3512convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp,
3513 llvm::IRBuilderBase &builder,
3514 LLVM::ModuleTranslation &moduleTranslation) {
3515 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3516 if (failed(checkImplementationStatus(*atomicCaptureOp)))
3517 return failure();
3518
3519 mlir::Value mlirExpr;
3520 bool isXBinopExpr = false, isPostfixUpdate = false;
3521 llvm::AtomicRMWInst::BinOp binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
3522
3523 omp::AtomicUpdateOp atomicUpdateOp = atomicCaptureOp.getAtomicUpdateOp();
3524 omp::AtomicWriteOp atomicWriteOp = atomicCaptureOp.getAtomicWriteOp();
3525
3526 assert((atomicUpdateOp || atomicWriteOp) &&
3527 "internal op must be an atomic.update or atomic.write op");
3528
3529 if (atomicWriteOp) {
3530 isPostfixUpdate = true;
3531 mlirExpr = atomicWriteOp.getExpr();
3532 } else {
3533 isPostfixUpdate = atomicCaptureOp.getSecondOp() ==
3534 atomicCaptureOp.getAtomicUpdateOp().getOperation();
3535 auto &innerOpList = atomicUpdateOp.getRegion().front().getOperations();
3536 // Find the binary update operation that uses the region argument
3537 // and get the expression to update
3538 if (innerOpList.size() == 2) {
3539 mlir::Operation &innerOp = *atomicUpdateOp.getRegion().front().begin();
3540 if (!llvm::is_contained(innerOp.getOperands(),
3541 atomicUpdateOp.getRegion().getArgument(0))) {
3542 return atomicUpdateOp.emitError(
3543 "no atomic update operation with region argument"
3544 " as operand found inside atomic.update region");
3545 }
3546 binop = convertBinOpToAtomic(innerOp);
3547 isXBinopExpr =
3548 innerOp.getOperand(0) == atomicUpdateOp.getRegion().getArgument(0);
3549 mlirExpr = (isXBinopExpr ? innerOp.getOperand(1) : innerOp.getOperand(0));
3550 } else {
3551 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
3552 }
3553 }
3554
3555 llvm::Value *llvmExpr = moduleTranslation.lookupValue(mlirExpr);
3556 llvm::Value *llvmX =
3557 moduleTranslation.lookupValue(atomicCaptureOp.getAtomicReadOp().getX());
3558 llvm::Value *llvmV =
3559 moduleTranslation.lookupValue(atomicCaptureOp.getAtomicReadOp().getV());
3560 llvm::Type *llvmXElementType = moduleTranslation.convertType(
3561 atomicCaptureOp.getAtomicReadOp().getElementType());
3562 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
3563 /*isSigned=*/false,
3564 /*isVolatile=*/false};
3565 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicV = {llvmV, llvmXElementType,
3566 /*isSigned=*/false,
3567 /*isVolatile=*/false};
3568
3569 llvm::AtomicOrdering atomicOrdering =
3570 convertAtomicOrdering(atomicCaptureOp.getMemoryOrder());
3571
3572 auto updateFn =
3573 [&](llvm::Value *atomicx,
3574 llvm::IRBuilder<> &builder) -> llvm::Expected<llvm::Value *> {
3575 if (atomicWriteOp)
3576 return moduleTranslation.lookupValue(atomicWriteOp.getExpr());
3577 Block &bb = *atomicUpdateOp.getRegion().begin();
3578 moduleTranslation.mapValue(*atomicUpdateOp.getRegion().args_begin(),
3579 atomicx);
3580 moduleTranslation.mapBlock(&bb, builder.GetInsertBlock());
3581 if (failed(moduleTranslation.convertBlock(bb, true, builder)))
3582 return llvm::make_error<PreviouslyReportedError>();
3583
3584 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.getTerminator());
3585 assert(yieldop && yieldop.getResults().size() == 1 &&
3586 "terminator must be omp.yield op and it must have exactly one "
3587 "argument");
3588 return moduleTranslation.lookupValue(yieldop.getResults()[0]);
3589 };
3590
3591 bool isIgnoreDenormalMode;
3592 bool isFineGrainedMemory;
3593 bool isRemoteMemory;
3594 extractAtomicControlFlags(atomicUpdateOp, isIgnoreDenormalMode,
3595 isFineGrainedMemory, isRemoteMemory);
3596 // Handle ambiguous alloca, if any.
3597 auto allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
3598 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3599 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3600 ompBuilder->createAtomicCapture(
3601 ompLoc, allocaIP, llvmAtomicX, llvmAtomicV, llvmExpr, atomicOrdering,
3602 binop, updateFn, atomicUpdateOp, isPostfixUpdate, isXBinopExpr,
3603 isIgnoreDenormalMode, isFineGrainedMemory, isRemoteMemory);
3604
3605 if (failed(handleError(afterIP, *atomicCaptureOp)))
3606 return failure();
3607
3608 builder.restoreIP(*afterIP);
3609 return success();
3610}
3611
3612static llvm::omp::Directive convertCancellationConstructType(
3613 omp::ClauseCancellationConstructType directive) {
3614 switch (directive) {
3615 case omp::ClauseCancellationConstructType::Loop:
3616 return llvm::omp::Directive::OMPD_for;
3617 case omp::ClauseCancellationConstructType::Parallel:
3618 return llvm::omp::Directive::OMPD_parallel;
3619 case omp::ClauseCancellationConstructType::Sections:
3620 return llvm::omp::Directive::OMPD_sections;
3621 case omp::ClauseCancellationConstructType::Taskgroup:
3622 return llvm::omp::Directive::OMPD_taskgroup;
3623 }
3624 llvm_unreachable("Unhandled cancellation construct type");
3625}
3626
3627static LogicalResult
3628convertOmpCancel(omp::CancelOp op, llvm::IRBuilderBase &builder,
3629 LLVM::ModuleTranslation &moduleTranslation) {
3630 if (failed(checkImplementationStatus(*op.getOperation())))
3631 return failure();
3632
3633 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3634 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3635
3636 llvm::Value *ifCond = nullptr;
3637 if (Value ifVar = op.getIfExpr())
3638 ifCond = moduleTranslation.lookupValue(ifVar);
3639
3640 llvm::omp::Directive cancelledDirective =
3641 convertCancellationConstructType(op.getCancelDirective());
3642
3643 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3644 ompBuilder->createCancel(ompLoc, ifCond, cancelledDirective);
3645
3646 if (failed(handleError(afterIP, *op.getOperation())))
3647 return failure();
3648
3649 builder.restoreIP(afterIP.get());
3650
3651 return success();
3652}
3653
3654static LogicalResult
3655convertOmpCancellationPoint(omp::CancellationPointOp op,
3656 llvm::IRBuilderBase &builder,
3657 LLVM::ModuleTranslation &moduleTranslation) {
3658 if (failed(checkImplementationStatus(*op.getOperation())))
3659 return failure();
3660
3661 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3662 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3663
3664 llvm::omp::Directive cancelledDirective =
3665 convertCancellationConstructType(op.getCancelDirective());
3666
3667 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3668 ompBuilder->createCancellationPoint(ompLoc, cancelledDirective);
3669
3670 if (failed(handleError(afterIP, *op.getOperation())))
3671 return failure();
3672
3673 builder.restoreIP(afterIP.get());
3674
3675 return success();
3676}
3677
3678/// Converts an OpenMP Threadprivate operation into LLVM IR using
3679/// OpenMPIRBuilder.
3680static LogicalResult
3681convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder,
3682 LLVM::ModuleTranslation &moduleTranslation) {
3683 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3684 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3685 auto threadprivateOp = cast<omp::ThreadprivateOp>(opInst);
3686
3687 if (failed(checkImplementationStatus(opInst)))
3688 return failure();
3689
3690 Value symAddr = threadprivateOp.getSymAddr();
3691 auto *symOp = symAddr.getDefiningOp();
3692
3693 if (auto asCast = dyn_cast<LLVM::AddrSpaceCastOp>(symOp))
3694 symOp = asCast.getOperand().getDefiningOp();
3695
3696 if (!isa<LLVM::AddressOfOp>(symOp))
3697 return opInst.emitError("Addressing symbol not found");
3698 LLVM::AddressOfOp addressOfOp = dyn_cast<LLVM::AddressOfOp>(symOp);
3699
3700 LLVM::GlobalOp global =
3701 addressOfOp.getGlobal(moduleTranslation.symbolTable());
3702 llvm::GlobalValue *globalValue = moduleTranslation.lookupGlobal(global);
3703
3704 if (!ompBuilder->Config.isTargetDevice()) {
3705 llvm::Type *type = globalValue->getValueType();
3706 llvm::TypeSize typeSize =
3707 builder.GetInsertBlock()->getModule()->getDataLayout().getTypeStoreSize(
3708 type);
3709 llvm::ConstantInt *size = builder.getInt64(typeSize.getFixedValue());
3710 llvm::Value *callInst = ompBuilder->createCachedThreadPrivate(
3711 ompLoc, globalValue, size, global.getSymName() + ".cache");
3712 moduleTranslation.mapValue(opInst.getResult(0), callInst);
3713 } else {
3714 moduleTranslation.mapValue(opInst.getResult(0), globalValue);
3715 }
3716
3717 return success();
3718}
3719
3720static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind
3721convertToDeviceClauseKind(mlir::omp::DeclareTargetDeviceType deviceClause) {
3722 switch (deviceClause) {
3723 case mlir::omp::DeclareTargetDeviceType::host:
3724 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseHost;
3725 break;
3726 case mlir::omp::DeclareTargetDeviceType::nohost:
3727 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseNoHost;
3728 break;
3729 case mlir::omp::DeclareTargetDeviceType::any:
3730 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseAny;
3731 break;
3732 }
3733 llvm_unreachable("unhandled device clause");
3734}
3735
3736static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind
3738 mlir::omp::DeclareTargetCaptureClause captureClause) {
3739 switch (captureClause) {
3740 case mlir::omp::DeclareTargetCaptureClause::to:
3741 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo;
3742 case mlir::omp::DeclareTargetCaptureClause::link:
3743 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink;
3744 case mlir::omp::DeclareTargetCaptureClause::enter:
3745 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter;
3746 case mlir::omp::DeclareTargetCaptureClause::none:
3747 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryNone;
3748 }
3749 llvm_unreachable("unhandled capture clause");
3750}
3751
3753 Operation *op = value.getDefiningOp();
3754 if (auto addrCast = dyn_cast_if_present<LLVM::AddrSpaceCastOp>(op))
3755 op = addrCast->getOperand(0).getDefiningOp();
3756 if (auto addressOfOp = dyn_cast_if_present<LLVM::AddressOfOp>(op)) {
3757 auto modOp = addressOfOp->getParentOfType<mlir::ModuleOp>();
3758 return modOp.lookupSymbol(addressOfOp.getGlobalName());
3759 }
3760 return nullptr;
3761}
3762
3763static llvm::SmallString<64>
3764getDeclareTargetRefPtrSuffix(LLVM::GlobalOp globalOp,
3765 llvm::OpenMPIRBuilder &ompBuilder) {
3766 llvm::SmallString<64> suffix;
3767 llvm::raw_svector_ostream os(suffix);
3768 if (globalOp.getVisibility() == mlir::SymbolTable::Visibility::Private) {
3769 auto loc = globalOp->getLoc()->findInstanceOf<FileLineColLoc>();
3770 auto fileInfoCallBack = [&loc]() {
3771 return std::pair<std::string, uint64_t>(
3772 llvm::StringRef(loc.getFilename()), loc.getLine());
3773 };
3774
3775 auto vfs = llvm::vfs::getRealFileSystem();
3776 os << llvm::format(
3777 "_%x",
3778 ompBuilder.getTargetEntryUniqueInfo(fileInfoCallBack, *vfs).FileID);
3779 }
3780 os << "_decl_tgt_ref_ptr";
3781
3782 return suffix;
3783}
3784
3785static bool isDeclareTargetLink(Value value) {
3786 if (auto declareTargetGlobal =
3787 dyn_cast_if_present<omp::DeclareTargetInterface>(
3788 getGlobalOpFromValue(value)))
3789 if (declareTargetGlobal.getDeclareTargetCaptureClause() ==
3790 omp::DeclareTargetCaptureClause::link)
3791 return true;
3792 return false;
3793}
3794
3795static bool isDeclareTargetTo(Value value) {
3796 if (auto declareTargetGlobal =
3797 dyn_cast_if_present<omp::DeclareTargetInterface>(
3798 getGlobalOpFromValue(value)))
3799 if (declareTargetGlobal.getDeclareTargetCaptureClause() ==
3800 omp::DeclareTargetCaptureClause::to ||
3801 declareTargetGlobal.getDeclareTargetCaptureClause() ==
3802 omp::DeclareTargetCaptureClause::enter)
3803 return true;
3804 return false;
3805}
3806
3807// Returns the reference pointer generated by the lowering of the declare
3808// target operation in cases where the link clause is used or the to clause is
3809// used in USM mode.
3810static llvm::Value *
3812 LLVM::ModuleTranslation &moduleTranslation) {
3813 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3814 if (auto gOp =
3815 dyn_cast_or_null<LLVM::GlobalOp>(getGlobalOpFromValue(value))) {
3816 if (auto declareTargetGlobal =
3817 dyn_cast<omp::DeclareTargetInterface>(gOp.getOperation())) {
3818 // In this case, we must utilise the reference pointer generated by
3819 // the declare target operation, similar to Clang
3820 if ((declareTargetGlobal.getDeclareTargetCaptureClause() ==
3821 omp::DeclareTargetCaptureClause::link) ||
3822 (declareTargetGlobal.getDeclareTargetCaptureClause() ==
3823 omp::DeclareTargetCaptureClause::to &&
3824 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
3825 llvm::SmallString<64> suffix =
3826 getDeclareTargetRefPtrSuffix(gOp, *ompBuilder);
3827
3828 if (gOp.getSymName().contains(suffix))
3829 return moduleTranslation.getLLVMModule()->getNamedValue(
3830 gOp.getSymName());
3831
3832 return moduleTranslation.getLLVMModule()->getNamedValue(
3833 (gOp.getSymName().str() + suffix.str()).str());
3834 }
3835 }
3836 }
3837 return nullptr;
3838}
3839
3840namespace {
3841// Append customMappers information to existing MapInfosTy
3842struct MapInfosTy : llvm::OpenMPIRBuilder::MapInfosTy {
3843 SmallVector<Operation *, 4> Mappers;
3844
3845 /// Append arrays in \a CurInfo.
3846 void append(MapInfosTy &curInfo) {
3847 Mappers.append(curInfo.Mappers.begin(), curInfo.Mappers.end());
3848 llvm::OpenMPIRBuilder::MapInfosTy::append(curInfo);
3849 }
3850};
3851// A small helper structure to contain data gathered
3852// for map lowering and coalese it into one area and
3853// avoiding extra computations such as searches in the
3854// llvm module for lowered mapped variables or checking
3855// if something is declare target (and retrieving the
3856// value) more than neccessary.
3857struct MapInfoData : MapInfosTy {
3858 llvm::SmallVector<bool, 4> IsDeclareTarget;
3859 llvm::SmallVector<bool, 4> IsAMember;
3860 // Identify if mapping was added by mapClause or use_device clauses.
3861 llvm::SmallVector<bool, 4> IsAMapping;
3862 llvm::SmallVector<mlir::Operation *, 4> MapClause;
3863 llvm::SmallVector<llvm::Value *, 4> OriginalValue;
3864 // Stripped off array/pointer to get the underlying
3865 // element type
3866 llvm::SmallVector<llvm::Type *, 4> BaseType;
3867
3868 /// Append arrays in \a CurInfo.
3869 void append(MapInfoData &CurInfo) {
3870 IsDeclareTarget.append(CurInfo.IsDeclareTarget.begin(),
3871 CurInfo.IsDeclareTarget.end());
3872 MapClause.append(CurInfo.MapClause.begin(), CurInfo.MapClause.end());
3873 OriginalValue.append(CurInfo.OriginalValue.begin(),
3874 CurInfo.OriginalValue.end());
3875 BaseType.append(CurInfo.BaseType.begin(), CurInfo.BaseType.end());
3876 MapInfosTy::append(CurInfo);
3877 }
3878};
3879
3880enum class TargetDirectiveEnumTy : uint32_t {
3881 None = 0,
3882 Target = 1,
3883 TargetData = 2,
3884 TargetEnterData = 3,
3885 TargetExitData = 4,
3886 TargetUpdate = 5
3887};
3888
3889static TargetDirectiveEnumTy getTargetDirectiveEnumTyFromOp(Operation *op) {
3890 return llvm::TypeSwitch<Operation *, TargetDirectiveEnumTy>(op)
3891 .Case([](omp::TargetDataOp) { return TargetDirectiveEnumTy::TargetData; })
3892 .Case([](omp::TargetEnterDataOp) {
3893 return TargetDirectiveEnumTy::TargetEnterData;
3894 })
3895 .Case([&](omp::TargetExitDataOp) {
3896 return TargetDirectiveEnumTy::TargetExitData;
3897 })
3898 .Case([&](omp::TargetUpdateOp) {
3899 return TargetDirectiveEnumTy::TargetUpdate;
3900 })
3901 .Case([&](omp::TargetOp) { return TargetDirectiveEnumTy::Target; })
3902 .Default([&](Operation *op) { return TargetDirectiveEnumTy::None; });
3903}
3904
3905} // namespace
3906
3907static uint64_t getArrayElementSizeInBits(LLVM::LLVMArrayType arrTy,
3908 DataLayout &dl) {
3909 if (auto nestedArrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(
3910 arrTy.getElementType()))
3911 return getArrayElementSizeInBits(nestedArrTy, dl);
3912 return dl.getTypeSizeInBits(arrTy.getElementType());
3913}
3914
3915// This function calculates the size to be offloaded for a specified type, given
3916// its associated map clause (which can contain bounds information which affects
3917// the total size), this size is calculated based on the underlying element type
3918// e.g. given a 1-D array of ints, we will calculate the size from the integer
3919// type * number of elements in the array. This size can be used in other
3920// calculations but is ultimately used as an argument to the OpenMP runtimes
3921// kernel argument structure which is generated through the combinedInfo data
3922// structures.
3923// This function is somewhat equivalent to Clang's getExprTypeSize inside of
3924// CGOpenMPRuntime.cpp.
3925static llvm::Value *getSizeInBytes(DataLayout &dl, const mlir::Type &type,
3926 Operation *clauseOp,
3927 llvm::Value *basePointer,
3928 llvm::Type *baseType,
3929 llvm::IRBuilderBase &builder,
3930 LLVM::ModuleTranslation &moduleTranslation) {
3931 if (auto memberClause =
3932 mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(clauseOp)) {
3933 // This calculates the size to transfer based on bounds and the underlying
3934 // element type, provided bounds have been specified (Fortran
3935 // pointers/allocatables/target and arrays that have sections specified fall
3936 // into this as well)
3937 if (!memberClause.getBounds().empty()) {
3938 llvm::Value *elementCount = builder.getInt64(1);
3939 for (auto bounds : memberClause.getBounds()) {
3940 if (auto boundOp = mlir::dyn_cast_if_present<mlir::omp::MapBoundsOp>(
3941 bounds.getDefiningOp())) {
3942 // The below calculation for the size to be mapped calculated from the
3943 // map.info's bounds is: (elemCount * [UB - LB] + 1), later we
3944 // multiply by the underlying element types byte size to get the full
3945 // size to be offloaded based on the bounds
3946 elementCount = builder.CreateMul(
3947 elementCount,
3948 builder.CreateAdd(
3949 builder.CreateSub(
3950 moduleTranslation.lookupValue(boundOp.getUpperBound()),
3951 moduleTranslation.lookupValue(boundOp.getLowerBound())),
3952 builder.getInt64(1)));
3953 }
3954 }
3955
3956 // utilising getTypeSizeInBits instead of getTypeSize as getTypeSize gives
3957 // the size in inconsistent byte or bit format.
3958 uint64_t underlyingTypeSzInBits = dl.getTypeSizeInBits(type);
3959 if (auto arrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(type))
3960 underlyingTypeSzInBits = getArrayElementSizeInBits(arrTy, dl);
3961
3962 // The size in bytes x number of elements, the sizeInBytes stored is
3963 // the underyling types size, e.g. if ptr<i32>, it'll be the i32's
3964 // size, so we do some on the fly runtime math to get the size in
3965 // bytes from the extent (ub - lb) * sizeInBytes. NOTE: This may need
3966 // some adjustment for members with more complex types.
3967 return builder.CreateMul(elementCount,
3968 builder.getInt64(underlyingTypeSzInBits / 8));
3969 }
3970 }
3971
3972 return builder.getInt64(dl.getTypeSizeInBits(type) / 8);
3973}
3974
3975// Convert the MLIR map flag set to the runtime map flag set for embedding
3976// in LLVM-IR. This is important as the two bit-flag lists do not correspond
3977// 1-to-1 as there's flags the runtime doesn't care about and vice versa.
3978// Certain flags are discarded here such as RefPtee and co.
3979static llvm::omp::OpenMPOffloadMappingFlags
3980convertClauseMapFlags(omp::ClauseMapFlags mlirFlags) {
3981 auto mapTypeToBool = [&mlirFlags](omp::ClauseMapFlags flag) {
3982 return (mlirFlags & flag) == flag;
3983 };
3984 const bool hasExplicitMap =
3985 (mlirFlags & ~omp::ClauseMapFlags::is_device_ptr) !=
3986 omp::ClauseMapFlags::none;
3987
3988 llvm::omp::OpenMPOffloadMappingFlags mapType =
3989 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
3990
3991 if (mapTypeToBool(omp::ClauseMapFlags::to))
3992 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
3993
3994 if (mapTypeToBool(omp::ClauseMapFlags::from))
3995 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
3996
3997 if (mapTypeToBool(omp::ClauseMapFlags::always))
3998 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
3999
4000 if (mapTypeToBool(omp::ClauseMapFlags::del))
4001 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
4002
4003 if (mapTypeToBool(omp::ClauseMapFlags::return_param))
4004 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
4005
4006 if (mapTypeToBool(omp::ClauseMapFlags::priv))
4007 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRIVATE;
4008
4009 if (mapTypeToBool(omp::ClauseMapFlags::literal))
4010 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
4011
4012 if (mapTypeToBool(omp::ClauseMapFlags::implicit))
4013 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
4014
4015 if (mapTypeToBool(omp::ClauseMapFlags::close))
4016 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
4017
4018 if (mapTypeToBool(omp::ClauseMapFlags::present))
4019 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT;
4020
4021 if (mapTypeToBool(omp::ClauseMapFlags::ompx_hold))
4022 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD;
4023
4024 if (mapTypeToBool(omp::ClauseMapFlags::attach))
4025 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH;
4026
4027 if (mapTypeToBool(omp::ClauseMapFlags::is_device_ptr)) {
4028 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
4029 if (!hasExplicitMap)
4030 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
4031 }
4032
4033 return mapType;
4034}
4035
4037 MapInfoData &mapData, SmallVectorImpl<Value> &mapVars,
4038 LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl,
4039 llvm::IRBuilderBase &builder, ArrayRef<Value> useDevPtrOperands = {},
4040 ArrayRef<Value> useDevAddrOperands = {},
4041 ArrayRef<Value> hasDevAddrOperands = {}) {
4042 auto checkIsAMember = [](const auto &mapVars, auto mapOp) {
4043 // Check if this is a member mapping and correctly assign that it is, if
4044 // it is a member of a larger object.
4045 // TODO: Need better handling of members, and distinguishing of members
4046 // that are implicitly allocated on device vs explicitly passed in as
4047 // arguments.
4048 // TODO: May require some further additions to support nested record
4049 // types, i.e. member maps that can have member maps.
4050 for (Value mapValue : mapVars) {
4051 auto map = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
4052 for (auto member : map.getMembers())
4053 if (member == mapOp)
4054 return true;
4055 }
4056 return false;
4057 };
4058
4059 // Process MapOperands
4060 for (Value mapValue : mapVars) {
4061 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
4062 Value offloadPtr =
4063 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
4064 mapData.OriginalValue.push_back(moduleTranslation.lookupValue(offloadPtr));
4065 mapData.Pointers.push_back(mapData.OriginalValue.back());
4066
4067 if (llvm::Value *refPtr =
4068 getRefPtrIfDeclareTarget(offloadPtr, moduleTranslation)) {
4069 mapData.IsDeclareTarget.push_back(true);
4070 mapData.BasePointers.push_back(refPtr);
4071 } else if (isDeclareTargetTo(offloadPtr)) {
4072 mapData.IsDeclareTarget.push_back(true);
4073 mapData.BasePointers.push_back(mapData.OriginalValue.back());
4074 } else { // regular mapped variable
4075 mapData.IsDeclareTarget.push_back(false);
4076 mapData.BasePointers.push_back(mapData.OriginalValue.back());
4077 }
4078
4079 mapData.BaseType.push_back(
4080 moduleTranslation.convertType(mapOp.getVarType()));
4081 mapData.Sizes.push_back(
4082 getSizeInBytes(dl, mapOp.getVarType(), mapOp, mapData.Pointers.back(),
4083 mapData.BaseType.back(), builder, moduleTranslation));
4084 mapData.MapClause.push_back(mapOp.getOperation());
4085 mapData.Types.push_back(convertClauseMapFlags(mapOp.getMapType()));
4086 mapData.Names.push_back(LLVM::createMappingInformation(
4087 mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
4088 mapData.DevicePointers.push_back(llvm::OpenMPIRBuilder::DeviceInfoTy::None);
4089 if (mapOp.getMapperId())
4090 mapData.Mappers.push_back(
4092 mapOp, mapOp.getMapperIdAttr()));
4093 else
4094 mapData.Mappers.push_back(nullptr);
4095 mapData.IsAMapping.push_back(true);
4096 mapData.IsAMember.push_back(checkIsAMember(mapVars, mapOp));
4097 }
4098
4099 auto findMapInfo = [&mapData](llvm::Value *val,
4100 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
4101 unsigned index = 0;
4102 bool found = false;
4103 for (llvm::Value *basePtr : mapData.OriginalValue) {
4104 if (basePtr == val && mapData.IsAMapping[index]) {
4105 found = true;
4106 mapData.Types[index] |=
4107 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
4108 mapData.DevicePointers[index] = devInfoTy;
4109 }
4110 index++;
4111 }
4112 return found;
4113 };
4114
4115 // Process useDevPtr(Addr)Operands
4116 auto addDevInfos = [&](const llvm::ArrayRef<Value> &useDevOperands,
4117 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
4118 for (Value mapValue : useDevOperands) {
4119 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
4120 Value offloadPtr =
4121 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
4122 llvm::Value *origValue = moduleTranslation.lookupValue(offloadPtr);
4123
4124 // Check if map info is already present for this entry.
4125 if (!findMapInfo(origValue, devInfoTy)) {
4126 mapData.OriginalValue.push_back(origValue);
4127 mapData.Pointers.push_back(mapData.OriginalValue.back());
4128 mapData.IsDeclareTarget.push_back(false);
4129 mapData.BasePointers.push_back(mapData.OriginalValue.back());
4130 mapData.BaseType.push_back(
4131 moduleTranslation.convertType(mapOp.getVarType()));
4132 mapData.Sizes.push_back(builder.getInt64(0));
4133 mapData.MapClause.push_back(mapOp.getOperation());
4134 mapData.Types.push_back(
4135 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM);
4136 mapData.Names.push_back(LLVM::createMappingInformation(
4137 mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
4138 mapData.DevicePointers.push_back(devInfoTy);
4139 mapData.Mappers.push_back(nullptr);
4140 mapData.IsAMapping.push_back(false);
4141 mapData.IsAMember.push_back(checkIsAMember(useDevOperands, mapOp));
4142 }
4143 }
4144 };
4145
4146 addDevInfos(useDevAddrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
4147 addDevInfos(useDevPtrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer);
4148
4149 for (Value mapValue : hasDevAddrOperands) {
4150 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
4151 Value offloadPtr =
4152 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
4153 llvm::Value *origValue = moduleTranslation.lookupValue(offloadPtr);
4154 auto mapType = convertClauseMapFlags(mapOp.getMapType());
4155 auto mapTypeAlways = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
4156 bool isDevicePtr =
4157 (mapOp.getMapType() & omp::ClauseMapFlags::is_device_ptr) !=
4158 omp::ClauseMapFlags::none;
4159
4160 mapData.OriginalValue.push_back(origValue);
4161 mapData.BasePointers.push_back(origValue);
4162 mapData.Pointers.push_back(origValue);
4163 mapData.IsDeclareTarget.push_back(false);
4164 mapData.BaseType.push_back(
4165 moduleTranslation.convertType(mapOp.getVarType()));
4166 mapData.Sizes.push_back(
4167 builder.getInt64(dl.getTypeSize(mapOp.getVarType())));
4168 mapData.MapClause.push_back(mapOp.getOperation());
4169 if (llvm::to_underlying(mapType & mapTypeAlways)) {
4170 // Descriptors are mapped with the ALWAYS flag, since they can get
4171 // rematerialized, so the address of the decriptor for a given object
4172 // may change from one place to another.
4173 mapData.Types.push_back(mapType);
4174 // Technically it's possible for a non-descriptor mapping to have
4175 // both has-device-addr and ALWAYS, so lookup the mapper in case it
4176 // exists.
4177 if (mapOp.getMapperId()) {
4178 mapData.Mappers.push_back(
4180 mapOp, mapOp.getMapperIdAttr()));
4181 } else {
4182 mapData.Mappers.push_back(nullptr);
4183 }
4184 } else {
4185 // For is_device_ptr we need the map type to propagate so the runtime
4186 // can materialize the device-side copy of the pointer container.
4187 mapData.Types.push_back(
4188 isDevicePtr ? mapType
4189 : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL);
4190 mapData.Mappers.push_back(nullptr);
4191 }
4192 mapData.Names.push_back(LLVM::createMappingInformation(
4193 mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
4194 mapData.DevicePointers.push_back(
4195 isDevicePtr ? llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer
4196 : llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
4197 mapData.IsAMapping.push_back(false);
4198 mapData.IsAMember.push_back(checkIsAMember(hasDevAddrOperands, mapOp));
4199 }
4200}
4201
4202static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp) {
4203 auto *res = llvm::find(mapData.MapClause, memberOp);
4204 assert(res != mapData.MapClause.end() &&
4205 "MapInfoOp for member not found in MapData, cannot return index");
4206 return std::distance(mapData.MapClause.begin(), res);
4207}
4208
4210 omp::MapInfoOp mapInfo) {
4211 ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
4212 llvm::SmallVector<size_t> occludedChildren;
4213 llvm::sort(
4214 indices.begin(), indices.end(), [&](const size_t a, const size_t b) {
4215 // Bail early if we are asked to look at the same index. If we do not
4216 // bail early, we can end up mistakenly adding indices to
4217 // occludedChildren. This can occur with some types of libc++ hardening.
4218 if (a == b)
4219 return false;
4220
4221 auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]);
4222 auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]);
4223
4224 for (auto it : llvm::zip(memberIndicesA, memberIndicesB)) {
4225 int64_t aIndex = mlir::cast<IntegerAttr>(std::get<0>(it)).getInt();
4226 int64_t bIndex = mlir::cast<IntegerAttr>(std::get<1>(it)).getInt();
4227
4228 if (aIndex == bIndex)
4229 continue;
4230
4231 if (aIndex < bIndex)
4232 return true;
4233
4234 if (aIndex > bIndex)
4235 return false;
4236 }
4237
4238 // Iterated up until the end of the smallest member and
4239 // they were found to be equal up to that point, so select
4240 // the member with the lowest index count, so the "parent"
4241 bool memberAParent = memberIndicesA.size() < memberIndicesB.size();
4242 if (memberAParent)
4243 occludedChildren.push_back(b);
4244 else
4245 occludedChildren.push_back(a);
4246 return memberAParent;
4247 });
4248
4249 // We remove children from the index list that are overshadowed by
4250 // a parent, this prevents us retrieving these as the first or last
4251 // element when the parent is the correct element in these cases.
4252 for (auto v : occludedChildren)
4253 indices.erase(std::remove(indices.begin(), indices.end(), v),
4254 indices.end());
4255}
4256
4257static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo,
4258 bool first) {
4259 ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
4260 // Only 1 member has been mapped, we can return it.
4261 if (indexAttr.size() == 1)
4262 return cast<omp::MapInfoOp>(mapInfo.getMembers()[0].getDefiningOp());
4263 llvm::SmallVector<size_t> indices(indexAttr.size());
4264 std::iota(indices.begin(), indices.end(), 0);
4265 sortMapIndices(indices, mapInfo);
4266 return llvm::cast<omp::MapInfoOp>(
4267 mapInfo.getMembers()[first ? indices.front() : indices.back()]
4268 .getDefiningOp());
4269}
4270
4271/// This function calculates the array/pointer offset for map data provided
4272/// with bounds operations, e.g. when provided something like the following:
4273///
4274/// Fortran
4275/// map(tofrom: array(2:5, 3:2))
4276///
4277/// We must calculate the initial pointer offset to pass across, this function
4278/// performs this using bounds.
4279///
4280/// TODO/WARNING: This only supports Fortran's column major indexing currently
4281/// as is noted in the note below and comments in the function, we must extend
4282/// this function when we add a C++ frontend.
4283/// NOTE: which while specified in row-major order it currently needs to be
4284/// flipped for Fortran's column order array allocation and access (as
4285/// opposed to C++'s row-major, hence the backwards processing where order is
4286/// important). This is likely important to keep in mind for the future when
4287/// we incorporate a C++ frontend, both frontends will need to agree on the
4288/// ordering of generated bounds operations (one may have to flip them) to
4289/// make the below lowering frontend agnostic. The offload size
4290/// calcualtion may also have to be adjusted for C++.
4291static std::vector<llvm::Value *>
4293 llvm::IRBuilderBase &builder, bool isArrayTy,
4294 OperandRange bounds) {
4295 std::vector<llvm::Value *> idx;
4296 // There's no bounds to calculate an offset from, we can safely
4297 // ignore and return no indices.
4298 if (bounds.empty())
4299 return idx;
4300
4301 // If we have an array type, then we have its type so can treat it as a
4302 // normal GEP instruction where the bounds operations are simply indexes
4303 // into the array. We currently do reverse order of the bounds, which
4304 // I believe leans more towards Fortran's column-major in memory.
4305 if (isArrayTy) {
4306 idx.push_back(builder.getInt64(0));
4307 for (int i = bounds.size() - 1; i >= 0; --i) {
4308 if (auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
4309 bounds[i].getDefiningOp())) {
4310 idx.push_back(moduleTranslation.lookupValue(boundOp.getLowerBound()));
4311 }
4312 }
4313 } else {
4314 // If we do not have an array type, but we have bounds, then we're dealing
4315 // with a pointer that's being treated like an array and we have the
4316 // underlying type e.g. an i32, or f64 etc, e.g. a fortran descriptor base
4317 // address (pointer pointing to the actual data) so we must caclulate the
4318 // offset using a single index which the following loop attempts to
4319 // compute using the standard column-major algorithm e.g for a 3D array:
4320 //
4321 // ((((c_idx * b_len) + b_idx) * a_len) + a_idx)
4322 //
4323 // It is of note that it's doing column-major rather than row-major at the
4324 // moment, but having a way for the frontend to indicate which major format
4325 // to use or standardizing/canonicalizing the order of the bounds to compute
4326 // the offset may be useful in the future when there's other frontends with
4327 // different formats.
4328 std::vector<llvm::Value *> dimensionIndexSizeOffset;
4329 for (int i = bounds.size() - 1; i >= 0; --i) {
4330 if (auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
4331 bounds[i].getDefiningOp())) {
4332 if (i == ((int)bounds.size() - 1))
4333 idx.emplace_back(
4334 moduleTranslation.lookupValue(boundOp.getLowerBound()));
4335 else
4336 idx.back() = builder.CreateAdd(
4337 builder.CreateMul(idx.back(), moduleTranslation.lookupValue(
4338 boundOp.getExtent())),
4339 moduleTranslation.lookupValue(boundOp.getLowerBound()));
4340 }
4341 }
4342 }
4343
4344 return idx;
4345}
4346
4348 llvm::transform(values, std::back_inserter(ints), [](Attribute value) {
4349 return cast<IntegerAttr>(value).getInt();
4350 });
4351}
4352
4353// Gathers members that are overlapping in the parent, excluding members that
4354// themselves overlap, keeping the top-most (closest to parents level) map.
4355static void
4357 omp::MapInfoOp parentOp) {
4358 // No members mapped, no overlaps.
4359 if (parentOp.getMembers().empty())
4360 return;
4361
4362 // Single member, we can insert and return early.
4363 if (parentOp.getMembers().size() == 1) {
4364 overlapMapDataIdxs.push_back(0);
4365 return;
4366 }
4367
4368 // 1) collect list of top-level overlapping members from MemberOp
4370 ArrayAttr indexAttr = parentOp.getMembersIndexAttr();
4371 for (auto [memIndex, indicesAttr] : llvm::enumerate(indexAttr))
4372 memberByIndex.push_back(
4373 std::make_pair(memIndex, cast<ArrayAttr>(indicesAttr)));
4374
4375 // Sort the smallest first (higher up the parent -> member chain), so that
4376 // when we remove members, we remove as much as we can in the initial
4377 // iterations, shortening the number of passes required.
4378 llvm::sort(memberByIndex.begin(), memberByIndex.end(),
4379 [&](auto a, auto b) { return a.second.size() < b.second.size(); });
4380
4381 // Remove elements from the vector if there is a parent element that
4382 // supersedes it. i.e. if member [0] is mapped, we can remove members [0,1],
4383 // [0,2].. etc.
4385 for (auto v : memberByIndex) {
4386 llvm::SmallVector<int64_t> vArr(v.second.size());
4387 getAsIntegers(v.second, vArr);
4388 skipList.push_back(
4389 *std::find_if(memberByIndex.begin(), memberByIndex.end(), [&](auto x) {
4390 if (v == x)
4391 return false;
4392 llvm::SmallVector<int64_t> xArr(x.second.size());
4393 getAsIntegers(x.second, xArr);
4394 return std::equal(vArr.begin(), vArr.end(), xArr.begin()) &&
4395 xArr.size() >= vArr.size();
4396 }));
4397 }
4398
4399 // Collect the indices, as we need the base pointer etc. from the MapData
4400 // structure which is primarily accessible via index at the moment.
4401 for (auto v : memberByIndex)
4402 if (find(skipList.begin(), skipList.end(), v) == skipList.end())
4403 overlapMapDataIdxs.push_back(v.first);
4404}
4405
4406// The intent is to verify if the mapped data being passed is a
4407// pointer -> pointee that requires special handling in certain cases,
4408// e.g. applying the OMP_MAP_PTR_AND_OBJ map type.
4409//
4410// There may be a better way to verify this, but unfortunately with
4411// opaque pointers we lose the ability to easily check if something is
4412// a pointer whilst maintaining access to the underlying type.
4413static bool checkIfPointerMap(omp::MapInfoOp mapOp) {
4414 // If we have a varPtrPtr field assigned then the underlying type is a pointer
4415 if (mapOp.getVarPtrPtr())
4416 return true;
4417
4418 // If the map data is declare target with a link clause, then it's represented
4419 // as a pointer when we lower it to LLVM-IR even if at the MLIR level it has
4420 // no relation to pointers.
4421 if (isDeclareTargetLink(mapOp.getVarPtr()))
4422 return true;
4423
4424 return false;
4425}
4426
4427// This creates two insertions into the MapInfosTy data structure for the
4428// "parent" of a set of members, (usually a container e.g.
4429// class/structure/derived type) when subsequent members have also been
4430// explicitly mapped on the same map clause. Certain types, such as Fortran
4431// descriptors are mapped like this as well, however, the members are
4432// implicit as far as a user is concerned, but we must explicitly map them
4433// internally.
4434//
4435// This function also returns the memberOfFlag for this particular parent,
4436// which is utilised in subsequent member mappings (by modifying there map type
4437// with it) to indicate that a member is part of this parent and should be
4438// treated by the runtime as such. Important to achieve the correct mapping.
4439//
4440// This function borrows a lot from Clang's emitCombinedEntry function
4441// inside of CGOpenMPRuntime.cpp
4442static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
4443 LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder,
4444 llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo,
4445 MapInfoData &mapData, uint64_t mapDataIndex,
4446 TargetDirectiveEnumTy targetDirective) {
4447 assert(!ompBuilder.Config.isTargetDevice() &&
4448 "function only supported for host device codegen");
4449
4450 // Map the first segment of the parent. If a user-defined mapper is attached,
4451 // include the parent's to/from-style bits (and common modifiers) in this
4452 // base entry so the mapper receives correct copy semantics via its 'type'
4453 // parameter. Also keep TARGET_PARAM when required for kernel arguments.
4454 llvm::omp::OpenMPOffloadMappingFlags baseFlag =
4455 (targetDirective == TargetDirectiveEnumTy::Target &&
4456 !mapData.IsDeclareTarget[mapDataIndex])
4457 ? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM
4458 : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
4459
4460 // Detect if this mapping uses a user-defined mapper.
4461 bool hasUserMapper = mapData.Mappers[mapDataIndex] != nullptr;
4462 if (hasUserMapper) {
4463 using mapFlags = llvm::omp::OpenMPOffloadMappingFlags;
4464 // Preserve relevant map-type bits from the parent clause. These include
4465 // the copy direction (TO/FROM), as well as commonly used modifiers that
4466 // should be visible to the mapper for correct behaviour.
4467 mapFlags parentFlags = mapData.Types[mapDataIndex];
4468 mapFlags preserve = mapFlags::OMP_MAP_TO | mapFlags::OMP_MAP_FROM |
4469 mapFlags::OMP_MAP_ALWAYS | mapFlags::OMP_MAP_CLOSE |
4470 mapFlags::OMP_MAP_PRESENT | mapFlags::OMP_MAP_OMPX_HOLD;
4471 baseFlag |= (parentFlags & preserve);
4472 }
4473
4474 combinedInfo.Types.emplace_back(baseFlag);
4475 combinedInfo.DevicePointers.emplace_back(
4476 mapData.DevicePointers[mapDataIndex]);
4477 combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIndex]);
4478 combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
4479 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
4480 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
4481
4482 // Calculate size of the parent object being mapped based on the
4483 // addresses at runtime, highAddr - lowAddr = size. This of course
4484 // doesn't factor in allocated data like pointers, hence the further
4485 // processing of members specified by users, or in the case of
4486 // Fortran pointers and allocatables, the mapping of the pointed to
4487 // data by the descriptor (which itself, is a structure containing
4488 // runtime information on the dynamically allocated data).
4489 auto parentClause =
4490 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
4491 llvm::Value *lowAddr, *highAddr;
4492 if (!parentClause.getPartialMap()) {
4493 lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
4494 builder.getPtrTy());
4495 highAddr = builder.CreatePointerCast(
4496 builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
4497 mapData.Pointers[mapDataIndex], 1),
4498 builder.getPtrTy());
4499 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
4500 } else {
4501 auto mapOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
4502 int firstMemberIdx = getMapDataMemberIdx(
4503 mapData, getFirstOrLastMappedMemberPtr(mapOp, true));
4504 lowAddr = builder.CreatePointerCast(mapData.Pointers[firstMemberIdx],
4505 builder.getPtrTy());
4506 int lastMemberIdx = getMapDataMemberIdx(
4507 mapData, getFirstOrLastMappedMemberPtr(mapOp, false));
4508 highAddr = builder.CreatePointerCast(
4509 builder.CreateGEP(mapData.BaseType[lastMemberIdx],
4510 mapData.Pointers[lastMemberIdx], builder.getInt64(1)),
4511 builder.getPtrTy());
4512 combinedInfo.Pointers.emplace_back(mapData.Pointers[firstMemberIdx]);
4513 }
4514
4515 llvm::Value *size = builder.CreateIntCast(
4516 builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
4517 builder.getInt64Ty(),
4518 /*isSigned=*/false);
4519 combinedInfo.Sizes.push_back(size);
4520
4521 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag =
4522 ompBuilder.getMemberOfFlag(combinedInfo.BasePointers.size() - 1);
4523
4524 // This creates the initial MEMBER_OF mapping that consists of
4525 // the parent/top level container (same as above effectively, except
4526 // with a fixed initial compile time size and separate maptype which
4527 // indicates the true mape type (tofrom etc.). This parent mapping is
4528 // only relevant if the structure in its totality is being mapped,
4529 // otherwise the above suffices.
4530 if (!parentClause.getPartialMap()) {
4531 // TODO: This will need to be expanded to include the whole host of logic
4532 // for the map flags that Clang currently supports (e.g. it should do some
4533 // further case specific flag modifications). For the moment, it handles
4534 // what we support as expected.
4535 llvm::omp::OpenMPOffloadMappingFlags mapFlag = mapData.Types[mapDataIndex];
4536 bool hasMapClose = (llvm::omp::OpenMPOffloadMappingFlags(mapFlag) &
4537 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE) ==
4538 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
4539 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
4540
4541 if (targetDirective == TargetDirectiveEnumTy::TargetUpdate || hasMapClose) {
4542 combinedInfo.Types.emplace_back(mapFlag);
4543 combinedInfo.DevicePointers.emplace_back(
4544 mapData.DevicePointers[mapDataIndex]);
4545 combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
4546 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
4547 combinedInfo.BasePointers.emplace_back(
4548 mapData.BasePointers[mapDataIndex]);
4549 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
4550 combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIndex]);
4551 combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIndex]);
4552 } else {
4553 llvm::SmallVector<size_t> overlapIdxs;
4554 // Find all of the members that "overlap", i.e. occlude other members that
4555 // were mapped alongside the parent, e.g. member [0], occludes [0,1] and
4556 // [0,2], but not [1,0].
4557 getOverlappedMembers(overlapIdxs, parentClause);
4558 // We need to make sure the overlapped members are sorted in order of
4559 // lowest address to highest address.
4560 sortMapIndices(overlapIdxs, parentClause);
4561
4562 lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
4563 builder.getPtrTy());
4564 highAddr = builder.CreatePointerCast(
4565 builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
4566 mapData.Pointers[mapDataIndex], 1),
4567 builder.getPtrTy());
4568
4569 // TODO: We may want to skip arrays/array sections in this as Clang does.
4570 // It appears to be an optimisation rather than a necessity though,
4571 // but this requires further investigation. However, we would have to make
4572 // sure to not exclude maps with bounds that ARE pointers, as these are
4573 // processed as separate components, i.e. pointer + data.
4574 for (auto v : overlapIdxs) {
4575 auto mapDataOverlapIdx = getMapDataMemberIdx(
4576 mapData,
4577 cast<omp::MapInfoOp>(parentClause.getMembers()[v].getDefiningOp()));
4578 combinedInfo.Types.emplace_back(mapFlag);
4579 combinedInfo.DevicePointers.emplace_back(
4580 mapData.DevicePointers[mapDataOverlapIdx]);
4581 combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
4582 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
4583 combinedInfo.BasePointers.emplace_back(
4584 mapData.BasePointers[mapDataIndex]);
4585 combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIndex]);
4586 combinedInfo.Pointers.emplace_back(lowAddr);
4587 combinedInfo.Sizes.emplace_back(builder.CreateIntCast(
4588 builder.CreatePtrDiff(builder.getInt8Ty(),
4589 mapData.OriginalValue[mapDataOverlapIdx],
4590 lowAddr),
4591 builder.getInt64Ty(), /*isSigned=*/true));
4592 lowAddr = builder.CreateConstGEP1_32(
4593 checkIfPointerMap(llvm::cast<omp::MapInfoOp>(
4594 mapData.MapClause[mapDataOverlapIdx]))
4595 ? builder.getPtrTy()
4596 : mapData.BaseType[mapDataOverlapIdx],
4597 mapData.BasePointers[mapDataOverlapIdx], 1);
4598 }
4599
4600 combinedInfo.Types.emplace_back(mapFlag);
4601 combinedInfo.DevicePointers.emplace_back(
4602 mapData.DevicePointers[mapDataIndex]);
4603 combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
4604 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
4605 combinedInfo.BasePointers.emplace_back(
4606 mapData.BasePointers[mapDataIndex]);
4607 combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIndex]);
4608 combinedInfo.Pointers.emplace_back(lowAddr);
4609 combinedInfo.Sizes.emplace_back(builder.CreateIntCast(
4610 builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
4611 builder.getInt64Ty(), true));
4612 }
4613 }
4614 return memberOfFlag;
4615}
4616
4617// This function is intended to add explicit mappings of members
4619 LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder,
4620 llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo,
4621 MapInfoData &mapData, uint64_t mapDataIndex,
4622 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag,
4623 TargetDirectiveEnumTy targetDirective) {
4624 assert(!ompBuilder.Config.isTargetDevice() &&
4625 "function only supported for host device codegen");
4626
4627 auto parentClause =
4628 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
4629
4630 for (auto mappedMembers : parentClause.getMembers()) {
4631 auto memberClause =
4632 llvm::cast<omp::MapInfoOp>(mappedMembers.getDefiningOp());
4633 int memberDataIdx = getMapDataMemberIdx(mapData, memberClause);
4634
4635 assert(memberDataIdx >= 0 && "could not find mapped member of structure");
4636
4637 // If we're currently mapping a pointer to a block of data, we must
4638 // initially map the pointer, and then attatch/bind the data with a
4639 // subsequent map to the pointer. This segment of code generates the
4640 // pointer mapping, which can in certain cases be optimised out as Clang
4641 // currently does in its lowering. However, for the moment we do not do so,
4642 // in part as we currently have substantially less information on the data
4643 // being mapped at this stage.
4644 if (checkIfPointerMap(memberClause)) {
4645 auto mapFlag = convertClauseMapFlags(memberClause.getMapType());
4646 mapFlag &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
4647 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
4648 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
4649 combinedInfo.Types.emplace_back(mapFlag);
4650 combinedInfo.DevicePointers.emplace_back(
4651 llvm::OpenMPIRBuilder::DeviceInfoTy::None);
4652 combinedInfo.Mappers.emplace_back(nullptr);
4653 combinedInfo.Names.emplace_back(
4654 LLVM::createMappingInformation(memberClause.getLoc(), ompBuilder));
4655 combinedInfo.BasePointers.emplace_back(
4656 mapData.BasePointers[mapDataIndex]);
4657 combinedInfo.Pointers.emplace_back(mapData.BasePointers[memberDataIdx]);
4658 combinedInfo.Sizes.emplace_back(builder.getInt64(
4659 moduleTranslation.getLLVMModule()->getDataLayout().getPointerSize()));
4660 }
4661
4662 // Same MemberOfFlag to indicate its link with parent and other members
4663 // of.
4664 auto mapFlag = convertClauseMapFlags(memberClause.getMapType());
4665 mapFlag &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
4666 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
4667 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
4668 bool isDeclTargetTo = isDeclareTargetTo(parentClause.getVarPtr()
4669 ? parentClause.getVarPtr()
4670 : parentClause.getVarPtrPtr());
4671 if (checkIfPointerMap(memberClause) &&
4672 (!isDeclTargetTo ||
4673 (targetDirective != TargetDirectiveEnumTy::TargetUpdate &&
4674 targetDirective != TargetDirectiveEnumTy::TargetData))) {
4675 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
4676 }
4677
4678 combinedInfo.Types.emplace_back(mapFlag);
4679 combinedInfo.DevicePointers.emplace_back(
4680 mapData.DevicePointers[memberDataIdx]);
4681 combinedInfo.Mappers.emplace_back(mapData.Mappers[memberDataIdx]);
4682 combinedInfo.Names.emplace_back(
4683 LLVM::createMappingInformation(memberClause.getLoc(), ompBuilder));
4684 uint64_t basePointerIndex =
4685 checkIfPointerMap(memberClause) ? memberDataIdx : mapDataIndex;
4686 combinedInfo.BasePointers.emplace_back(
4687 mapData.BasePointers[basePointerIndex]);
4688 combinedInfo.Pointers.emplace_back(mapData.Pointers[memberDataIdx]);
4689
4690 llvm::Value *size = mapData.Sizes[memberDataIdx];
4691 if (checkIfPointerMap(memberClause)) {
4692 size = builder.CreateSelect(
4693 builder.CreateIsNull(mapData.Pointers[memberDataIdx]),
4694 builder.getInt64(0), size);
4695 }
4696
4697 combinedInfo.Sizes.emplace_back(size);
4698 }
4699}
4700
4701static void processIndividualMap(MapInfoData &mapData, size_t mapDataIdx,
4702 MapInfosTy &combinedInfo,
4703 TargetDirectiveEnumTy targetDirective,
4704 int mapDataParentIdx = -1) {
4705 // Declare Target Mappings are excluded from being marked as
4706 // OMP_MAP_TARGET_PARAM as they are not passed as parameters, they're
4707 // marked with OMP_MAP_PTR_AND_OBJ instead.
4708 auto mapFlag = mapData.Types[mapDataIdx];
4709 auto mapInfoOp = llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIdx]);
4710
4711 bool isPtrTy = checkIfPointerMap(mapInfoOp);
4712 if (isPtrTy)
4713 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
4714
4715 if (targetDirective == TargetDirectiveEnumTy::Target &&
4716 !mapData.IsDeclareTarget[mapDataIdx])
4717 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
4718
4719 if (mapInfoOp.getMapCaptureType() == omp::VariableCaptureKind::ByCopy &&
4720 !isPtrTy)
4721 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
4722
4723 // if we're provided a mapDataParentIdx, then the data being mapped is
4724 // part of a larger object (in a parent <-> member mapping) and in this
4725 // case our BasePointer should be the parent.
4726 if (mapDataParentIdx >= 0)
4727 combinedInfo.BasePointers.emplace_back(
4728 mapData.BasePointers[mapDataParentIdx]);
4729 else
4730 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIdx]);
4731
4732 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIdx]);
4733 combinedInfo.DevicePointers.emplace_back(mapData.DevicePointers[mapDataIdx]);
4734 combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIdx]);
4735 combinedInfo.Names.emplace_back(mapData.Names[mapDataIdx]);
4736 combinedInfo.Types.emplace_back(mapFlag);
4737 combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIdx]);
4738}
4739
4741 llvm::IRBuilderBase &builder,
4742 llvm::OpenMPIRBuilder &ompBuilder,
4743 DataLayout &dl, MapInfosTy &combinedInfo,
4744 MapInfoData &mapData, uint64_t mapDataIndex,
4745 TargetDirectiveEnumTy targetDirective) {
4746 assert(!ompBuilder.Config.isTargetDevice() &&
4747 "function only supported for host device codegen");
4748
4749 auto parentClause =
4750 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
4751
4752 // If we have a partial map (no parent referenced in the map clauses of the
4753 // directive, only members) and only a single member, we do not need to bind
4754 // the map of the member to the parent, we can pass the member separately.
4755 if (parentClause.getMembers().size() == 1 && parentClause.getPartialMap()) {
4756 auto memberClause = llvm::cast<omp::MapInfoOp>(
4757 parentClause.getMembers()[0].getDefiningOp());
4758 int memberDataIdx = getMapDataMemberIdx(mapData, memberClause);
4759 // Note: Clang treats arrays with explicit bounds that fall into this
4760 // category as a parent with map case, however, it seems this isn't a
4761 // requirement, and processing them as an individual map is fine. So,
4762 // we will handle them as individual maps for the moment, as it's
4763 // difficult for us to check this as we always require bounds to be
4764 // specified currently and it's also marginally more optimal (single
4765 // map rather than two). The difference may come from the fact that
4766 // Clang maps array without bounds as pointers (which we do not
4767 // currently do), whereas we treat them as arrays in all cases
4768 // currently.
4769 processIndividualMap(mapData, memberDataIdx, combinedInfo, targetDirective,
4770 mapDataIndex);
4771 return;
4772 }
4773
4774 llvm::omp::OpenMPOffloadMappingFlags memberOfParentFlag =
4775 mapParentWithMembers(moduleTranslation, builder, ompBuilder, dl,
4776 combinedInfo, mapData, mapDataIndex,
4777 targetDirective);
4778 processMapMembersWithParent(moduleTranslation, builder, ompBuilder, dl,
4779 combinedInfo, mapData, mapDataIndex,
4780 memberOfParentFlag, targetDirective);
4781}
4782
4783// This is a variation on Clang's GenerateOpenMPCapturedVars, which
4784// generates different operation (e.g. load/store) combinations for
4785// arguments to the kernel, based on map capture kinds which are then
4786// utilised in the combinedInfo in place of the original Map value.
4787static void
4788createAlteredByCaptureMap(MapInfoData &mapData,
4789 LLVM::ModuleTranslation &moduleTranslation,
4790 llvm::IRBuilderBase &builder) {
4791 assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
4792 "function only supported for host device codegen");
4793 for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
4794 // if it's declare target, skip it, it's handled separately.
4795 if (!mapData.IsDeclareTarget[i]) {
4796 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
4797 omp::VariableCaptureKind captureKind = mapOp.getMapCaptureType();
4798 bool isPtrTy = checkIfPointerMap(mapOp);
4799
4800 // Currently handles array sectioning lowerbound case, but more
4801 // logic may be required in the future. Clang invokes EmitLValue,
4802 // which has specialised logic for special Clang types such as user
4803 // defines, so it is possible we will have to extend this for
4804 // structures or other complex types. As the general idea is that this
4805 // function mimics some of the logic from Clang that we require for
4806 // kernel argument passing from host -> device.
4807 switch (captureKind) {
4808 case omp::VariableCaptureKind::ByRef: {
4809 llvm::Value *newV = mapData.Pointers[i];
4810 std::vector<llvm::Value *> offsetIdx = calculateBoundsOffset(
4811 moduleTranslation, builder, mapData.BaseType[i]->isArrayTy(),
4812 mapOp.getBounds());
4813 if (isPtrTy)
4814 newV = builder.CreateLoad(builder.getPtrTy(), newV);
4815
4816 if (!offsetIdx.empty())
4817 newV = builder.CreateInBoundsGEP(mapData.BaseType[i], newV, offsetIdx,
4818 "array_offset");
4819 mapData.Pointers[i] = newV;
4820 } break;
4821 case omp::VariableCaptureKind::ByCopy: {
4822 llvm::Type *type = mapData.BaseType[i];
4823 llvm::Value *newV;
4824 if (mapData.Pointers[i]->getType()->isPointerTy())
4825 newV = builder.CreateLoad(type, mapData.Pointers[i]);
4826 else
4827 newV = mapData.Pointers[i];
4828
4829 if (!isPtrTy) {
4830 auto curInsert = builder.saveIP();
4831 llvm::DebugLoc DbgLoc = builder.getCurrentDebugLocation();
4832 builder.restoreIP(findAllocaInsertPoint(builder, moduleTranslation));
4833 auto *memTempAlloc =
4834 builder.CreateAlloca(builder.getPtrTy(), nullptr, ".casted");
4835 builder.SetCurrentDebugLocation(DbgLoc);
4836 builder.restoreIP(curInsert);
4837
4838 builder.CreateStore(newV, memTempAlloc);
4839 newV = builder.CreateLoad(builder.getPtrTy(), memTempAlloc);
4840 }
4841
4842 mapData.Pointers[i] = newV;
4843 mapData.BasePointers[i] = newV;
4844 } break;
4845 case omp::VariableCaptureKind::This:
4846 case omp::VariableCaptureKind::VLAType:
4847 mapData.MapClause[i]->emitOpError("Unhandled capture kind");
4848 break;
4849 }
4850 }
4851 }
4852}
4853
4854// Generate all map related information and fill the combinedInfo.
4855static void genMapInfos(llvm::IRBuilderBase &builder,
4856 LLVM::ModuleTranslation &moduleTranslation,
4857 DataLayout &dl, MapInfosTy &combinedInfo,
4858 MapInfoData &mapData,
4859 TargetDirectiveEnumTy targetDirective) {
4860 assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
4861 "function only supported for host device codegen");
4862 // We wish to modify some of the methods in which arguments are
4863 // passed based on their capture type by the target region, this can
4864 // involve generating new loads and stores, which changes the
4865 // MLIR value to LLVM value mapping, however, we only wish to do this
4866 // locally for the current function/target and also avoid altering
4867 // ModuleTranslation, so we remap the base pointer or pointer stored
4868 // in the map infos corresponding MapInfoData, which is later accessed
4869 // by genMapInfos and createTarget to help generate the kernel and
4870 // kernel arg structure. It primarily becomes relevant in cases like
4871 // bycopy, or byref range'd arrays. In the default case, we simply
4872 // pass thee pointer byref as both basePointer and pointer.
4873 createAlteredByCaptureMap(mapData, moduleTranslation, builder);
4874
4875 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4876
4877 // We operate under the assumption that all vectors that are
4878 // required in MapInfoData are of equal lengths (either filled with
4879 // default constructed data or appropiate information) so we can
4880 // utilise the size from any component of MapInfoData, if we can't
4881 // something is missing from the initial MapInfoData construction.
4882 for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
4883 // NOTE/TODO: We currently do not support arbitrary depth record
4884 // type mapping.
4885 if (mapData.IsAMember[i])
4886 continue;
4887
4888 auto mapInfoOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[i]);
4889 if (!mapInfoOp.getMembers().empty()) {
4890 processMapWithMembersOf(moduleTranslation, builder, *ompBuilder, dl,
4891 combinedInfo, mapData, i, targetDirective);
4892 continue;
4893 }
4894
4895 processIndividualMap(mapData, i, combinedInfo, targetDirective);
4896 }
4897}
4898
4899static llvm::Expected<llvm::Function *>
4900emitUserDefinedMapper(Operation *declMapperOp, llvm::IRBuilderBase &builder,
4901 LLVM::ModuleTranslation &moduleTranslation,
4902 llvm::StringRef mapperFuncName,
4903 TargetDirectiveEnumTy targetDirective);
4904
4905static llvm::Expected<llvm::Function *>
4906getOrCreateUserDefinedMapperFunc(Operation *op, llvm::IRBuilderBase &builder,
4907 LLVM::ModuleTranslation &moduleTranslation,
4908 TargetDirectiveEnumTy targetDirective) {
4909 assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
4910 "function only supported for host device codegen");
4911 auto declMapperOp = cast<omp::DeclareMapperOp>(op);
4912 std::string mapperFuncName =
4913 moduleTranslation.getOpenMPBuilder()->createPlatformSpecificName(
4914 {"omp_mapper", declMapperOp.getSymName()});
4915
4916 if (auto *lookupFunc = moduleTranslation.lookupFunction(mapperFuncName))
4917 return lookupFunc;
4918
4919 return emitUserDefinedMapper(declMapperOp, builder, moduleTranslation,
4920 mapperFuncName, targetDirective);
4921}
4922
4923static llvm::Expected<llvm::Function *>
4924emitUserDefinedMapper(Operation *op, llvm::IRBuilderBase &builder,
4925 LLVM::ModuleTranslation &moduleTranslation,
4926 llvm::StringRef mapperFuncName,
4927 TargetDirectiveEnumTy targetDirective) {
4928 assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
4929 "function only supported for host device codegen");
4930 auto declMapperOp = cast<omp::DeclareMapperOp>(op);
4931 auto declMapperInfoOp = declMapperOp.getDeclareMapperInfo();
4932 DataLayout dl = DataLayout(declMapperOp->getParentOfType<ModuleOp>());
4933 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4934 llvm::Type *varType = moduleTranslation.convertType(declMapperOp.getType());
4935 SmallVector<Value> mapVars = declMapperInfoOp.getMapVars();
4936
4937 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
4938
4939 // Fill up the arrays with all the mapped variables.
4940 MapInfosTy combinedInfo;
4941 auto genMapInfoCB =
4942 [&](InsertPointTy codeGenIP, llvm::Value *ptrPHI,
4943 llvm::Value *unused2) -> llvm::OpenMPIRBuilder::MapInfosOrErrorTy {
4944 builder.restoreIP(codeGenIP);
4945 moduleTranslation.mapValue(declMapperOp.getSymVal(), ptrPHI);
4946 moduleTranslation.mapBlock(&declMapperOp.getRegion().front(),
4947 builder.GetInsertBlock());
4948 if (failed(moduleTranslation.convertBlock(declMapperOp.getRegion().front(),
4949 /*ignoreArguments=*/true,
4950 builder)))
4951 return llvm::make_error<PreviouslyReportedError>();
4952 MapInfoData mapData;
4953 collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, dl,
4954 builder);
4955 genMapInfos(builder, moduleTranslation, dl, combinedInfo, mapData,
4956 targetDirective);
4957
4958 // Drop the mapping that is no longer necessary so that the same region
4959 // can be processed multiple times.
4960 moduleTranslation.forgetMapping(declMapperOp.getRegion());
4961 return combinedInfo;
4962 };
4963
4964 auto customMapperCB = [&](unsigned i) -> llvm::Expected<llvm::Function *> {
4965 if (!combinedInfo.Mappers[i])
4966 return nullptr;
4967 return getOrCreateUserDefinedMapperFunc(combinedInfo.Mappers[i], builder,
4968 moduleTranslation, targetDirective);
4969 };
4970
4971 llvm::Expected<llvm::Function *> newFn = ompBuilder->emitUserDefinedMapper(
4972 genMapInfoCB, varType, mapperFuncName, customMapperCB);
4973 if (!newFn)
4974 return newFn.takeError();
4975 moduleTranslation.mapFunction(mapperFuncName, *newFn);
4976 return *newFn;
4977}
4978
4979static LogicalResult
4980convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
4981 LLVM::ModuleTranslation &moduleTranslation) {
4982 llvm::Value *ifCond = nullptr;
4983 int64_t deviceID = llvm::omp::OMP_DEVICEID_UNDEF;
4984 SmallVector<Value> mapVars;
4985 SmallVector<Value> useDevicePtrVars;
4986 SmallVector<Value> useDeviceAddrVars;
4987 llvm::omp::RuntimeFunction RTLFn;
4988 DataLayout DL = DataLayout(op->getParentOfType<ModuleOp>());
4989 TargetDirectiveEnumTy targetDirective = getTargetDirectiveEnumTyFromOp(op);
4990
4991 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4992 llvm::OpenMPIRBuilder::TargetDataInfo info(
4993 /*RequiresDevicePointerInfo=*/true,
4994 /*SeparateBeginEndCalls=*/true);
4995 bool isTargetDevice = ompBuilder->Config.isTargetDevice();
4996 bool isOffloadEntry =
4997 isTargetDevice || !ompBuilder->Config.TargetTriples.empty();
4998
4999 LogicalResult result =
5001 .Case([&](omp::TargetDataOp dataOp) {
5002 if (failed(checkImplementationStatus(*dataOp)))
5003 return failure();
5004
5005 if (auto ifVar = dataOp.getIfExpr())
5006 ifCond = moduleTranslation.lookupValue(ifVar);
5007
5008 if (auto devId = dataOp.getDevice())
5009 if (auto constOp = devId.getDefiningOp<LLVM::ConstantOp>())
5010 if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
5011 deviceID = intAttr.getInt();
5012
5013 mapVars = dataOp.getMapVars();
5014 useDevicePtrVars = dataOp.getUseDevicePtrVars();
5015 useDeviceAddrVars = dataOp.getUseDeviceAddrVars();
5016 return success();
5017 })
5018 .Case([&](omp::TargetEnterDataOp enterDataOp) -> LogicalResult {
5019 if (failed(checkImplementationStatus(*enterDataOp)))
5020 return failure();
5021
5022 if (auto ifVar = enterDataOp.getIfExpr())
5023 ifCond = moduleTranslation.lookupValue(ifVar);
5024
5025 if (auto devId = enterDataOp.getDevice())
5026 if (auto constOp = devId.getDefiningOp<LLVM::ConstantOp>())
5027 if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
5028 deviceID = intAttr.getInt();
5029 RTLFn =
5030 enterDataOp.getNowait()
5031 ? llvm::omp::OMPRTL___tgt_target_data_begin_nowait_mapper
5032 : llvm::omp::OMPRTL___tgt_target_data_begin_mapper;
5033 mapVars = enterDataOp.getMapVars();
5034 info.HasNoWait = enterDataOp.getNowait();
5035 return success();
5036 })
5037 .Case([&](omp::TargetExitDataOp exitDataOp) -> LogicalResult {
5038 if (failed(checkImplementationStatus(*exitDataOp)))
5039 return failure();
5040
5041 if (auto ifVar = exitDataOp.getIfExpr())
5042 ifCond = moduleTranslation.lookupValue(ifVar);
5043
5044 if (auto devId = exitDataOp.getDevice())
5045 if (auto constOp = devId.getDefiningOp<LLVM::ConstantOp>())
5046 if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
5047 deviceID = intAttr.getInt();
5048
5049 RTLFn = exitDataOp.getNowait()
5050 ? llvm::omp::OMPRTL___tgt_target_data_end_nowait_mapper
5051 : llvm::omp::OMPRTL___tgt_target_data_end_mapper;
5052 mapVars = exitDataOp.getMapVars();
5053 info.HasNoWait = exitDataOp.getNowait();
5054 return success();
5055 })
5056 .Case([&](omp::TargetUpdateOp updateDataOp) -> LogicalResult {
5057 if (failed(checkImplementationStatus(*updateDataOp)))
5058 return failure();
5059
5060 if (auto ifVar = updateDataOp.getIfExpr())
5061 ifCond = moduleTranslation.lookupValue(ifVar);
5062
5063 if (auto devId = updateDataOp.getDevice())
5064 if (auto constOp = devId.getDefiningOp<LLVM::ConstantOp>())
5065 if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
5066 deviceID = intAttr.getInt();
5067
5068 RTLFn =
5069 updateDataOp.getNowait()
5070 ? llvm::omp::OMPRTL___tgt_target_data_update_nowait_mapper
5071 : llvm::omp::OMPRTL___tgt_target_data_update_mapper;
5072 mapVars = updateDataOp.getMapVars();
5073 info.HasNoWait = updateDataOp.getNowait();
5074 return success();
5075 })
5076 .DefaultUnreachable("unexpected operation");
5077
5078 if (failed(result))
5079 return failure();
5080 // Pretend we have IF(false) if we're not doing offload.
5081 if (!isOffloadEntry)
5082 ifCond = builder.getFalse();
5083
5084 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
5085 MapInfoData mapData;
5086 collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, DL,
5087 builder, useDevicePtrVars, useDeviceAddrVars);
5088
5089 // Fill up the arrays with all the mapped variables.
5090 MapInfosTy combinedInfo;
5091 auto genMapInfoCB = [&](InsertPointTy codeGenIP) -> MapInfosTy & {
5092 builder.restoreIP(codeGenIP);
5093 genMapInfos(builder, moduleTranslation, DL, combinedInfo, mapData,
5094 targetDirective);
5095 return combinedInfo;
5096 };
5097
5098 // Define a lambda to apply mappings between use_device_addr and
5099 // use_device_ptr base pointers, and their associated block arguments.
5100 auto mapUseDevice =
5101 [&moduleTranslation](
5102 llvm::OpenMPIRBuilder::DeviceInfoTy type,
5104 llvm::SmallVectorImpl<Value> &useDeviceVars, MapInfoData &mapInfoData,
5105 llvm::function_ref<llvm::Value *(llvm::Value *)> mapper = nullptr) {
5106 for (auto [arg, useDevVar] :
5107 llvm::zip_equal(blockArgs, useDeviceVars)) {
5108
5109 auto getMapBasePtr = [](omp::MapInfoOp mapInfoOp) {
5110 return mapInfoOp.getVarPtrPtr() ? mapInfoOp.getVarPtrPtr()
5111 : mapInfoOp.getVarPtr();
5112 };
5113
5114 auto useDevMap = cast<omp::MapInfoOp>(useDevVar.getDefiningOp());
5115 for (auto [mapClause, devicePointer, basePointer] : llvm::zip_equal(
5116 mapInfoData.MapClause, mapInfoData.DevicePointers,
5117 mapInfoData.BasePointers)) {
5118 auto mapOp = cast<omp::MapInfoOp>(mapClause);
5119 if (getMapBasePtr(mapOp) != getMapBasePtr(useDevMap) ||
5120 devicePointer != type)
5121 continue;
5122
5123 if (llvm::Value *devPtrInfoMap =
5124 mapper ? mapper(basePointer) : basePointer) {
5125 moduleTranslation.mapValue(arg, devPtrInfoMap);
5126 break;
5127 }
5128 }
5129 }
5130 };
5131
5132 using BodyGenTy = llvm::OpenMPIRBuilder::BodyGenTy;
5133 auto bodyGenCB = [&](InsertPointTy codeGenIP, BodyGenTy bodyGenType)
5134 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
5135 // We must always restoreIP regardless of doing anything the caller
5136 // does not restore it, leading to incorrect (no) branch generation.
5137 builder.restoreIP(codeGenIP);
5138 assert(isa<omp::TargetDataOp>(op) &&
5139 "BodyGen requested for non TargetDataOp");
5140 auto blockArgIface = cast<omp::BlockArgOpenMPOpInterface>(op);
5141 Region &region = cast<omp::TargetDataOp>(op).getRegion();
5142 switch (bodyGenType) {
5143 case BodyGenTy::Priv:
5144 // Check if any device ptr/addr info is available
5145 if (!info.DevicePtrInfoMap.empty()) {
5146 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
5147 blockArgIface.getUseDeviceAddrBlockArgs(),
5148 useDeviceAddrVars, mapData,
5149 [&](llvm::Value *basePointer) -> llvm::Value * {
5150 if (!info.DevicePtrInfoMap[basePointer].second)
5151 return nullptr;
5152 return builder.CreateLoad(
5153 builder.getPtrTy(),
5154 info.DevicePtrInfoMap[basePointer].second);
5155 });
5156 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
5157 blockArgIface.getUseDevicePtrBlockArgs(), useDevicePtrVars,
5158 mapData, [&](llvm::Value *basePointer) {
5159 return info.DevicePtrInfoMap[basePointer].second;
5160 });
5161
5162 if (failed(inlineConvertOmpRegions(region, "omp.data.region", builder,
5163 moduleTranslation)))
5164 return llvm::make_error<PreviouslyReportedError>();
5165 }
5166 break;
5167 case BodyGenTy::DupNoPriv:
5168 if (info.DevicePtrInfoMap.empty()) {
5169 // For host device we still need to do the mapping for codegen,
5170 // otherwise it may try to lookup a missing value.
5171 if (!ompBuilder->Config.IsTargetDevice.value_or(false)) {
5172 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
5173 blockArgIface.getUseDeviceAddrBlockArgs(),
5174 useDeviceAddrVars, mapData);
5175 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
5176 blockArgIface.getUseDevicePtrBlockArgs(),
5177 useDevicePtrVars, mapData);
5178 }
5179 }
5180 break;
5181 case BodyGenTy::NoPriv:
5182 // If device info is available then region has already been generated
5183 if (info.DevicePtrInfoMap.empty()) {
5184 // For device pass, if use_device_ptr(addr) mappings were present,
5185 // we need to link them here before codegen.
5186 if (ompBuilder->Config.IsTargetDevice.value_or(false)) {
5187 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
5188 blockArgIface.getUseDeviceAddrBlockArgs(),
5189 useDeviceAddrVars, mapData);
5190 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
5191 blockArgIface.getUseDevicePtrBlockArgs(),
5192 useDevicePtrVars, mapData);
5193 }
5194
5195 if (failed(inlineConvertOmpRegions(region, "omp.data.region", builder,
5196 moduleTranslation)))
5197 return llvm::make_error<PreviouslyReportedError>();
5198 }
5199 break;
5200 }
5201 return builder.saveIP();
5202 };
5203
5204 auto customMapperCB =
5205 [&](unsigned int i) -> llvm::Expected<llvm::Function *> {
5206 if (!combinedInfo.Mappers[i])
5207 return nullptr;
5208 info.HasMapper = true;
5209 return getOrCreateUserDefinedMapperFunc(combinedInfo.Mappers[i], builder,
5210 moduleTranslation, targetDirective);
5211 };
5212
5213 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
5214 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
5215 findAllocaInsertPoint(builder, moduleTranslation);
5216 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = [&]() {
5217 if (isa<omp::TargetDataOp>(op))
5218 return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(),
5219 builder.getInt64(deviceID), ifCond,
5220 info, genMapInfoCB, customMapperCB,
5221 /*MapperFunc=*/nullptr, bodyGenCB,
5222 /*DeviceAddrCB=*/nullptr);
5223 return ompBuilder->createTargetData(
5224 ompLoc, allocaIP, builder.saveIP(), builder.getInt64(deviceID), ifCond,
5225 info, genMapInfoCB, customMapperCB, &RTLFn);
5226 }();
5227
5228 if (failed(handleError(afterIP, *op)))
5229 return failure();
5230
5231 builder.restoreIP(*afterIP);
5232 return success();
5233}
5234
5235static LogicalResult
5236convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder,
5237 LLVM::ModuleTranslation &moduleTranslation) {
5238 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
5239 auto distributeOp = cast<omp::DistributeOp>(opInst);
5240 if (failed(checkImplementationStatus(opInst)))
5241 return failure();
5242
5243 /// Process teams op reduction in distribute if the reduction is contained in
5244 /// the distribute op.
5245 omp::TeamsOp teamsOp = opInst.getParentOfType<omp::TeamsOp>();
5246 bool doDistributeReduction =
5247 teamsOp ? teamsReductionContainedInDistribute(teamsOp) : false;
5248
5249 DenseMap<Value, llvm::Value *> reductionVariableMap;
5250 unsigned numReductionVars = teamsOp ? teamsOp.getNumReductionVars() : 0;
5252 SmallVector<llvm::Value *> privateReductionVariables(numReductionVars);
5253 llvm::ArrayRef<bool> isByRef;
5254
5255 if (doDistributeReduction) {
5256 isByRef = getIsByRef(teamsOp.getReductionByref());
5257 assert(isByRef.size() == teamsOp.getNumReductionVars());
5258
5259 collectReductionDecls(teamsOp, reductionDecls);
5260 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
5261 findAllocaInsertPoint(builder, moduleTranslation);
5262
5263 MutableArrayRef<BlockArgument> reductionArgs =
5264 llvm::cast<omp::BlockArgOpenMPOpInterface>(*teamsOp)
5265 .getReductionBlockArgs();
5266
5268 teamsOp, reductionArgs, builder, moduleTranslation, allocaIP,
5269 reductionDecls, privateReductionVariables, reductionVariableMap,
5270 isByRef)))
5271 return failure();
5272 }
5273
5274 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
5275 auto bodyGenCB = [&](InsertPointTy allocaIP,
5276 InsertPointTy codeGenIP) -> llvm::Error {
5277 // Save the alloca insertion point on ModuleTranslation stack for use in
5278 // nested regions.
5280 moduleTranslation, allocaIP);
5281
5282 // DistributeOp has only one region associated with it.
5283 builder.restoreIP(codeGenIP);
5284 PrivateVarsInfo privVarsInfo(distributeOp);
5285
5287 allocatePrivateVars(builder, moduleTranslation, privVarsInfo, allocaIP);
5288 if (handleError(afterAllocas, opInst).failed())
5289 return llvm::make_error<PreviouslyReportedError>();
5290
5291 if (handleError(initPrivateVars(builder, moduleTranslation, privVarsInfo),
5292 opInst)
5293 .failed())
5294 return llvm::make_error<PreviouslyReportedError>();
5295
5296 if (failed(copyFirstPrivateVars(
5297 distributeOp, builder, moduleTranslation, privVarsInfo.mlirVars,
5298 privVarsInfo.llvmVars, privVarsInfo.privatizers,
5299 distributeOp.getPrivateNeedsBarrier())))
5300 return llvm::make_error<PreviouslyReportedError>();
5301
5302 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
5303 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
5305 convertOmpOpRegions(distributeOp.getRegion(), "omp.distribute.region",
5306 builder, moduleTranslation);
5307 if (!regionBlock)
5308 return regionBlock.takeError();
5309 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
5310
5311 // Skip applying a workshare loop below when translating 'distribute
5312 // parallel do' (it's been already handled by this point while translating
5313 // the nested omp.wsloop).
5314 if (!isa_and_present<omp::WsloopOp>(distributeOp.getNestedWrapper())) {
5315 // TODO: Add support for clauses which are valid for DISTRIBUTE
5316 // constructs. Static schedule is the default.
5317 bool hasDistSchedule = distributeOp.getDistScheduleStatic();
5318 auto schedule = hasDistSchedule ? omp::ClauseScheduleKind::Distribute
5319 : omp::ClauseScheduleKind::Static;
5320 // dist_schedule clauses are ordered - otherise this should be false
5321 bool isOrdered = hasDistSchedule;
5322 std::optional<omp::ScheduleModifier> scheduleMod;
5323 bool isSimd = false;
5324 llvm::omp::WorksharingLoopType workshareLoopType =
5325 llvm::omp::WorksharingLoopType::DistributeStaticLoop;
5326 bool loopNeedsBarrier = false;
5327 llvm::Value *chunk = moduleTranslation.lookupValue(
5328 distributeOp.getDistScheduleChunkSize());
5329 llvm::CanonicalLoopInfo *loopInfo =
5330 findCurrentLoopInfo(moduleTranslation);
5331 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
5332 ompBuilder->applyWorkshareLoop(
5333 ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
5334 convertToScheduleKind(schedule), chunk, isSimd,
5335 scheduleMod == omp::ScheduleModifier::monotonic,
5336 scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
5337 workshareLoopType, false, hasDistSchedule, chunk);
5338
5339 if (!wsloopIP)
5340 return wsloopIP.takeError();
5341 }
5342 if (failed(cleanupPrivateVars(builder, moduleTranslation,
5343 distributeOp.getLoc(), privVarsInfo.llvmVars,
5344 privVarsInfo.privatizers)))
5345 return llvm::make_error<PreviouslyReportedError>();
5346
5347 return llvm::Error::success();
5348 };
5349
5350 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
5351 findAllocaInsertPoint(builder, moduleTranslation);
5352 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
5353 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
5354 ompBuilder->createDistribute(ompLoc, allocaIP, bodyGenCB);
5355
5356 if (failed(handleError(afterIP, opInst)))
5357 return failure();
5358
5359 builder.restoreIP(*afterIP);
5360
5361 if (doDistributeReduction) {
5362 // Process the reductions if required.
5364 teamsOp, builder, moduleTranslation, allocaIP, reductionDecls,
5365 privateReductionVariables, isByRef,
5366 /*isNoWait*/ false, /*isTeamsReduction*/ true);
5367 }
5368 return success();
5369}
5370
5371/// Lowers the FlagsAttr which is applied to the module on the device
5372/// pass when offloading, this attribute contains OpenMP RTL globals that can
5373/// be passed as flags to the frontend, otherwise they are set to default
5374static LogicalResult
5375convertFlagsAttr(Operation *op, mlir::omp::FlagsAttr attribute,
5376 LLVM::ModuleTranslation &moduleTranslation) {
5377 if (!cast<mlir::ModuleOp>(op))
5378 return failure();
5379
5380 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
5381
5382 ompBuilder->M.addModuleFlag(llvm::Module::Max, "openmp-device",
5383 attribute.getOpenmpDeviceVersion());
5384
5385 if (attribute.getNoGpuLib())
5386 return success();
5387
5388 ompBuilder->createGlobalFlag(
5389 attribute.getDebugKind() /*LangOpts().OpenMPTargetDebug*/,
5390 "__omp_rtl_debug_kind");
5391 ompBuilder->createGlobalFlag(
5392 attribute
5393 .getAssumeTeamsOversubscription() /*LangOpts().OpenMPTeamSubscription*/
5394 ,
5395 "__omp_rtl_assume_teams_oversubscription");
5396 ompBuilder->createGlobalFlag(
5397 attribute
5398 .getAssumeThreadsOversubscription() /*LangOpts().OpenMPThreadSubscription*/
5399 ,
5400 "__omp_rtl_assume_threads_oversubscription");
5401 ompBuilder->createGlobalFlag(
5402 attribute.getAssumeNoThreadState() /*LangOpts().OpenMPNoThreadState*/,
5403 "__omp_rtl_assume_no_thread_state");
5404 ompBuilder->createGlobalFlag(
5405 attribute
5406 .getAssumeNoNestedParallelism() /*LangOpts().OpenMPNoNestedParallelism*/
5407 ,
5408 "__omp_rtl_assume_no_nested_parallelism");
5409 return success();
5410}
5411
5412static void getTargetEntryUniqueInfo(llvm::TargetRegionEntryInfo &targetInfo,
5413 omp::TargetOp targetOp,
5414 llvm::StringRef parentName = "") {
5415 auto fileLoc = targetOp.getLoc()->findInstanceOf<FileLineColLoc>();
5416
5417 assert(fileLoc && "No file found from location");
5418 StringRef fileName = fileLoc.getFilename().getValue();
5419
5420 llvm::sys::fs::UniqueID id;
5421 uint64_t line = fileLoc.getLine();
5422 if (auto ec = llvm::sys::fs::getUniqueID(fileName, id)) {
5423 size_t fileHash = llvm::hash_value(fileName.str());
5424 size_t deviceId = 0xdeadf17e;
5425 targetInfo =
5426 llvm::TargetRegionEntryInfo(parentName, deviceId, fileHash, line);
5427 } else {
5428 targetInfo = llvm::TargetRegionEntryInfo(parentName, id.getDevice(),
5429 id.getFile(), line);
5430 }
5431}
5432
5433static void
5434handleDeclareTargetMapVar(MapInfoData &mapData,
5435 LLVM::ModuleTranslation &moduleTranslation,
5436 llvm::IRBuilderBase &builder, llvm::Function *func) {
5437 assert(moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
5438 "function only supported for target device codegen");
5439 llvm::IRBuilderBase::InsertPointGuard guard(builder);
5440 for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
5441 // In the case of declare target mapped variables, the basePointer is
5442 // the reference pointer generated by the convertDeclareTargetAttr
5443 // method. Whereas the kernelValue is the original variable, so for
5444 // the device we must replace all uses of this original global variable
5445 // (stored in kernelValue) with the reference pointer (stored in
5446 // basePointer for declare target mapped variables), as for device the
5447 // data is mapped into this reference pointer and should be loaded
5448 // from it, the original variable is discarded. On host both exist and
5449 // metadata is generated (elsewhere in the convertDeclareTargetAttr)
5450 // function to link the two variables in the runtime and then both the
5451 // reference pointer and the pointer are assigned in the kernel argument
5452 // structure for the host.
5453 if (mapData.IsDeclareTarget[i]) {
5454 // If the original map value is a constant, then we have to make sure all
5455 // of it's uses within the current kernel/function that we are going to
5456 // rewrite are converted to instructions, as we will be altering the old
5457 // use (OriginalValue) from a constant to an instruction, which will be
5458 // illegal and ICE the compiler if the user is a constant expression of
5459 // some kind e.g. a constant GEP.
5460 if (auto *constant = dyn_cast<llvm::Constant>(mapData.OriginalValue[i]))
5461 convertUsersOfConstantsToInstructions(constant, func, false);
5462
5463 // The users iterator will get invalidated if we modify an element,
5464 // so we populate this vector of uses to alter each user on an
5465 // individual basis to emit its own load (rather than one load for
5466 // all).
5468 for (llvm::User *user : mapData.OriginalValue[i]->users())
5469 userVec.push_back(user);
5470
5471 for (llvm::User *user : userVec) {
5472 if (auto *insn = dyn_cast<llvm::Instruction>(user)) {
5473 if (insn->getFunction() == func) {
5474 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
5475 llvm::Value *substitute = mapData.BasePointers[i];
5476 if (isDeclareTargetLink(mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr()
5477 : mapOp.getVarPtr())) {
5478 builder.SetCurrentDebugLocation(insn->getDebugLoc());
5479 substitute = builder.CreateLoad(
5480 mapData.BasePointers[i]->getType(), mapData.BasePointers[i]);
5481 cast<llvm::LoadInst>(substitute)->moveBefore(insn->getIterator());
5482 }
5483 user->replaceUsesOfWith(mapData.OriginalValue[i], substitute);
5484 }
5485 }
5486 }
5487 }
5488 }
5489}
5490
5491// The createDeviceArgumentAccessor function generates
5492// instructions for retrieving (acessing) kernel
5493// arguments inside of the device kernel for use by
5494// the kernel. This enables different semantics such as
5495// the creation of temporary copies of data allowing
5496// semantics like read-only/no host write back kernel
5497// arguments.
5498//
5499// This currently implements a very light version of Clang's
5500// EmitParmDecl's handling of direct argument handling as well
5501// as a portion of the argument access generation based on
5502// capture types found at the end of emitOutlinedFunctionPrologue
5503// in Clang. The indirect path handling of EmitParmDecl's may be
5504// required for future work, but a direct 1-to-1 copy doesn't seem
5505// possible as the logic is rather scattered throughout Clang's
5506// lowering and perhaps we wish to deviate slightly.
5507//
5508// \param mapData - A container containing vectors of information
5509// corresponding to the input argument, which should have a
5510// corresponding entry in the MapInfoData containers
5511// OrigialValue's.
5512// \param arg - This is the generated kernel function argument that
5513// corresponds to the passed in input argument. We generated different
5514// accesses of this Argument, based on capture type and other Input
5515// related information.
5516// \param input - This is the host side value that will be passed to
5517// the kernel i.e. the kernel input, we rewrite all uses of this within
5518// the kernel (as we generate the kernel body based on the target's region
5519// which maintians references to the original input) to the retVal argument
5520// apon exit of this function inside of the OMPIRBuilder. This interlinks
5521// the kernel argument to future uses of it in the function providing
5522// appropriate "glue" instructions inbetween.
5523// \param retVal - This is the value that all uses of input inside of the
5524// kernel will be re-written to, the goal of this function is to generate
5525// an appropriate location for the kernel argument to be accessed from,
5526// e.g. ByRef will result in a temporary allocation location and then
5527// a store of the kernel argument into this allocated memory which
5528// will then be loaded from, ByCopy will use the allocated memory
5529// directly.
5530static llvm::IRBuilderBase::InsertPoint
5531createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg,
5532 llvm::Value *input, llvm::Value *&retVal,
5533 llvm::IRBuilderBase &builder,
5534 llvm::OpenMPIRBuilder &ompBuilder,
5535 LLVM::ModuleTranslation &moduleTranslation,
5536 llvm::IRBuilderBase::InsertPoint allocaIP,
5537 llvm::IRBuilderBase::InsertPoint codeGenIP) {
5538 assert(ompBuilder.Config.isTargetDevice() &&
5539 "function only supported for target device codegen");
5540 builder.restoreIP(allocaIP);
5541
5542 omp::VariableCaptureKind capture = omp::VariableCaptureKind::ByRef;
5543 LLVM::TypeToLLVMIRTranslator typeToLLVMIRTranslator(
5544 ompBuilder.M.getContext());
5545 unsigned alignmentValue = 0;
5546 // Find the associated MapInfoData entry for the current input
5547 for (size_t i = 0; i < mapData.MapClause.size(); ++i)
5548 if (mapData.OriginalValue[i] == input) {
5549 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
5550 capture = mapOp.getMapCaptureType();
5551 // Get information of alignment of mapped object
5552 alignmentValue = typeToLLVMIRTranslator.getPreferredAlignment(
5553 mapOp.getVarType(), ompBuilder.M.getDataLayout());
5554 break;
5555 }
5556
5557 unsigned int allocaAS = ompBuilder.M.getDataLayout().getAllocaAddrSpace();
5558 unsigned int defaultAS =
5559 ompBuilder.M.getDataLayout().getProgramAddressSpace();
5560
5561 // Create the alloca for the argument the current point.
5562 llvm::Value *v = builder.CreateAlloca(arg.getType(), allocaAS);
5563
5564 if (allocaAS != defaultAS && arg.getType()->isPointerTy())
5565 v = builder.CreateAddrSpaceCast(v, builder.getPtrTy(defaultAS));
5566
5567 builder.CreateStore(&arg, v);
5568
5569 builder.restoreIP(codeGenIP);
5570
5571 switch (capture) {
5572 case omp::VariableCaptureKind::ByCopy: {
5573 retVal = v;
5574 break;
5575 }
5576 case omp::VariableCaptureKind::ByRef: {
5577 llvm::LoadInst *loadInst = builder.CreateAlignedLoad(
5578 v->getType(), v,
5579 ompBuilder.M.getDataLayout().getPrefTypeAlign(v->getType()));
5580 // CreateAlignedLoad function creates similar LLVM IR:
5581 // %res = load ptr, ptr %input, align 8
5582 // This LLVM IR does not contain information about alignment
5583 // of the loaded value. We need to add !align metadata to unblock
5584 // optimizer. The existence of the !align metadata on the instruction
5585 // tells the optimizer that the value loaded is known to be aligned to
5586 // a boundary specified by the integer value in the metadata node.
5587 // Example:
5588 // %res = load ptr, ptr %input, align 8, !align !align_md_node
5589 // ^ ^
5590 // | |
5591 // alignment of %input address |
5592 // |
5593 // alignment of %res object
5594 if (v->getType()->isPointerTy() && alignmentValue) {
5595 llvm::MDBuilder MDB(builder.getContext());
5596 loadInst->setMetadata(
5597 llvm::LLVMContext::MD_align,
5598 llvm::MDNode::get(builder.getContext(),
5599 MDB.createConstant(llvm::ConstantInt::get(
5600 llvm::Type::getInt64Ty(builder.getContext()),
5601 alignmentValue))));
5602 }
5603 retVal = loadInst;
5604
5605 break;
5606 }
5607 case omp::VariableCaptureKind::This:
5608 case omp::VariableCaptureKind::VLAType:
5609 // TODO: Consider returning error to use standard reporting for
5610 // unimplemented features.
5611 assert(false && "Currently unsupported capture kind");
5612 break;
5613 }
5614
5615 return builder.saveIP();
5616}
5617
5618/// Follow uses of `host_eval`-defined block arguments of the given `omp.target`
5619/// operation and populate output variables with their corresponding host value
5620/// (i.e. operand evaluated outside of the target region), based on their uses
5621/// inside of the target region.
5622///
5623/// Loop bounds and steps are only optionally populated, if output vectors are
5624/// provided.
5625static void
5626extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
5627 Value &numTeamsLower, Value &numTeamsUpper,
5628 Value &threadLimit,
5629 llvm::SmallVectorImpl<Value> *lowerBounds = nullptr,
5630 llvm::SmallVectorImpl<Value> *upperBounds = nullptr,
5631 llvm::SmallVectorImpl<Value> *steps = nullptr) {
5632 auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
5633 for (auto item : llvm::zip_equal(targetOp.getHostEvalVars(),
5634 blockArgIface.getHostEvalBlockArgs())) {
5635 Value hostEvalVar = std::get<0>(item), blockArg = std::get<1>(item);
5636
5637 for (Operation *user : blockArg.getUsers()) {
5639 .Case([&](omp::TeamsOp teamsOp) {
5640 if (teamsOp.getNumTeamsLower() == blockArg)
5641 numTeamsLower = hostEvalVar;
5642 else if (teamsOp.getNumTeamsUpper() == blockArg)
5643 numTeamsUpper = hostEvalVar;
5644 else if (teamsOp.getThreadLimit() == blockArg)
5645 threadLimit = hostEvalVar;
5646 else
5647 llvm_unreachable("unsupported host_eval use");
5648 })
5649 .Case([&](omp::ParallelOp parallelOp) {
5650 if (parallelOp.getNumThreads() == blockArg)
5651 numThreads = hostEvalVar;
5652 else
5653 llvm_unreachable("unsupported host_eval use");
5654 })
5655 .Case([&](omp::LoopNestOp loopOp) {
5656 auto processBounds =
5657 [&](OperandRange opBounds,
5658 llvm::SmallVectorImpl<Value> *outBounds) -> bool {
5659 bool found = false;
5660 for (auto [i, lb] : llvm::enumerate(opBounds)) {
5661 if (lb == blockArg) {
5662 found = true;
5663 if (outBounds)
5664 (*outBounds)[i] = hostEvalVar;
5665 }
5666 }
5667 return found;
5668 };
5669 bool found =
5670 processBounds(loopOp.getLoopLowerBounds(), lowerBounds);
5671 found = processBounds(loopOp.getLoopUpperBounds(), upperBounds) ||
5672 found;
5673 found = processBounds(loopOp.getLoopSteps(), steps) || found;
5674 (void)found;
5675 assert(found && "unsupported host_eval use");
5676 })
5677 .DefaultUnreachable("unsupported host_eval use");
5678 }
5679 }
5680}
5681
5682/// If \p op is of the given type parameter, return it casted to that type.
5683/// Otherwise, if its immediate parent operation (or some other higher-level
5684/// parent, if \p immediateParent is false) is of that type, return that parent
5685/// casted to the given type.
5686///
5687/// If \p op is \c null or neither it or its parent(s) are of the specified
5688/// type, return a \c null operation.
5689template <typename OpTy>
5690static OpTy castOrGetParentOfType(Operation *op, bool immediateParent = false) {
5691 if (!op)
5692 return OpTy();
5693
5694 if (OpTy casted = dyn_cast<OpTy>(op))
5695 return casted;
5696
5697 if (immediateParent)
5698 return dyn_cast_if_present<OpTy>(op->getParentOp());
5699
5700 return op->getParentOfType<OpTy>();
5701}
5702
5703/// If the given \p value is defined by an \c llvm.mlir.constant operation and
5704/// it is of an integer type, return its value.
5705static std::optional<int64_t> extractConstInteger(Value value) {
5706 if (!value)
5707 return std::nullopt;
5708
5709 if (auto constOp = value.getDefiningOp<LLVM::ConstantOp>())
5710 if (auto constAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
5711 return constAttr.getInt();
5712
5713 return std::nullopt;
5714}
5715
5716static uint64_t getTypeByteSize(mlir::Type type, const DataLayout &dl) {
5717 uint64_t sizeInBits = dl.getTypeSizeInBits(type);
5718 uint64_t sizeInBytes = sizeInBits / 8;
5719 return sizeInBytes;
5720}
5721
5722template <typename OpTy>
5723static uint64_t getReductionDataSize(OpTy &op) {
5724 if (op.getNumReductionVars() > 0) {
5726 collectReductionDecls(op, reductions);
5727
5729 members.reserve(reductions.size());
5730 for (omp::DeclareReductionOp &red : reductions)
5731 members.push_back(red.getType());
5732 Operation *opp = op.getOperation();
5733 auto structType = mlir::LLVM::LLVMStructType::getLiteral(
5734 opp->getContext(), members, /*isPacked=*/false);
5735 DataLayout dl = DataLayout(opp->getParentOfType<ModuleOp>());
5736 return getTypeByteSize(structType, dl);
5737 }
5738 return 0;
5739}
5740
5741/// Populate default `MinTeams`, `MaxTeams` and `MaxThreads` to their default
5742/// values as stated by the corresponding clauses, if constant.
5743///
5744/// These default values must be set before the creation of the outlined LLVM
5745/// function for the target region, so that they can be used to initialize the
5746/// corresponding global `ConfigurationEnvironmentTy` structure.
5747static void
5748initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
5749 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs,
5750 bool isTargetDevice, bool isGPU) {
5751 // TODO: Handle constant 'if' clauses.
5752
5753 Value numThreads, numTeamsLower, numTeamsUpper, threadLimit;
5754 if (!isTargetDevice) {
5755 extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper,
5756 threadLimit);
5757 } else {
5758 // In the target device, values for these clauses are not passed as
5759 // host_eval, but instead evaluated prior to entry to the region. This
5760 // ensures values are mapped and available inside of the target region.
5761 if (auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
5762 numTeamsLower = teamsOp.getNumTeamsLower();
5763 numTeamsUpper = teamsOp.getNumTeamsUpper();
5764 threadLimit = teamsOp.getThreadLimit();
5765 }
5766
5767 if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp))
5768 numThreads = parallelOp.getNumThreads();
5769 }
5770
5771 // Handle clauses impacting the number of teams.
5772
5773 int32_t minTeamsVal = 1, maxTeamsVal = -1;
5774 if (castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
5775 // TODO: Use `hostNumTeamsLower` to initialize `minTeamsVal`. For now,
5776 // match clang and set min and max to the same value.
5777 if (numTeamsUpper) {
5778 if (auto val = extractConstInteger(numTeamsUpper))
5779 minTeamsVal = maxTeamsVal = *val;
5780 } else {
5781 minTeamsVal = maxTeamsVal = 0;
5782 }
5783 } else if (castOrGetParentOfType<omp::ParallelOp>(capturedOp,
5784 /*immediateParent=*/true) ||
5786 /*immediateParent=*/true)) {
5787 minTeamsVal = maxTeamsVal = 1;
5788 } else {
5789 minTeamsVal = maxTeamsVal = -1;
5790 }
5791
5792 // Handle clauses impacting the number of threads.
5793
5794 auto setMaxValueFromClause = [](Value clauseValue, int32_t &result) {
5795 if (!clauseValue)
5796 return;
5797
5798 if (auto val = extractConstInteger(clauseValue))
5799 result = *val;
5800
5801 // Found an applicable clause, so it's not undefined. Mark as unknown
5802 // because it's not constant.
5803 if (result < 0)
5804 result = 0;
5805 };
5806
5807 // Extract 'thread_limit' clause from 'target' and 'teams' directives.
5808 int32_t targetThreadLimitVal = -1, teamsThreadLimitVal = -1;
5809 setMaxValueFromClause(targetOp.getThreadLimit(), targetThreadLimitVal);
5810 setMaxValueFromClause(threadLimit, teamsThreadLimitVal);
5811
5812 // Extract 'max_threads' clause from 'parallel' or set to 1 if it's SIMD.
5813 int32_t maxThreadsVal = -1;
5815 setMaxValueFromClause(numThreads, maxThreadsVal);
5816 else if (castOrGetParentOfType<omp::SimdOp>(capturedOp,
5817 /*immediateParent=*/true))
5818 maxThreadsVal = 1;
5819
5820 // For max values, < 0 means unset, == 0 means set but unknown. Select the
5821 // minimum value between 'max_threads' and 'thread_limit' clauses that were
5822 // set.
5823 int32_t combinedMaxThreadsVal = targetThreadLimitVal;
5824 if (combinedMaxThreadsVal < 0 ||
5825 (teamsThreadLimitVal >= 0 && teamsThreadLimitVal < combinedMaxThreadsVal))
5826 combinedMaxThreadsVal = teamsThreadLimitVal;
5827
5828 if (combinedMaxThreadsVal < 0 ||
5829 (maxThreadsVal >= 0 && maxThreadsVal < combinedMaxThreadsVal))
5830 combinedMaxThreadsVal = maxThreadsVal;
5831
5832 int32_t reductionDataSize = 0;
5833 if (isGPU && capturedOp) {
5834 if (auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp))
5835 reductionDataSize = getReductionDataSize(teamsOp);
5836 }
5837
5838 // Update kernel bounds structure for the `OpenMPIRBuilder` to use.
5839 omp::TargetRegionFlags kernelFlags = targetOp.getKernelExecFlags(capturedOp);
5840 assert(
5841 omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic |
5842 omp::TargetRegionFlags::spmd) &&
5843 "invalid kernel flags");
5844 attrs.ExecFlags =
5845 omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic)
5846 ? omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::spmd)
5847 ? llvm::omp::OMP_TGT_EXEC_MODE_GENERIC_SPMD
5848 : llvm::omp::OMP_TGT_EXEC_MODE_GENERIC
5849 : llvm::omp::OMP_TGT_EXEC_MODE_SPMD;
5850 if (omp::bitEnumContainsAll(kernelFlags,
5851 omp::TargetRegionFlags::spmd |
5852 omp::TargetRegionFlags::no_loop) &&
5853 !omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic))
5854 attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_SPMD_NO_LOOP;
5855
5856 attrs.MinTeams = minTeamsVal;
5857 attrs.MaxTeams.front() = maxTeamsVal;
5858 attrs.MinThreads = 1;
5859 attrs.MaxThreads.front() = combinedMaxThreadsVal;
5860 attrs.ReductionDataSize = reductionDataSize;
5861 // TODO: Allow modified buffer length similar to
5862 // fopenmp-cuda-teams-reduction-recs-num flag in clang.
5863 if (attrs.ReductionDataSize != 0)
5864 attrs.ReductionBufferLength = 1024;
5865}
5866
5867/// Gather LLVM runtime values for all clauses evaluated in the host that are
5868/// passed to the kernel invocation.
5869///
5870/// This function must be called only when compiling for the host. Also, it will
5871/// only provide correct results if it's called after the body of \c targetOp
5872/// has been fully generated.
5873static void
5874initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
5875 LLVM::ModuleTranslation &moduleTranslation,
5876 omp::TargetOp targetOp, Operation *capturedOp,
5877 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs) {
5878 omp::LoopNestOp loopOp = castOrGetParentOfType<omp::LoopNestOp>(capturedOp);
5879 unsigned numLoops = loopOp ? loopOp.getNumLoops() : 0;
5880
5881 Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
5882 llvm::SmallVector<Value> lowerBounds(numLoops), upperBounds(numLoops),
5883 steps(numLoops);
5884 extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper,
5885 teamsThreadLimit, &lowerBounds, &upperBounds, &steps);
5886
5887 // TODO: Handle constant 'if' clauses.
5888 if (Value targetThreadLimit = targetOp.getThreadLimit())
5889 attrs.TargetThreadLimit.front() =
5890 moduleTranslation.lookupValue(targetThreadLimit);
5891
5892 if (numTeamsLower)
5893 attrs.MinTeams = moduleTranslation.lookupValue(numTeamsLower);
5894
5895 if (numTeamsUpper)
5896 attrs.MaxTeams.front() = moduleTranslation.lookupValue(numTeamsUpper);
5897
5898 if (teamsThreadLimit)
5899 attrs.TeamsThreadLimit.front() =
5900 moduleTranslation.lookupValue(teamsThreadLimit);
5901
5902 if (numThreads)
5903 attrs.MaxThreads = moduleTranslation.lookupValue(numThreads);
5904
5905 if (omp::bitEnumContainsAny(targetOp.getKernelExecFlags(capturedOp),
5906 omp::TargetRegionFlags::trip_count)) {
5907 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
5908 attrs.LoopTripCount = nullptr;
5909
5910 // To calculate the trip count, we multiply together the trip counts of
5911 // every collapsed canonical loop. We don't need to create the loop nests
5912 // here, since we're only interested in the trip count.
5913 for (auto [loopLower, loopUpper, loopStep] :
5914 llvm::zip_equal(lowerBounds, upperBounds, steps)) {
5915 llvm::Value *lowerBound = moduleTranslation.lookupValue(loopLower);
5916 llvm::Value *upperBound = moduleTranslation.lookupValue(loopUpper);
5917 llvm::Value *step = moduleTranslation.lookupValue(loopStep);
5918
5919 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
5920 llvm::Value *tripCount = ompBuilder->calculateCanonicalLoopTripCount(
5921 loc, lowerBound, upperBound, step, /*IsSigned=*/true,
5922 loopOp.getLoopInclusive());
5923
5924 if (!attrs.LoopTripCount) {
5925 attrs.LoopTripCount = tripCount;
5926 continue;
5927 }
5928
5929 // TODO: Enable UndefinedSanitizer to diagnose an overflow here.
5930 attrs.LoopTripCount = builder.CreateMul(attrs.LoopTripCount, tripCount,
5931 {}, /*HasNUW=*/true);
5932 }
5933 }
5934}
5935
5936static LogicalResult
5937convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
5938 LLVM::ModuleTranslation &moduleTranslation) {
5939 auto targetOp = cast<omp::TargetOp>(opInst);
5940 // The current debug location already has the DISubprogram for the outlined
5941 // function that will be created for the target op. We save it here so that
5942 // we can set it on the outlined function.
5943 llvm::DebugLoc outlinedFnLoc = builder.getCurrentDebugLocation();
5944 if (failed(checkImplementationStatus(opInst)))
5945 return failure();
5946
5947 // During the handling of target op, we will generate instructions in the
5948 // parent function like call to the oulined function or branch to a new
5949 // BasicBlock. We set the debug location here to parent function so that those
5950 // get the correct debug locations. For outlined functions, the normal MLIR op
5951 // conversion will automatically pick the correct location.
5952 llvm::BasicBlock *parentBB = builder.GetInsertBlock();
5953 assert(parentBB && "No insert block is set for the builder");
5954 llvm::Function *parentLLVMFn = parentBB->getParent();
5955 assert(parentLLVMFn && "Parent Function must be valid");
5956 if (llvm::DISubprogram *SP = parentLLVMFn->getSubprogram())
5957 builder.SetCurrentDebugLocation(llvm::DILocation::get(
5958 parentLLVMFn->getContext(), outlinedFnLoc.getLine(),
5959 outlinedFnLoc.getCol(), SP, outlinedFnLoc.getInlinedAt()));
5960
5961 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
5962 bool isTargetDevice = ompBuilder->Config.isTargetDevice();
5963 bool isGPU = ompBuilder->Config.isGPU();
5964
5965 auto parentFn = opInst.getParentOfType<LLVM::LLVMFuncOp>();
5966 auto argIface = cast<omp::BlockArgOpenMPOpInterface>(opInst);
5967 auto &targetRegion = targetOp.getRegion();
5968 // Holds the private vars that have been mapped along with the block
5969 // argument that corresponds to the MapInfoOp corresponding to the private
5970 // var in question. So, for instance:
5971 //
5972 // %10 = omp.map.info var_ptr(%6#0 : !fir.ref<!fir.box<!fir.heap<i32>>>, ..)
5973 // omp.target map_entries(%10 -> %arg0) private(@box.privatizer %6#0-> %arg1)
5974 //
5975 // Then, %10 has been created so that the descriptor can be used by the
5976 // privatizer @box.privatizer on the device side. Here we'd record {%6#0,
5977 // %arg0} in the mappedPrivateVars map.
5978 llvm::DenseMap<Value, Value> mappedPrivateVars;
5979 DataLayout dl = DataLayout(opInst.getParentOfType<ModuleOp>());
5980 SmallVector<Value> mapVars = targetOp.getMapVars();
5981 SmallVector<Value> hdaVars = targetOp.getHasDeviceAddrVars();
5982 ArrayRef<BlockArgument> mapBlockArgs = argIface.getMapBlockArgs();
5983 ArrayRef<BlockArgument> hdaBlockArgs = argIface.getHasDeviceAddrBlockArgs();
5984 llvm::Function *llvmOutlinedFn = nullptr;
5985 TargetDirectiveEnumTy targetDirective =
5986 getTargetDirectiveEnumTyFromOp(&opInst);
5987
5988 // TODO: It can also be false if a compile-time constant `false` IF clause is
5989 // specified.
5990 bool isOffloadEntry =
5991 isTargetDevice || !ompBuilder->Config.TargetTriples.empty();
5992
5993 // For some private variables, the MapsForPrivatizedVariablesPass
5994 // creates MapInfoOp instances. Go through the private variables and
5995 // the mapped variables so that during codegeneration we are able
5996 // to quickly look up the corresponding map variable, if any for each
5997 // private variable.
5998 if (!targetOp.getPrivateVars().empty() && !targetOp.getMapVars().empty()) {
5999 OperandRange privateVars = targetOp.getPrivateVars();
6000 std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms();
6001 std::optional<DenseI64ArrayAttr> privateMapIndices =
6002 targetOp.getPrivateMapsAttr();
6003
6004 for (auto [privVarIdx, privVarSymPair] :
6005 llvm::enumerate(llvm::zip_equal(privateVars, *privateSyms))) {
6006 auto privVar = std::get<0>(privVarSymPair);
6007 auto privSym = std::get<1>(privVarSymPair);
6008
6009 SymbolRefAttr privatizerName = llvm::cast<SymbolRefAttr>(privSym);
6010 omp::PrivateClauseOp privatizer =
6011 findPrivatizer(targetOp, privatizerName);
6012
6013 if (!privatizer.needsMap())
6014 continue;
6015
6016 mlir::Value mappedValue =
6017 targetOp.getMappedValueForPrivateVar(privVarIdx);
6018 assert(mappedValue && "Expected to find mapped value for a privatized "
6019 "variable that needs mapping");
6020
6021 // The MapInfoOp defining the map var isn't really needed later.
6022 // So, we don't store it in any datastructure. Instead, we just
6023 // do some sanity checks on it right now.
6024 auto mapInfoOp = mappedValue.getDefiningOp<omp::MapInfoOp>();
6025 [[maybe_unused]] Type varType = mapInfoOp.getVarType();
6026
6027 // Check #1: Check that the type of the private variable matches
6028 // the type of the variable being mapped.
6029 if (!isa<LLVM::LLVMPointerType>(privVar.getType()))
6030 assert(
6031 varType == privVar.getType() &&
6032 "Type of private var doesn't match the type of the mapped value");
6033
6034 // Ok, only 1 sanity check for now.
6035 // Record the block argument corresponding to this mapvar.
6036 mappedPrivateVars.insert(
6037 {privVar,
6038 targetRegion.getArgument(argIface.getMapBlockArgsStart() +
6039 (*privateMapIndices)[privVarIdx])});
6040 }
6041 }
6042
6043 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
6044 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP)
6045 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
6046 llvm::IRBuilderBase::InsertPointGuard guard(builder);
6047 builder.SetCurrentDebugLocation(llvm::DebugLoc());
6048 // Forward target-cpu and target-features function attributes from the
6049 // original function to the new outlined function.
6050 llvm::Function *llvmParentFn =
6051 moduleTranslation.lookupFunction(parentFn.getName());
6052 llvmOutlinedFn = codeGenIP.getBlock()->getParent();
6053 assert(llvmParentFn && llvmOutlinedFn &&
6054 "Both parent and outlined functions must exist at this point");
6055
6056 if (outlinedFnLoc && llvmParentFn->getSubprogram())
6057 llvmOutlinedFn->setSubprogram(outlinedFnLoc->getScope()->getSubprogram());
6058
6059 if (auto attr = llvmParentFn->getFnAttribute("target-cpu");
6060 attr.isStringAttribute())
6061 llvmOutlinedFn->addFnAttr(attr);
6062
6063 if (auto attr = llvmParentFn->getFnAttribute("target-features");
6064 attr.isStringAttribute())
6065 llvmOutlinedFn->addFnAttr(attr);
6066
6067 for (auto [arg, mapOp] : llvm::zip_equal(mapBlockArgs, mapVars)) {
6068 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
6069 llvm::Value *mapOpValue =
6070 moduleTranslation.lookupValue(mapInfoOp.getVarPtr());
6071 moduleTranslation.mapValue(arg, mapOpValue);
6072 }
6073 for (auto [arg, mapOp] : llvm::zip_equal(hdaBlockArgs, hdaVars)) {
6074 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
6075 llvm::Value *mapOpValue =
6076 moduleTranslation.lookupValue(mapInfoOp.getVarPtr());
6077 moduleTranslation.mapValue(arg, mapOpValue);
6078 }
6079
6080 // Do privatization after moduleTranslation has already recorded
6081 // mapped values.
6082 PrivateVarsInfo privateVarsInfo(targetOp);
6083
6085 allocatePrivateVars(builder, moduleTranslation, privateVarsInfo,
6086 allocaIP, &mappedPrivateVars);
6087
6088 if (failed(handleError(afterAllocas, *targetOp)))
6089 return llvm::make_error<PreviouslyReportedError>();
6090
6091 builder.restoreIP(codeGenIP);
6092 if (handleError(initPrivateVars(builder, moduleTranslation, privateVarsInfo,
6093 &mappedPrivateVars),
6094 *targetOp)
6095 .failed())
6096 return llvm::make_error<PreviouslyReportedError>();
6097
6098 if (failed(copyFirstPrivateVars(
6099 targetOp, builder, moduleTranslation, privateVarsInfo.mlirVars,
6100 privateVarsInfo.llvmVars, privateVarsInfo.privatizers,
6101 targetOp.getPrivateNeedsBarrier(), &mappedPrivateVars)))
6102 return llvm::make_error<PreviouslyReportedError>();
6103
6104 SmallVector<Region *> privateCleanupRegions;
6105 llvm::transform(privateVarsInfo.privatizers,
6106 std::back_inserter(privateCleanupRegions),
6107 [](omp::PrivateClauseOp privatizer) {
6108 return &privatizer.getDeallocRegion();
6109 });
6110
6112 targetRegion, "omp.target", builder, moduleTranslation);
6113
6114 if (!exitBlock)
6115 return exitBlock.takeError();
6116
6117 builder.SetInsertPoint(*exitBlock);
6118 if (!privateCleanupRegions.empty()) {
6119 if (failed(inlineOmpRegionCleanup(
6120 privateCleanupRegions, privateVarsInfo.llvmVars,
6121 moduleTranslation, builder, "omp.targetop.private.cleanup",
6122 /*shouldLoadCleanupRegionArg=*/false))) {
6123 return llvm::createStringError(
6124 "failed to inline `dealloc` region of `omp.private` "
6125 "op in the target region");
6126 }
6127 return builder.saveIP();
6128 }
6129
6130 return InsertPointTy(exitBlock.get(), exitBlock.get()->end());
6131 };
6132
6133 StringRef parentName = parentFn.getName();
6134
6135 llvm::TargetRegionEntryInfo entryInfo;
6136
6137 getTargetEntryUniqueInfo(entryInfo, targetOp, parentName);
6138
6139 MapInfoData mapData;
6140 collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, dl,
6141 builder, /*useDevPtrOperands=*/{},
6142 /*useDevAddrOperands=*/{}, hdaVars);
6143
6144 MapInfosTy combinedInfos;
6145 auto genMapInfoCB =
6146 [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP) -> MapInfosTy & {
6147 builder.restoreIP(codeGenIP);
6148 genMapInfos(builder, moduleTranslation, dl, combinedInfos, mapData,
6149 targetDirective);
6150 return combinedInfos;
6151 };
6152
6153 auto argAccessorCB = [&](llvm::Argument &arg, llvm::Value *input,
6154 llvm::Value *&retVal, InsertPointTy allocaIP,
6155 InsertPointTy codeGenIP)
6156 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
6157 llvm::IRBuilderBase::InsertPointGuard guard(builder);
6158 builder.SetCurrentDebugLocation(llvm::DebugLoc());
6159 // We just return the unaltered argument for the host function
6160 // for now, some alterations may be required in the future to
6161 // keep host fallback functions working identically to the device
6162 // version (e.g. pass ByCopy values should be treated as such on
6163 // host and device, currently not always the case)
6164 if (!isTargetDevice) {
6165 retVal = cast<llvm::Value>(&arg);
6166 return codeGenIP;
6167 }
6168
6169 return createDeviceArgumentAccessor(mapData, arg, input, retVal, builder,
6170 *ompBuilder, moduleTranslation,
6171 allocaIP, codeGenIP);
6172 };
6173
6174 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs runtimeAttrs;
6175 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs;
6176 Operation *targetCapturedOp = targetOp.getInnermostCapturedOmpOp();
6177 initTargetDefaultAttrs(targetOp, targetCapturedOp, defaultAttrs,
6178 isTargetDevice, isGPU);
6179
6180 // Collect host-evaluated values needed to properly launch the kernel from the
6181 // host.
6182 if (!isTargetDevice)
6183 initTargetRuntimeAttrs(builder, moduleTranslation, targetOp,
6184 targetCapturedOp, runtimeAttrs);
6185
6186 // Pass host-evaluated values as parameters to the kernel / host fallback,
6187 // except if they are constants. In any case, map the MLIR block argument to
6188 // the corresponding LLVM values.
6190 SmallVector<Value> hostEvalVars = targetOp.getHostEvalVars();
6191 ArrayRef<BlockArgument> hostEvalBlockArgs = argIface.getHostEvalBlockArgs();
6192 for (auto [arg, var] : llvm::zip_equal(hostEvalBlockArgs, hostEvalVars)) {
6193 llvm::Value *value = moduleTranslation.lookupValue(var);
6194 moduleTranslation.mapValue(arg, value);
6195
6196 if (!llvm::isa<llvm::Constant>(value))
6197 kernelInput.push_back(value);
6198 }
6199
6200 for (size_t i = 0, e = mapData.OriginalValue.size(); i != e; ++i) {
6201 // declare target arguments are not passed to kernels as arguments
6202 // TODO: We currently do not handle cases where a member is explicitly
6203 // passed in as an argument, this will likley need to be handled in
6204 // the near future, rather than using IsAMember, it may be better to
6205 // test if the relevant BlockArg is used within the target region and
6206 // then use that as a basis for exclusion in the kernel inputs.
6207 if (!mapData.IsDeclareTarget[i] && !mapData.IsAMember[i])
6208 kernelInput.push_back(mapData.OriginalValue[i]);
6209 }
6210
6212 buildDependData(targetOp.getDependKinds(), targetOp.getDependVars(),
6213 moduleTranslation, dds);
6214
6215 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
6216 findAllocaInsertPoint(builder, moduleTranslation);
6217 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
6218
6219 llvm::OpenMPIRBuilder::TargetDataInfo info(
6220 /*RequiresDevicePointerInfo=*/false,
6221 /*SeparateBeginEndCalls=*/true);
6222
6223 auto customMapperCB =
6224 [&](unsigned int i) -> llvm::Expected<llvm::Function *> {
6225 if (!combinedInfos.Mappers[i])
6226 return nullptr;
6227 info.HasMapper = true;
6228 return getOrCreateUserDefinedMapperFunc(combinedInfos.Mappers[i], builder,
6229 moduleTranslation, targetDirective);
6230 };
6231
6232 llvm::Value *ifCond = nullptr;
6233 if (Value targetIfCond = targetOp.getIfExpr())
6234 ifCond = moduleTranslation.lookupValue(targetIfCond);
6235
6236 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
6237 moduleTranslation.getOpenMPBuilder()->createTarget(
6238 ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), info, entryInfo,
6239 defaultAttrs, runtimeAttrs, ifCond, kernelInput, genMapInfoCB, bodyCB,
6240 argAccessorCB, customMapperCB, dds, targetOp.getNowait());
6241
6242 if (failed(handleError(afterIP, opInst)))
6243 return failure();
6244
6245 builder.restoreIP(*afterIP);
6246
6247 // Remap access operations to declare target reference pointers for the
6248 // device, essentially generating extra loadop's as necessary
6249 if (moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice())
6250 handleDeclareTargetMapVar(mapData, moduleTranslation, builder,
6251 llvmOutlinedFn);
6252
6253 return success();
6254}
6255
6256static LogicalResult
6257convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute,
6258 LLVM::ModuleTranslation &moduleTranslation) {
6259 // Amend omp.declare_target by deleting the IR of the outlined functions
6260 // created for target regions. They cannot be filtered out from MLIR earlier
6261 // because the omp.target operation inside must be translated to LLVM, but
6262 // the wrapper functions themselves must not remain at the end of the
6263 // process. We know that functions where omp.declare_target does not match
6264 // omp.is_target_device at this stage can only be wrapper functions because
6265 // those that aren't are removed earlier as an MLIR transformation pass.
6266 if (FunctionOpInterface funcOp = dyn_cast<FunctionOpInterface>(op)) {
6267 if (auto offloadMod = dyn_cast<omp::OffloadModuleInterface>(
6268 op->getParentOfType<ModuleOp>().getOperation())) {
6269 if (!offloadMod.getIsTargetDevice())
6270 return success();
6271
6272 omp::DeclareTargetDeviceType declareType =
6273 attribute.getDeviceType().getValue();
6274
6275 if (declareType == omp::DeclareTargetDeviceType::host) {
6276 llvm::Function *llvmFunc =
6277 moduleTranslation.lookupFunction(funcOp.getName());
6278 llvmFunc->dropAllReferences();
6279 llvmFunc->eraseFromParent();
6280 }
6281 }
6282 return success();
6283 }
6284
6285 if (LLVM::GlobalOp gOp = dyn_cast<LLVM::GlobalOp>(op)) {
6286 llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
6287 if (auto *gVal = llvmModule->getNamedValue(gOp.getSymName())) {
6288 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
6289 bool isDeclaration = gOp.isDeclaration();
6290 bool isExternallyVisible =
6291 gOp.getVisibility() != mlir::SymbolTable::Visibility::Private;
6292 auto loc = op->getLoc()->findInstanceOf<FileLineColLoc>();
6293 llvm::StringRef mangledName = gOp.getSymName();
6294 auto captureClause =
6295 convertToCaptureClauseKind(attribute.getCaptureClause().getValue());
6296 auto deviceClause =
6297 convertToDeviceClauseKind(attribute.getDeviceType().getValue());
6298 // unused for MLIR at the moment, required in Clang for book
6299 // keeping
6300 std::vector<llvm::GlobalVariable *> generatedRefs;
6301
6302 std::vector<llvm::Triple> targetTriple;
6303 auto targetTripleAttr = dyn_cast_or_null<mlir::StringAttr>(
6304 op->getParentOfType<mlir::ModuleOp>()->getAttr(
6305 LLVM::LLVMDialect::getTargetTripleAttrName()));
6306 if (targetTripleAttr)
6307 targetTriple.emplace_back(targetTripleAttr.data());
6308
6309 auto fileInfoCallBack = [&loc]() {
6310 std::string filename = "";
6311 std::uint64_t lineNo = 0;
6312
6313 if (loc) {
6314 filename = loc.getFilename().str();
6315 lineNo = loc.getLine();
6316 }
6317
6318 return std::pair<std::string, std::uint64_t>(llvm::StringRef(filename),
6319 lineNo);
6320 };
6321
6322 auto vfs = llvm::vfs::getRealFileSystem();
6323
6324 ompBuilder->registerTargetGlobalVariable(
6325 captureClause, deviceClause, isDeclaration, isExternallyVisible,
6326 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack, *vfs),
6327 mangledName, generatedRefs, /*OpenMPSimd*/ false, targetTriple,
6328 /*GlobalInitializer*/ nullptr, /*VariableLinkage*/ nullptr,
6329 gVal->getType(), gVal);
6330
6331 if (ompBuilder->Config.isTargetDevice() &&
6332 (attribute.getCaptureClause().getValue() !=
6333 mlir::omp::DeclareTargetCaptureClause::to ||
6334 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
6335 ompBuilder->getAddrOfDeclareTargetVar(
6336 captureClause, deviceClause, isDeclaration, isExternallyVisible,
6337 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack, *vfs),
6338 mangledName, generatedRefs, /*OpenMPSimd*/ false, targetTriple,
6339 gVal->getType(), /*GlobalInitializer*/ nullptr,
6340 /*VariableLinkage*/ nullptr);
6341 }
6342 }
6343 }
6344
6345 return success();
6346}
6347
6348// Returns true if the operation is inside a TargetOp or
6349// is part of a declare target function.
6350static bool isTargetDeviceOp(Operation *op) {
6351 // Assumes no reverse offloading
6352 if (op->getParentOfType<omp::TargetOp>())
6353 return true;
6354
6355 // Certain operations return results, and whether utilised in host or
6356 // target there is a chance an LLVM Dialect operation depends on it
6357 // by taking it in as an operand, so we must always lower these in
6358 // some manner or result in an ICE (whether they end up in a no-op
6359 // or otherwise).
6360 if (mlir::isa<omp::ThreadprivateOp>(op))
6361 return true;
6362
6363 if (mlir::isa<omp::TargetAllocMemOp>(op) ||
6364 mlir::isa<omp::TargetFreeMemOp>(op))
6365 return true;
6366
6367 if (auto parentFn = op->getParentOfType<LLVM::LLVMFuncOp>())
6368 if (auto declareTargetIface =
6369 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
6370 parentFn.getOperation()))
6371 if (declareTargetIface.isDeclareTarget() &&
6372 declareTargetIface.getDeclareTargetDeviceType() !=
6373 mlir::omp::DeclareTargetDeviceType::host)
6374 return true;
6375
6376 return false;
6377}
6378
6379static llvm::Function *getOmpTargetAlloc(llvm::IRBuilderBase &builder,
6380 llvm::Module *llvmModule) {
6381 llvm::Type *i64Ty = builder.getInt64Ty();
6382 llvm::Type *i32Ty = builder.getInt32Ty();
6383 llvm::Type *returnType = builder.getPtrTy(0);
6384 llvm::FunctionType *fnType =
6385 llvm::FunctionType::get(returnType, {i64Ty, i32Ty}, false);
6386 llvm::Function *func = cast<llvm::Function>(
6387 llvmModule->getOrInsertFunction("omp_target_alloc", fnType).getCallee());
6388 return func;
6389}
6390
6391static LogicalResult
6392convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder,
6393 LLVM::ModuleTranslation &moduleTranslation) {
6394 auto allocMemOp = cast<omp::TargetAllocMemOp>(opInst);
6395 if (!allocMemOp)
6396 return failure();
6397
6398 // Get "omp_target_alloc" function
6399 llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
6400 llvm::Function *ompTargetAllocFunc = getOmpTargetAlloc(builder, llvmModule);
6401 // Get the corresponding device value in llvm
6402 mlir::Value deviceNum = allocMemOp.getDevice();
6403 llvm::Value *llvmDeviceNum = moduleTranslation.lookupValue(deviceNum);
6404 // Get the allocation size.
6405 llvm::DataLayout dataLayout = llvmModule->getDataLayout();
6406 mlir::Type heapTy = allocMemOp.getAllocatedType();
6407 llvm::Type *llvmHeapTy = moduleTranslation.convertType(heapTy);
6408 llvm::TypeSize typeSize = dataLayout.getTypeStoreSize(llvmHeapTy);
6409 llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue());
6410 for (auto typeParam : allocMemOp.getTypeparams())
6411 allocSize =
6412 builder.CreateMul(allocSize, moduleTranslation.lookupValue(typeParam));
6413 // Create call to "omp_target_alloc" with the args as translated llvm values.
6414 llvm::CallInst *call =
6415 builder.CreateCall(ompTargetAllocFunc, {allocSize, llvmDeviceNum});
6416 llvm::Value *resultI64 = builder.CreatePtrToInt(call, builder.getInt64Ty());
6417
6418 // Map the result
6419 moduleTranslation.mapValue(allocMemOp.getResult(), resultI64);
6420 return success();
6421}
6422
6423static llvm::Function *getOmpTargetFree(llvm::IRBuilderBase &builder,
6424 llvm::Module *llvmModule) {
6425 llvm::Type *ptrTy = builder.getPtrTy(0);
6426 llvm::Type *i32Ty = builder.getInt32Ty();
6427 llvm::Type *voidTy = builder.getVoidTy();
6428 llvm::FunctionType *fnType =
6429 llvm::FunctionType::get(voidTy, {ptrTy, i32Ty}, false);
6430 llvm::Function *func = dyn_cast<llvm::Function>(
6431 llvmModule->getOrInsertFunction("omp_target_free", fnType).getCallee());
6432 return func;
6433}
6434
6435static LogicalResult
6436convertTargetFreeMemOp(Operation &opInst, llvm::IRBuilderBase &builder,
6437 LLVM::ModuleTranslation &moduleTranslation) {
6438 auto freeMemOp = cast<omp::TargetFreeMemOp>(opInst);
6439 if (!freeMemOp)
6440 return failure();
6441
6442 // Get "omp_target_free" function
6443 llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
6444 llvm::Function *ompTragetFreeFunc = getOmpTargetFree(builder, llvmModule);
6445 // Get the corresponding device value in llvm
6446 mlir::Value deviceNum = freeMemOp.getDevice();
6447 llvm::Value *llvmDeviceNum = moduleTranslation.lookupValue(deviceNum);
6448 // Get the corresponding heapref value in llvm
6449 mlir::Value heapref = freeMemOp.getHeapref();
6450 llvm::Value *llvmHeapref = moduleTranslation.lookupValue(heapref);
6451 // Convert heapref int to ptr and call "omp_target_free"
6452 llvm::Value *intToPtr =
6453 builder.CreateIntToPtr(llvmHeapref, builder.getPtrTy(0));
6454 builder.CreateCall(ompTragetFreeFunc, {intToPtr, llvmDeviceNum});
6455 return success();
6456}
6457
6458/// Given an OpenMP MLIR operation, create the corresponding LLVM IR (including
6459/// OpenMP runtime calls).
6460static LogicalResult
6461convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder,
6462 LLVM::ModuleTranslation &moduleTranslation) {
6463 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
6464
6465 // For each loop, introduce one stack frame to hold loop information. Ensure
6466 // this is only done for the outermost loop wrapper to prevent introducing
6467 // multiple stack frames for a single loop. Initially set to null, the loop
6468 // information structure is initialized during translation of the nested
6469 // omp.loop_nest operation, making it available to translation of all loop
6470 // wrappers after their body has been successfully translated.
6471 bool isOutermostLoopWrapper =
6472 isa_and_present<omp::LoopWrapperInterface>(op) &&
6473 !dyn_cast_if_present<omp::LoopWrapperInterface>(op->getParentOp());
6474
6475 if (isOutermostLoopWrapper)
6476 moduleTranslation.stackPush<OpenMPLoopInfoStackFrame>();
6477
6478 auto result =
6480 .Case([&](omp::BarrierOp op) -> LogicalResult {
6481 if (failed(checkImplementationStatus(*op)))
6482 return failure();
6483
6484 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
6485 ompBuilder->createBarrier(builder.saveIP(),
6486 llvm::omp::OMPD_barrier);
6487 LogicalResult res = handleError(afterIP, *op);
6488 if (res.succeeded()) {
6489 // If the barrier generated a cancellation check, the insertion
6490 // point might now need to be changed to a new continuation block
6491 builder.restoreIP(*afterIP);
6492 }
6493 return res;
6494 })
6495 .Case([&](omp::TaskyieldOp op) {
6496 if (failed(checkImplementationStatus(*op)))
6497 return failure();
6498
6499 ompBuilder->createTaskyield(builder.saveIP());
6500 return success();
6501 })
6502 .Case([&](omp::FlushOp op) {
6503 if (failed(checkImplementationStatus(*op)))
6504 return failure();
6505
6506 // No support in Openmp runtime function (__kmpc_flush) to accept
6507 // the argument list.
6508 // OpenMP standard states the following:
6509 // "An implementation may implement a flush with a list by ignoring
6510 // the list, and treating it the same as a flush without a list."
6511 //
6512 // The argument list is discarded so that, flush with a list is
6513 // treated same as a flush without a list.
6514 ompBuilder->createFlush(builder.saveIP());
6515 return success();
6516 })
6517 .Case([&](omp::ParallelOp op) {
6518 return convertOmpParallel(op, builder, moduleTranslation);
6519 })
6520 .Case([&](omp::MaskedOp) {
6521 return convertOmpMasked(*op, builder, moduleTranslation);
6522 })
6523 .Case([&](omp::MasterOp) {
6524 return convertOmpMaster(*op, builder, moduleTranslation);
6525 })
6526 .Case([&](omp::CriticalOp) {
6527 return convertOmpCritical(*op, builder, moduleTranslation);
6528 })
6529 .Case([&](omp::OrderedRegionOp) {
6530 return convertOmpOrderedRegion(*op, builder, moduleTranslation);
6531 })
6532 .Case([&](omp::OrderedOp) {
6533 return convertOmpOrdered(*op, builder, moduleTranslation);
6534 })
6535 .Case([&](omp::WsloopOp) {
6536 return convertOmpWsloop(*op, builder, moduleTranslation);
6537 })
6538 .Case([&](omp::SimdOp) {
6539 return convertOmpSimd(*op, builder, moduleTranslation);
6540 })
6541 .Case([&](omp::AtomicReadOp) {
6542 return convertOmpAtomicRead(*op, builder, moduleTranslation);
6543 })
6544 .Case([&](omp::AtomicWriteOp) {
6545 return convertOmpAtomicWrite(*op, builder, moduleTranslation);
6546 })
6547 .Case([&](omp::AtomicUpdateOp op) {
6548 return convertOmpAtomicUpdate(op, builder, moduleTranslation);
6549 })
6550 .Case([&](omp::AtomicCaptureOp op) {
6551 return convertOmpAtomicCapture(op, builder, moduleTranslation);
6552 })
6553 .Case([&](omp::CancelOp op) {
6554 return convertOmpCancel(op, builder, moduleTranslation);
6555 })
6556 .Case([&](omp::CancellationPointOp op) {
6557 return convertOmpCancellationPoint(op, builder, moduleTranslation);
6558 })
6559 .Case([&](omp::SectionsOp) {
6560 return convertOmpSections(*op, builder, moduleTranslation);
6561 })
6562 .Case([&](omp::SingleOp op) {
6563 return convertOmpSingle(op, builder, moduleTranslation);
6564 })
6565 .Case([&](omp::TeamsOp op) {
6566 return convertOmpTeams(op, builder, moduleTranslation);
6567 })
6568 .Case([&](omp::TaskOp op) {
6569 return convertOmpTaskOp(op, builder, moduleTranslation);
6570 })
6571 .Case([&](omp::TaskgroupOp op) {
6572 return convertOmpTaskgroupOp(op, builder, moduleTranslation);
6573 })
6574 .Case([&](omp::TaskwaitOp op) {
6575 return convertOmpTaskwaitOp(op, builder, moduleTranslation);
6576 })
6577 .Case<omp::YieldOp, omp::TerminatorOp, omp::DeclareMapperOp,
6578 omp::DeclareMapperInfoOp, omp::DeclareReductionOp,
6579 omp::CriticalDeclareOp>([](auto op) {
6580 // `yield` and `terminator` can be just omitted. The block structure
6581 // was created in the region that handles their parent operation.
6582 // `declare_reduction` will be used by reductions and is not
6583 // converted directly, skip it.
6584 // `declare_mapper` and `declare_mapper.info` are handled whenever
6585 // they are referred to through a `map` clause.
6586 // `critical.declare` is only used to declare names of critical
6587 // sections which will be used by `critical` ops and hence can be
6588 // ignored for lowering. The OpenMP IRBuilder will create unique
6589 // name for critical section names.
6590 return success();
6591 })
6592 .Case([&](omp::ThreadprivateOp) {
6593 return convertOmpThreadprivate(*op, builder, moduleTranslation);
6594 })
6595 .Case<omp::TargetDataOp, omp::TargetEnterDataOp,
6596 omp::TargetExitDataOp, omp::TargetUpdateOp>([&](auto op) {
6597 return convertOmpTargetData(op, builder, moduleTranslation);
6598 })
6599 .Case([&](omp::TargetOp) {
6600 return convertOmpTarget(*op, builder, moduleTranslation);
6601 })
6602 .Case([&](omp::DistributeOp) {
6603 return convertOmpDistribute(*op, builder, moduleTranslation);
6604 })
6605 .Case([&](omp::LoopNestOp) {
6606 return convertOmpLoopNest(*op, builder, moduleTranslation);
6607 })
6608 .Case<omp::MapInfoOp, omp::MapBoundsOp, omp::PrivateClauseOp>(
6609 [&](auto op) {
6610 // No-op, should be handled by relevant owning operations e.g.
6611 // TargetOp, TargetEnterDataOp, TargetExitDataOp, TargetDataOp
6612 // etc. and then discarded
6613 return success();
6614 })
6615 .Case([&](omp::NewCliOp op) {
6616 // Meta-operation: Doesn't do anything by itself, but used to
6617 // identify a loop.
6618 return success();
6619 })
6620 .Case([&](omp::CanonicalLoopOp op) {
6621 return convertOmpCanonicalLoopOp(op, builder, moduleTranslation);
6622 })
6623 .Case([&](omp::UnrollHeuristicOp op) {
6624 // FIXME: Handling omp.unroll_heuristic as an executable requires
6625 // that the generator (e.g. omp.canonical_loop) has been seen first.
6626 // For construct that require all codegen to occur inside a callback
6627 // (e.g. OpenMPIRBilder::createParallel), all codegen of that
6628 // contained region including their transformations must occur at
6629 // the omp.canonical_loop.
6630 return applyUnrollHeuristic(op, builder, moduleTranslation);
6631 })
6632 .Case([&](omp::TileOp op) {
6633 return applyTile(op, builder, moduleTranslation);
6634 })
6635 .Case([&](omp::TargetAllocMemOp) {
6636 return convertTargetAllocMemOp(*op, builder, moduleTranslation);
6637 })
6638 .Case([&](omp::TargetFreeMemOp) {
6639 return convertTargetFreeMemOp(*op, builder, moduleTranslation);
6640 })
6641 .Default([&](Operation *inst) {
6642 return inst->emitError()
6643 << "not yet implemented: " << inst->getName();
6644 });
6645
6646 if (isOutermostLoopWrapper)
6647 moduleTranslation.stackPop();
6648
6649 return result;
6650}
6651
6652static LogicalResult
6653convertTargetDeviceOp(Operation *op, llvm::IRBuilderBase &builder,
6654 LLVM::ModuleTranslation &moduleTranslation) {
6655 return convertHostOrTargetOperation(op, builder, moduleTranslation);
6656}
6657
6658static LogicalResult
6659convertTargetOpsInNest(Operation *op, llvm::IRBuilderBase &builder,
6660 LLVM::ModuleTranslation &moduleTranslation) {
6661 if (isa<omp::TargetOp>(op))
6662 return convertOmpTarget(*op, builder, moduleTranslation);
6663 if (isa<omp::TargetDataOp>(op))
6664 return convertOmpTargetData(op, builder, moduleTranslation);
6665 bool interrupted =
6666 op->walk<WalkOrder::PreOrder>([&](Operation *oper) {
6667 if (isa<omp::TargetOp>(oper)) {
6668 if (failed(convertOmpTarget(*oper, builder, moduleTranslation)))
6669 return WalkResult::interrupt();
6670 return WalkResult::skip();
6671 }
6672 if (isa<omp::TargetDataOp>(oper)) {
6673 if (failed(convertOmpTargetData(oper, builder, moduleTranslation)))
6674 return WalkResult::interrupt();
6675 return WalkResult::skip();
6676 }
6677
6678 // Non-target ops might nest target-related ops, therefore, we
6679 // translate them as non-OpenMP scopes. Translating them is needed by
6680 // nested target-related ops since they might need LLVM values defined
6681 // in their parent non-target ops.
6682 if (isa<omp::OpenMPDialect>(oper->getDialect()) &&
6683 oper->getParentOfType<LLVM::LLVMFuncOp>() &&
6684 !oper->getRegions().empty()) {
6685 if (auto blockArgsIface =
6686 dyn_cast<omp::BlockArgOpenMPOpInterface>(oper))
6687 forwardArgs(moduleTranslation, blockArgsIface);
6688 else {
6689 // Here we map entry block arguments of
6690 // non-BlockArgOpenMPOpInterface ops if they can be encountered
6691 // inside of a function and they define any of these arguments.
6692 if (isa<mlir::omp::AtomicUpdateOp>(oper))
6693 for (auto [operand, arg] :
6694 llvm::zip_equal(oper->getOperands(),
6695 oper->getRegion(0).getArguments())) {
6696 moduleTranslation.mapValue(
6697 arg, builder.CreateLoad(
6698 moduleTranslation.convertType(arg.getType()),
6699 moduleTranslation.lookupValue(operand)));
6700 }
6701 }
6702
6703 if (auto loopNest = dyn_cast<omp::LoopNestOp>(oper)) {
6704 assert(builder.GetInsertBlock() &&
6705 "No insert block is set for the builder");
6706 for (auto iv : loopNest.getIVs()) {
6707 // Map iv to an undefined value just to keep the IR validity.
6708 moduleTranslation.mapValue(
6709 iv, llvm::PoisonValue::get(
6710 moduleTranslation.convertType(iv.getType())));
6711 }
6712 }
6713
6714 for (Region &region : oper->getRegions()) {
6715 // Regions are fake in the sense that they are not a truthful
6716 // translation of the OpenMP construct being converted (e.g. no
6717 // OpenMP runtime calls will be generated). We just need this to
6718 // prepare the kernel invocation args.
6721 region, oper->getName().getStringRef().str() + ".fake.region",
6722 builder, moduleTranslation, &phis);
6723 if (failed(handleError(result, *oper)))
6724 return WalkResult::interrupt();
6725
6726 builder.SetInsertPoint(result.get(), result.get()->end());
6727 }
6728
6729 return WalkResult::skip();
6730 }
6731
6732 return WalkResult::advance();
6733 }).wasInterrupted();
6734 return failure(interrupted);
6735}
6736
6737namespace {
6738
6739/// Implementation of the dialect interface that converts operations belonging
6740/// to the OpenMP dialect to LLVM IR.
6741class OpenMPDialectLLVMIRTranslationInterface
6742 : public LLVMTranslationDialectInterface {
6743public:
6745
6746 /// Translates the given operation to LLVM IR using the provided IR builder
6747 /// and saving the state in `moduleTranslation`.
6748 LogicalResult
6749 convertOperation(Operation *op, llvm::IRBuilderBase &builder,
6750 LLVM::ModuleTranslation &moduleTranslation) const final;
6751
6752 /// Given an OpenMP MLIR attribute, create the corresponding LLVM-IR,
6753 /// runtime calls, or operation amendments
6754 LogicalResult
6755 amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
6756 NamedAttribute attribute,
6757 LLVM::ModuleTranslation &moduleTranslation) const final;
6758};
6759
6760} // namespace
6761
6762LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
6763 Operation *op, ArrayRef<llvm::Instruction *> instructions,
6764 NamedAttribute attribute,
6765 LLVM::ModuleTranslation &moduleTranslation) const {
6766 return llvm::StringSwitch<llvm::function_ref<LogicalResult(Attribute)>>(
6767 attribute.getName())
6768 .Case("omp.is_target_device",
6769 [&](Attribute attr) {
6770 if (auto deviceAttr = dyn_cast<BoolAttr>(attr)) {
6771 llvm::OpenMPIRBuilderConfig &config =
6772 moduleTranslation.getOpenMPBuilder()->Config;
6773 config.setIsTargetDevice(deviceAttr.getValue());
6774 return success();
6775 }
6776 return failure();
6777 })
6778 .Case("omp.is_gpu",
6779 [&](Attribute attr) {
6780 if (auto gpuAttr = dyn_cast<BoolAttr>(attr)) {
6781 llvm::OpenMPIRBuilderConfig &config =
6782 moduleTranslation.getOpenMPBuilder()->Config;
6783 config.setIsGPU(gpuAttr.getValue());
6784 return success();
6785 }
6786 return failure();
6787 })
6788 .Case("omp.host_ir_filepath",
6789 [&](Attribute attr) {
6790 if (auto filepathAttr = dyn_cast<StringAttr>(attr)) {
6791 llvm::OpenMPIRBuilder *ompBuilder =
6792 moduleTranslation.getOpenMPBuilder();
6793 auto VFS = llvm::vfs::getRealFileSystem();
6794 ompBuilder->loadOffloadInfoMetadata(*VFS,
6795 filepathAttr.getValue());
6796 return success();
6797 }
6798 return failure();
6799 })
6800 .Case("omp.flags",
6801 [&](Attribute attr) {
6802 if (auto rtlAttr = dyn_cast<omp::FlagsAttr>(attr))
6803 return convertFlagsAttr(op, rtlAttr, moduleTranslation);
6804 return failure();
6805 })
6806 .Case("omp.version",
6807 [&](Attribute attr) {
6808 if (auto versionAttr = dyn_cast<omp::VersionAttr>(attr)) {
6809 llvm::OpenMPIRBuilder *ompBuilder =
6810 moduleTranslation.getOpenMPBuilder();
6811 ompBuilder->M.addModuleFlag(llvm::Module::Max, "openmp",
6812 versionAttr.getVersion());
6813 return success();
6814 }
6815 return failure();
6816 })
6817 .Case("omp.declare_target",
6818 [&](Attribute attr) {
6819 if (auto declareTargetAttr =
6820 dyn_cast<omp::DeclareTargetAttr>(attr))
6821 return convertDeclareTargetAttr(op, declareTargetAttr,
6822 moduleTranslation);
6823 return failure();
6824 })
6825 .Case("omp.requires",
6826 [&](Attribute attr) {
6827 if (auto requiresAttr = dyn_cast<omp::ClauseRequiresAttr>(attr)) {
6828 using Requires = omp::ClauseRequires;
6829 Requires flags = requiresAttr.getValue();
6830 llvm::OpenMPIRBuilderConfig &config =
6831 moduleTranslation.getOpenMPBuilder()->Config;
6832 config.setHasRequiresReverseOffload(
6833 bitEnumContainsAll(flags, Requires::reverse_offload));
6834 config.setHasRequiresUnifiedAddress(
6835 bitEnumContainsAll(flags, Requires::unified_address));
6836 config.setHasRequiresUnifiedSharedMemory(
6837 bitEnumContainsAll(flags, Requires::unified_shared_memory));
6838 config.setHasRequiresDynamicAllocators(
6839 bitEnumContainsAll(flags, Requires::dynamic_allocators));
6840 return success();
6841 }
6842 return failure();
6843 })
6844 .Case("omp.target_triples",
6845 [&](Attribute attr) {
6846 if (auto triplesAttr = dyn_cast<ArrayAttr>(attr)) {
6847 llvm::OpenMPIRBuilderConfig &config =
6848 moduleTranslation.getOpenMPBuilder()->Config;
6849 config.TargetTriples.clear();
6850 config.TargetTriples.reserve(triplesAttr.size());
6851 for (Attribute tripleAttr : triplesAttr) {
6852 if (auto tripleStrAttr = dyn_cast<StringAttr>(tripleAttr))
6853 config.TargetTriples.emplace_back(tripleStrAttr.getValue());
6854 else
6855 return failure();
6856 }
6857 return success();
6858 }
6859 return failure();
6860 })
6861 .Default([](Attribute) {
6862 // Fall through for omp attributes that do not require lowering.
6863 return success();
6864 })(attribute.getValue());
6865
6866 return failure();
6867}
6868
6869/// Given an OpenMP MLIR operation, create the corresponding LLVM IR
6870/// (including OpenMP runtime calls).
6871LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
6872 Operation *op, llvm::IRBuilderBase &builder,
6873 LLVM::ModuleTranslation &moduleTranslation) const {
6874
6875 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
6876 if (ompBuilder->Config.isTargetDevice()) {
6877 if (isTargetDeviceOp(op)) {
6878 return convertTargetDeviceOp(op, builder, moduleTranslation);
6879 }
6880 return convertTargetOpsInNest(op, builder, moduleTranslation);
6881 }
6882 return convertHostOrTargetOperation(op, builder, moduleTranslation);
6883}
6884
6886 registry.insert<omp::OpenMPDialect>();
6887 registry.addExtension(+[](MLIRContext *ctx, omp::OpenMPDialect *dialect) {
6888 dialect->addInterfaces<OpenMPDialectLLVMIRTranslationInterface>();
6889 });
6890}
6891
6893 DialectRegistry registry;
6895 context.appendDialectRegistry(registry);
6896}
for(Operation *op :ops)
return success()
lhs
static ze_device_handle_t getDevice(const uint32_t driverIdx=0, const int32_t devIdx=0)
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
if(!isCopyOut)
static bool mapTypeToBool(ClauseMapFlags value, ClauseMapFlags flag)
static void handleDeclareTargetMapVar(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::Function *func)
static LogicalResult convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP atomic update operation using OpenMPIRBuilder.
static llvm::omp::OrderKind convertOrderKind(std::optional< omp::ClauseOrderKind > o)
Convert Order attribute to llvm::omp::OrderKind.
static bool teamsReductionContainedInDistribute(omp::TeamsOp teamsOp)
static LogicalResult convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
owningDataPtrPtrReductionGens[i]
static Operation * getGlobalOpFromValue(Value value)
static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind convertToCaptureClauseKind(mlir::omp::DeclareTargetCaptureClause captureClause)
static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo, bool first)
static OpTy castOrGetParentOfType(Operation *op, bool immediateParent=false)
If op is of the given type parameter, return it casted to that type. Otherwise, if its immediate pare...
static LogicalResult convertOmpOrderedRegion(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'ordered_region' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertTargetFreeMemOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static void processMapMembersWithParent(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, llvm::omp::OpenMPOffloadMappingFlags memberOfFlag, TargetDirectiveEnumTy targetDirective)
static LogicalResult convertOmpAtomicWrite(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an omp.atomic.write operation to LLVM IR.
static OwningAtomicReductionGen makeAtomicReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Create an OpenMPIRBuilder-compatible atomic reduction generator for the given reduction declaration.
static OwningDataPtrPtrReductionGen makeRefDataPtrGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, bool isByRef)
Create an OpenMPIRBuilder-compatible data_ptr_ptr reduction generator for the given reduction declara...
static llvm::Expected< llvm::Value * > initPrivateVar(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, omp::PrivateClauseOp &privDecl, Value mlirPrivVar, BlockArgument &blockArg, llvm::Value *llvmPrivateVar, llvm::BasicBlock *privInitBlock, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Initialize a single (first)private variable. You probably want to use allocateAndInitPrivateVars inst...
static llvm::Value * getRefPtrIfDeclareTarget(Value value, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult applyUnrollHeuristic(omp::UnrollHeuristicOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Apply a #pragma omp unroll / "!$omp unroll" transformation using the OpenMPIRBuilder.
static LogicalResult convertOmpMaster(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'master' operation into LLVM IR using OpenMPIRBuilder.
static void getAsIntegers(ArrayAttr values, llvm::SmallVector< int64_t > &ints)
static llvm::Value * findAssociatedValue(Value privateVar, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Return the llvm::Value * corresponding to the privateVar that is being privatized....
static ArrayRef< bool > getIsByRef(std::optional< ArrayRef< bool > > attr)
static LogicalResult convertOmpTaskwaitOp(omp::TaskwaitOp twOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP loop nest into LLVM IR using OpenMPIRBuilder.
static void popCancelFinalizationCB(const ArrayRef< llvm::BranchInst * > cancelTerminators, llvm::OpenMPIRBuilder &ompBuilder, const llvm::OpenMPIRBuilder::InsertPointTy &afterIP)
If we cancelled the construct, we should branch to the finalization block of that construct....
static LogicalResult cleanupPrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, Location loc, SmallVectorImpl< llvm::Value * > &llvmPrivateVars, SmallVectorImpl< omp::PrivateClauseOp > &privateDecls)
static void createAlteredByCaptureMap(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder)
static LogicalResult convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP task construct into LLVM IR using OpenMPIRBuilder.
static void genMapInfos(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, TargetDirectiveEnumTy targetDirective)
static llvm::AtomicOrdering convertAtomicOrdering(std::optional< omp::ClauseMemoryOrderKind > ao)
Convert an Atomic Ordering attribute to llvm::AtomicOrdering.
static void setInsertPointForPossiblyEmptyBlock(llvm::IRBuilderBase &builder, llvm::BasicBlock *block=nullptr)
static LogicalResult convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'critical' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, TargetDirectiveEnumTy targetDirective)
static LogicalResult convertOmpCanonicalLoopOp(omp::CanonicalLoopOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Convert an omp.canonical_loop to LLVM-IR.
static LogicalResult convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static std::optional< int64_t > extractConstInteger(Value value)
If the given value is defined by an llvm.mlir.constant operation and it is of an integer type,...
static LogicalResult allocAndInitializeReductionVars(OP op, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, llvm::ArrayRef< bool > isByRef)
static LogicalResult convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Function * getOmpTargetAlloc(llvm::IRBuilderBase &builder, llvm::Module *llvmModule)
static llvm::Expected< llvm::Function * > emitUserDefinedMapper(Operation *declMapperOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::StringRef mapperFuncName, TargetDirectiveEnumTy targetDirective)
static llvm::Expected< llvm::BasicBlock * > allocatePrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, PrivateVarsInfo &privateVarsInfo, const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Allocate and initialize delayed private variables. Returns the basic block which comes after all of t...
static LogicalResult convertOmpOrdered(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'ordered' operation into LLVM IR using OpenMPIRBuilder.
static void processMapWithMembersOf(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, TargetDirectiveEnumTy targetDirective)
static LogicalResult convertOmpMasked(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'masked' operation into LLVM IR using OpenMPIRBuilder.
static bool isTargetDeviceOp(Operation *op)
static llvm::AtomicRMWInst::BinOp convertBinOpToAtomic(Operation &op)
Converts an LLVM dialect binary operation to the corresponding enum value for atomicrmw supported bin...
static void pushCancelFinalizationCB(SmallVectorImpl< llvm::BranchInst * > &cancelTerminators, llvm::IRBuilderBase &llvmBuilder, llvm::OpenMPIRBuilder &ompBuilder, mlir::Operation *op, llvm::omp::Directive cancelDirective)
Shared implementation of a callback which adds a termiator for the new block created for the branch t...
static LogicalResult convertOmpCancel(omp::CancelOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp)
allocatedType moduleTranslation static convertType(allocatedType) LogicalResult inlineOmpRegionCleanup(llvm::SmallVectorImpl< Region * > &cleanupRegions, llvm::ArrayRef< llvm::Value * > privateVariables, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, StringRef regionName, bool shouldLoadCleanupRegionArg=true)
handling of DeclareReductionOp's cleanup region
static llvm::SmallString< 64 > getDeclareTargetRefPtrSuffix(LLVM::GlobalOp globalOp, llvm::OpenMPIRBuilder &ompBuilder)
static llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Error initPrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, PrivateVarsInfo &privateVarsInfo, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
static llvm::CanonicalLoopInfo * findCurrentLoopInfo(LLVM::ModuleTranslation &moduleTranslation)
Find the loop information structure for the loop nest being translated.
static OwningReductionGen makeReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Create an OpenMPIRBuilder-compatible reduction generator for the given reduction declaration.
static std::vector< llvm::Value * > calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, bool isArrayTy, OperandRange bounds)
This function calculates the array/pointer offset for map data provided with bounds operations,...
static LogicalResult convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts the OpenMP parallel operation to LLVM IR.
static llvm::OpenMPIRBuilder::InsertPointTy findAllocaInsertPoint(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Find the insertion point for allocas given the current insertion point for normal operations in the b...
static LogicalResult inlineConvertOmpRegions(Region &region, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::Value * > *continuationBlockArgs=nullptr)
Translates the blocks contained in the given region and appends them to at the current insertion poin...
static LogicalResult convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP Threadprivate operation into LLVM IR using OpenMPIRBuilder.
static omp::PrivateClauseOp findPrivatizer(Operation *from, SymbolRefAttr symbolName)
Looks up from the operation from and returns the PrivateClauseOp with name symbolName.
static LogicalResult convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Expected< llvm::Function * > getOrCreateUserDefinedMapperFunc(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, TargetDirectiveEnumTy targetDirective)
static uint64_t getTypeByteSize(mlir::Type type, const DataLayout &dl)
static void processIndividualMap(MapInfoData &mapData, size_t mapDataIdx, MapInfosTy &combinedInfo, TargetDirectiveEnumTy targetDirective, int mapDataParentIdx=-1)
static void getTargetEntryUniqueInfo(llvm::TargetRegionEntryInfo &targetInfo, omp::TargetOp targetOp, llvm::StringRef parentName="")
static void extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads, Value &numTeamsLower, Value &numTeamsUpper, Value &threadLimit, llvm::SmallVectorImpl< Value > *lowerBounds=nullptr, llvm::SmallVectorImpl< Value > *upperBounds=nullptr, llvm::SmallVectorImpl< Value > *steps=nullptr)
Follow uses of host_eval-defined block arguments of the given omp.target operation and populate outpu...
static llvm::Expected< llvm::BasicBlock * > convertOmpOpRegions(Region &region, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::PHINode * > *continuationBlockPHIs=nullptr)
Converts the given region that appears within an OpenMP dialect operation to LLVM IR,...
static void forwardArgs(LLVM::ModuleTranslation &moduleTranslation, omp::BlockArgOpenMPOpInterface blockArgIface)
Maps block arguments from blockArgIface (which are MLIR values) to the corresponding LLVM values of t...
static LogicalResult allocReductionVars(T loop, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, SmallVectorImpl< DeferredStore > &deferredStores, llvm::ArrayRef< bool > isByRefs)
Allocate space for privatized reduction variables.
static bool constructIsCancellable(Operation *op)
Returns true if the construct contains omp.cancel or omp.cancellation_point.
static llvm::omp::OpenMPOffloadMappingFlags convertClauseMapFlags(omp::ClauseMapFlags mlirFlags)
static llvm::IRBuilderBase::InsertPoint createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg, llvm::Value *input, llvm::Value *&retVal, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase::InsertPoint allocaIP, llvm::IRBuilderBase::InsertPoint codeGenIP)
static void mapInitializationArgs(T loop, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, DenseMap< Value, llvm::Value * > &reductionVariableMap, unsigned i)
Map input arguments to reduction initialization region.
static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind)
Convert ProcBindKind from MLIR-generated enum to LLVM enum.
static LogicalResult convertTargetOpsInNest(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static void getOverlappedMembers(llvm::SmallVectorImpl< size_t > &overlapMapDataIdxs, omp::MapInfoOp parentOp)
static LogicalResult convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP single construct into LLVM IR using OpenMPIRBuilder.
static bool isDeclareTargetTo(Value value)
static uint64_t getArrayElementSizeInBits(LLVM::LLVMArrayType arrTy, DataLayout &dl)
static void collectReductionDecls(T op, SmallVectorImpl< omp::DeclareReductionOp > &reductions)
Populates reductions with reduction declarations used in the given op.
static LogicalResult handleError(llvm::Error error, Operation &op)
static LogicalResult copyFirstPrivateVars(mlir::Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< mlir::Value > &mlirPrivateVars, ArrayRef< llvm::Value * > llvmPrivateVars, SmallVectorImpl< omp::PrivateClauseOp > &privateDecls, bool insertBarrier, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
static LogicalResult convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind convertToDeviceClauseKind(mlir::omp::DeclareTargetDeviceType deviceClause)
static LogicalResult checkImplementationStatus(Operation &op)
Check whether translation to LLVM IR for the given operation is currently supported.
static LogicalResult createReductionsAndCleanup(OP op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, ArrayRef< llvm::Value * > privateReductionVariables, ArrayRef< bool > isByRef, bool isNowait=false, bool isTeamsReduction=false)
static LogicalResult convertOmpCancellationPoint(omp::CancellationPointOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static uint64_t getReductionDataSize(OpTy &op)
static LogicalResult convertOmpAtomicRead(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Convert omp.atomic.read operation to LLVM IR.
static llvm::omp::Directive convertCancellationConstructType(omp::ClauseCancellationConstructType directive)
static void initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp, llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs, bool isTargetDevice, bool isGPU)
Populate default MinTeams, MaxTeams and MaxThreads to their default values as stated by the correspon...
static void initTargetRuntimeAttrs(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, omp::TargetOp targetOp, Operation *capturedOp, llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs)
Gather LLVM runtime values for all clauses evaluated in the host that are passed to the kernel invoca...
static LogicalResult convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Given an OpenMP MLIR operation, create the corresponding LLVM IR (including OpenMP runtime calls).
static bool isDeclareTargetLink(Value value)
static LogicalResult convertFlagsAttr(Operation *op, mlir::omp::FlagsAttr attribute, LLVM::ModuleTranslation &moduleTranslation)
Lowers the FlagsAttr which is applied to the module on the device pass when offloading,...
static bool checkIfPointerMap(omp::MapInfoOp mapOp)
static LogicalResult applyTile(omp::TileOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Apply a #pragma omp tile / !$omp tile transformation using the OpenMPIRBuilder.
static void sortMapIndices(llvm::SmallVectorImpl< size_t > &indices, omp::MapInfoOp mapInfo)
static LogicalResult convertTargetDeviceOp(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static void buildDependData(std::optional< ArrayAttr > dependKinds, OperandRange dependVars, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::OpenMPIRBuilder::DependData > &dds)
static llvm::Function * getOmpTargetFree(llvm::IRBuilderBase &builder, llvm::Module *llvmModule)
static LogicalResult convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder.
static void collectMapDataFromMapOperands(MapInfoData &mapData, SmallVectorImpl< Value > &mapVars, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, llvm::IRBuilderBase &builder, ArrayRef< Value > useDevPtrOperands={}, ArrayRef< Value > useDevAddrOperands={}, ArrayRef< Value > hasDevAddrOperands={})
static void extractAtomicControlFlags(omp::AtomicUpdateOp atomicUpdateOp, bool &isIgnoreDenormalMode, bool &isFineGrainedMemory, bool &isRemoteMemory)
static Operation * genLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr, unsigned numCases, bool needsUniv, ArrayRef< TensorLevel > tidLvls)
Generates a for-loop or a while-loop, depending on whether it implements singleton iteration or co-it...
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition TypeID.h:331
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:309
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:129
unsigned getNumArguments()
Definition Block.h:128
Operation & front()
Definition Block.h:153
Operation & back()
Definition Block.h:152
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:244
iterator begin()
Definition Block.h:143
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
An instance of this location represents a tuple of file, line number, and column number.
Definition Location.h:174
Implementation class for module translation.
llvm::BasicBlock * lookupBlock(Block *block) const
Finds an LLVM IR basic block that corresponds to the given MLIR block.
WalkResult stackWalk(llvm::function_ref< WalkResult(T &)> callback)
Calls callback for every ModuleTranslation stack frame of type T starting from the top of the stack.
void stackPush(Args &&...args)
Creates a stack frame of type T on ModuleTranslation stack.
LogicalResult convertBlock(Block &bb, bool ignoreArguments, llvm::IRBuilderBase &builder)
Translates the contents of the given block to LLVM IR using this translator.
SmallVector< llvm::Value * > lookupValues(ValueRange values)
Looks up remapped a list of remapped values.
void mapFunction(StringRef name, llvm::Function *func)
Stores the mapping between a function name and its LLVM IR representation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
void invalidateOmpLoop(omp::NewCliOp mlir)
Mark an OpenMP loop as having been consumed.
SymbolTableCollection & symbolTable()
llvm::Type * convertType(Type type)
Converts the type from MLIR LLVM dialect to LLVM.
llvm::OpenMPIRBuilder * getOpenMPBuilder()
Returns the OpenMP IR builder associated with the LLVM IR module being constructed.
void mapOmpLoop(omp::NewCliOp mlir, llvm::CanonicalLoopInfo *llvm)
Map an MLIR OpenMP dialect CanonicalLoopInfo to its lowered LLVM-IR OpenMPIRBuilder CanonicalLoopInfo...
llvm::GlobalValue * lookupGlobal(Operation *op)
Finds an LLVM IR global value that corresponds to the given MLIR operation defining a global value.
SaveStateStack< T, ModuleTranslation > SaveStack
RAII object calling stackPush/stackPop on construction/destruction.
llvm::Function * lookupFunction(StringRef name) const
Finds an LLVM IR function by its name.
void mapBlock(Block *mlir, llvm::BasicBlock *llvm)
Stores the mapping between an MLIR block and LLVM IR basic block.
llvm::Module * getLLVMModule()
Returns the LLVM module in which the IR is being constructed.
void stackPop()
Pops the last element from the ModuleTranslation stack.
void forgetMapping(Region &region)
Removes the mapping for blocks contained in the region and values defined in these blocks.
void mapValue(Value mlir, llvm::Value *llvm)
Stores the mapping between an MLIR value and its LLVM IR counterpart.
llvm::CanonicalLoopInfo * lookupOMPLoop(omp::NewCliOp mlir) const
Find the LLVM-IR loop that represents an MLIR loop.
llvm::LLVMContext & getLLVMContext() const
Returns the LLVM context in which the IR is being constructed.
Utility class to translate MLIR LLVM dialect types to LLVM IR.
Definition TypeToLLVM.h:39
unsigned getPreferredAlignment(Type type, const llvm::DataLayout &layout)
Returns the preferred alignment for the type given the data layout.
T findInstanceOf()
Return an instance of the given location type if one is nested under the current location.
Definition Location.h:45
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
void appendDialectRegistry(const DialectRegistry &registry)
Append the contents of the given dialect registry to the registry associated with this context.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition Operation.h:220
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:686
Value getOperand(unsigned idx)
Definition Operation.h:350
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
unsigned getNumOperands()
Definition Operation.h:346
OperandRange operand_range
Definition Operation.h:371
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:238
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:797
user_range getUsers()
Returns a range of all users.
Definition Operation.h:873
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
BlockArgListType getArguments()
Definition Region.h:81
bool empty()
Definition Region.h:60
unsigned getNumArguments()
Definition Region.h:123
iterator begin()
Definition Region.h:55
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition Region.h:200
BlockListType & getBlocks()
Definition Region.h:45
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
Concrete CRTP base class for StateStack frames.
Definition StateStack.h:47
@ Private
The symbol is private and may only be referenced by SymbolRefAttrs local to the operations within the...
Definition SymbolTable.h:97
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult skip()
Definition WalkResult.h:48
static WalkResult advance()
Definition WalkResult.h:47
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition WalkResult.h:51
static WalkResult interrupt()
Definition WalkResult.h:46
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Definition CallGraph.h:229
void connectPHINodes(Region &region, const ModuleTranslation &state)
For all blocks in the region that were converted to LLVM IR using the given ModuleTranslation,...
llvm::Constant * createMappingInformation(Location loc, llvm::OpenMPIRBuilder &builder)
Create a constant string representing the mapping information extracted from the MLIR location inform...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
Include the generated interface declarations.
SetVector< Block * > getBlocksSortedByDominance(Region &region)
Gets a list of blocks that is sorted according to dominance.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
const FrozenRewritePatternSet GreedyRewriteConfig config
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void registerOpenMPDialectTranslation(DialectRegistry &registry)
Register the OpenMP dialect and the translation from it to the LLVM IR in the given registry;.
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:131
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition Utils.cpp:1293
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
A util to collect info needed to convert delayed privatizers from MLIR to LLVM.
SmallVector< mlir::Value > mlirVars
SmallVector< omp::PrivateClauseOp > privatizers
MutableArrayRef< BlockArgument > blockArgs
SmallVector< llvm::Value * > llvmVars