MLIR 23.0.0git
OpenMPToLLVMIRTranslation.cpp
Go to the documentation of this file.
1//===- OpenMPToLLVMIRTranslation.cpp - Translate OpenMP dialect to LLVM IR-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a translation between the MLIR OpenMP dialect and LLVM
10// IR.
11//
12//===----------------------------------------------------------------------===//
19#include "mlir/IR/Operation.h"
20#include "mlir/Support/LLVM.h"
23
24#include "llvm/ADT/ArrayRef.h"
25#include "llvm/ADT/SmallVector.h"
26#include "llvm/ADT/TypeSwitch.h"
27#include "llvm/Frontend/OpenMP/OMPConstants.h"
28#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
29#include "llvm/IR/Constants.h"
30#include "llvm/IR/DebugInfoMetadata.h"
31#include "llvm/IR/DerivedTypes.h"
32#include "llvm/IR/IRBuilder.h"
33#include "llvm/IR/MDBuilder.h"
34#include "llvm/IR/ReplaceConstant.h"
35#include "llvm/Support/FileSystem.h"
36#include "llvm/Support/VirtualFileSystem.h"
37#include "llvm/TargetParser/Triple.h"
38#include "llvm/Transforms/Utils/ModuleUtils.h"
39
40#include <cstdint>
41#include <iterator>
42#include <numeric>
43#include <optional>
44#include <utility>
45
46using namespace mlir;
47
48namespace {
49static llvm::omp::ScheduleKind
50convertToScheduleKind(std::optional<omp::ClauseScheduleKind> schedKind) {
51 if (!schedKind.has_value())
52 return llvm::omp::OMP_SCHEDULE_Default;
53 switch (schedKind.value()) {
54 case omp::ClauseScheduleKind::Static:
55 return llvm::omp::OMP_SCHEDULE_Static;
56 case omp::ClauseScheduleKind::Dynamic:
57 return llvm::omp::OMP_SCHEDULE_Dynamic;
58 case omp::ClauseScheduleKind::Guided:
59 return llvm::omp::OMP_SCHEDULE_Guided;
60 case omp::ClauseScheduleKind::Auto:
61 return llvm::omp::OMP_SCHEDULE_Auto;
62 case omp::ClauseScheduleKind::Runtime:
63 return llvm::omp::OMP_SCHEDULE_Runtime;
64 case omp::ClauseScheduleKind::Distribute:
65 return llvm::omp::OMP_SCHEDULE_Distribute;
66 }
67 llvm_unreachable("unhandled schedule clause argument");
68}
69
70/// ModuleTranslation stack frame for OpenMP operations. This keeps track of the
71/// insertion points for allocas.
72class OpenMPAllocaStackFrame
73 : public StateStackFrameBase<OpenMPAllocaStackFrame> {
74public:
76
77 explicit OpenMPAllocaStackFrame(llvm::OpenMPIRBuilder::InsertPointTy allocaIP)
78 : allocaInsertPoint(allocaIP) {}
79 llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
80};
81
82/// Stack frame to hold a \see llvm::CanonicalLoopInfo representing the
83/// collapsed canonical loop information corresponding to an \c omp.loop_nest
84/// operation.
85class OpenMPLoopInfoStackFrame
86 : public StateStackFrameBase<OpenMPLoopInfoStackFrame> {
87public:
88 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPLoopInfoStackFrame)
89 llvm::CanonicalLoopInfo *loopInfo = nullptr;
90};
91
92/// Custom error class to signal translation errors that don't need reporting,
93/// since encountering them will have already triggered relevant error messages.
94///
95/// Its purpose is to serve as the glue between MLIR failures represented as
96/// \see LogicalResult instances and \see llvm::Error instances used to
97/// propagate errors through the \see llvm::OpenMPIRBuilder. Generally, when an
98/// error of the first type is raised, a message is emitted directly (the \see
99/// LogicalResult itself does not hold any information). If we need to forward
100/// this error condition as an \see llvm::Error while avoiding triggering some
101/// redundant error reporting later on, we need a custom \see llvm::ErrorInfo
102/// class to just signal this situation has happened.
103///
104/// For example, this class should be used to trigger errors from within
105/// callbacks passed to the \see OpenMPIRBuilder when they were triggered by the
106/// translation of their own regions. This unclutters the error log from
107/// redundant messages.
108class PreviouslyReportedError
109 : public llvm::ErrorInfo<PreviouslyReportedError> {
110public:
111 void log(raw_ostream &) const override {
112 // Do not log anything.
113 }
114
115 std::error_code convertToErrorCode() const override {
116 llvm_unreachable(
117 "PreviouslyReportedError doesn't support ECError conversion");
118 }
119
120 // Used by ErrorInfo::classID.
121 static char ID;
122};
123
124char PreviouslyReportedError::ID = 0;
125
126/*
127 * Custom class for processing linear clause for omp.wsloop
128 * and omp.simd. Linear clause translation requires setup,
129 * initialization, update, and finalization at varying
130 * basic blocks in the IR. This class helps maintain
131 * internal state to allow consistent translation in
132 * each of these stages.
133 */
134
135class LinearClauseProcessor {
136
137private:
138 SmallVector<llvm::Value *> linearPreconditionVars;
139 SmallVector<llvm::Value *> linearLoopBodyTemps;
140 SmallVector<llvm::Value *> linearOrigVal;
141 SmallVector<llvm::Value *> linearSteps;
142 SmallVector<llvm::Type *> linearVarTypes;
143 llvm::BasicBlock *linearFinalizationBB;
144 llvm::BasicBlock *linearExitBB;
145 llvm::BasicBlock *linearLastIterExitBB;
146
147public:
148 // Register type for the linear variables
149 void registerType(LLVM::ModuleTranslation &moduleTranslation,
150 mlir::Attribute &ty) {
151 linearVarTypes.push_back(moduleTranslation.convertType(
152 mlir::cast<mlir::TypeAttr>(ty).getValue()));
153 }
154
155 // Allocate space for linear variabes
156 void createLinearVar(llvm::IRBuilderBase &builder,
157 LLVM::ModuleTranslation &moduleTranslation,
158 llvm::Value *linearVar, int idx) {
159 linearPreconditionVars.push_back(
160 builder.CreateAlloca(linearVarTypes[idx], nullptr, ".linear_var"));
161 llvm::Value *linearLoopBodyTemp =
162 builder.CreateAlloca(linearVarTypes[idx], nullptr, ".linear_result");
163 linearOrigVal.push_back(linearVar);
164 linearLoopBodyTemps.push_back(linearLoopBodyTemp);
165 }
166
167 // Initialize linear step
168 inline void initLinearStep(LLVM::ModuleTranslation &moduleTranslation,
169 mlir::Value &linearStep) {
170 linearSteps.push_back(moduleTranslation.lookupValue(linearStep));
171 }
172
173 // Emit IR for initialization of linear variables
174 void initLinearVar(llvm::IRBuilderBase &builder,
175 LLVM::ModuleTranslation &moduleTranslation,
176 llvm::BasicBlock *loopPreHeader) {
177 builder.SetInsertPoint(loopPreHeader->getTerminator());
178 for (size_t index = 0; index < linearOrigVal.size(); index++) {
179 llvm::LoadInst *linearVarLoad =
180 builder.CreateLoad(linearVarTypes[index], linearOrigVal[index]);
181 builder.CreateStore(linearVarLoad, linearPreconditionVars[index]);
182 }
183 }
184
185 // Emit IR for updating Linear variables
186 void updateLinearVar(llvm::IRBuilderBase &builder, llvm::BasicBlock *loopBody,
187 llvm::Value *loopInductionVar) {
188 builder.SetInsertPoint(loopBody->getTerminator());
189 for (size_t index = 0; index < linearPreconditionVars.size(); index++) {
190 llvm::Type *linearVarType = linearVarTypes[index];
191 llvm::Value *iv = loopInductionVar;
192 llvm::Value *step = linearSteps[index];
193
194 if (!iv->getType()->isIntegerTy())
195 llvm_unreachable("OpenMP loop induction variable must be an integer "
196 "type");
197
198 if (linearVarType->isIntegerTy()) {
199 // Integer path: normalize all arithmetic to linearVarType
200 iv = builder.CreateSExtOrTrunc(iv, linearVarType);
201 step = builder.CreateSExtOrTrunc(step, linearVarType);
202
203 llvm::LoadInst *linearVarStart =
204 builder.CreateLoad(linearVarType, linearPreconditionVars[index]);
205 llvm::Value *mulInst = builder.CreateMul(iv, step);
206 llvm::Value *addInst = builder.CreateAdd(linearVarStart, mulInst);
207 builder.CreateStore(addInst, linearLoopBodyTemps[index]);
208 } else if (linearVarType->isFloatingPointTy()) {
209 // Float path: perform multiply in integer, then convert to float
210 step = builder.CreateSExtOrTrunc(step, iv->getType());
211 llvm::Value *mulInst = builder.CreateMul(iv, step);
212
213 llvm::LoadInst *linearVarStart =
214 builder.CreateLoad(linearVarType, linearPreconditionVars[index]);
215 llvm::Value *mulFp = builder.CreateSIToFP(mulInst, linearVarType);
216 llvm::Value *addInst = builder.CreateFAdd(linearVarStart, mulFp);
217 builder.CreateStore(addInst, linearLoopBodyTemps[index]);
218 } else {
219 llvm_unreachable(
220 "Linear variable must be of integer or floating-point type");
221 }
222 }
223 }
224
225 // Linear variable finalization is conditional on the last logical iteration.
226 // Create BB splits to manage the same.
227 void splitLinearFiniBB(llvm::IRBuilderBase &builder,
228 llvm::BasicBlock *loopExit) {
229 linearFinalizationBB = loopExit->splitBasicBlock(
230 loopExit->getTerminator(), "omp_loop.linear_finalization");
231 linearExitBB = linearFinalizationBB->splitBasicBlock(
232 linearFinalizationBB->getTerminator(), "omp_loop.linear_exit");
233 linearLastIterExitBB = linearFinalizationBB->splitBasicBlock(
234 linearFinalizationBB->getTerminator(), "omp_loop.linear_lastiter_exit");
235 }
236
237 // Finalize the linear vars
238 llvm::OpenMPIRBuilder::InsertPointOrErrorTy
239 finalizeLinearVar(llvm::IRBuilderBase &builder,
240 LLVM::ModuleTranslation &moduleTranslation,
241 llvm::Value *lastIter) {
242 // Emit condition to check whether last logical iteration is being executed
243 builder.SetInsertPoint(linearFinalizationBB->getTerminator());
244 llvm::Value *loopLastIterLoad = builder.CreateLoad(
245 llvm::Type::getInt32Ty(builder.getContext()), lastIter);
246 llvm::Value *isLast =
247 builder.CreateCmp(llvm::CmpInst::ICMP_NE, loopLastIterLoad,
248 llvm::ConstantInt::get(
249 llvm::Type::getInt32Ty(builder.getContext()), 0));
250 // Store the linear variable values to original variables.
251 builder.SetInsertPoint(linearLastIterExitBB->getTerminator());
252 for (size_t index = 0; index < linearOrigVal.size(); index++) {
253 llvm::LoadInst *linearVarTemp =
254 builder.CreateLoad(linearVarTypes[index], linearLoopBodyTemps[index]);
255 builder.CreateStore(linearVarTemp, linearOrigVal[index]);
256 }
257
258 // Create conditional branch such that the linear variable
259 // values are stored to original variables only at the
260 // last logical iteration
261 builder.SetInsertPoint(linearFinalizationBB->getTerminator());
262 builder.CreateCondBr(isLast, linearLastIterExitBB, linearExitBB);
263 linearFinalizationBB->getTerminator()->eraseFromParent();
264 // Emit barrier
265 builder.SetInsertPoint(linearExitBB->getTerminator());
266 return moduleTranslation.getOpenMPBuilder()->createBarrier(
267 builder.saveIP(), llvm::omp::OMPD_barrier);
268 }
269
270 // Emit stores for linear variables. Useful in case of SIMD
271 // construct.
272 void emitStoresForLinearVar(llvm::IRBuilderBase &builder) {
273 for (size_t index = 0; index < linearOrigVal.size(); index++) {
274 llvm::LoadInst *linearVarTemp =
275 builder.CreateLoad(linearVarTypes[index], linearLoopBodyTemps[index]);
276 builder.CreateStore(linearVarTemp, linearOrigVal[index]);
277 }
278 }
279
280 // Rewrite all uses of the original variable in `BBName`
281 // with the linear variable in-place
282 void rewriteInPlace(llvm::IRBuilderBase &builder, const std::string &BBName,
283 size_t varIndex) {
284 llvm::SmallVector<llvm::User *> users;
285 for (llvm::User *user : linearOrigVal[varIndex]->users())
286 users.push_back(user);
287 for (auto *user : users) {
288 if (auto *userInst = dyn_cast<llvm::Instruction>(user)) {
289 if (userInst->getParent()->getName().str().find(BBName) !=
290 std::string::npos)
291 user->replaceUsesOfWith(linearOrigVal[varIndex],
292 linearLoopBodyTemps[varIndex]);
293 }
294 }
295 }
296};
297
298} // namespace
299
300/// Looks up from the operation from and returns the PrivateClauseOp with
301/// name symbolName
302static omp::PrivateClauseOp findPrivatizer(Operation *from,
303 SymbolRefAttr symbolName) {
304 omp::PrivateClauseOp privatizer =
306 symbolName);
307 assert(privatizer && "privatizer not found in the symbol table");
308 return privatizer;
309}
310
311/// Check whether translation to LLVM IR for the given operation is currently
312/// supported. If not, descriptive diagnostics will be emitted to let users know
313/// this is a not-yet-implemented feature.
314///
315/// \returns success if no unimplemented features are needed to translate the
316/// given operation.
317static LogicalResult checkImplementationStatus(Operation &op) {
318 auto todo = [&op](StringRef clauseName) {
319 return op.emitError() << "not yet implemented: Unhandled clause "
320 << clauseName << " in " << op.getName()
321 << " operation";
322 };
323
324 auto checkAffinity = [&todo](auto op, LogicalResult &result) {
325 if (!op.getAffinityVars().empty())
326 result = todo("affinity");
327 };
328 auto checkAllocate = [&todo](auto op, LogicalResult &result) {
329 if (!op.getAllocateVars().empty() || !op.getAllocatorVars().empty())
330 result = todo("allocate");
331 };
332 auto checkBare = [&todo](auto op, LogicalResult &result) {
333 if (op.getBare())
334 result = todo("ompx_bare");
335 };
336 auto checkDepend = [&todo](auto op, LogicalResult &result) {
337 if (!op.getDependVars().empty() || op.getDependKinds())
338 result = todo("depend");
339 };
340 auto checkHint = [](auto op, LogicalResult &) {
341 if (op.getHint())
342 op.emitWarning("hint clause discarded");
343 };
344 auto checkInReduction = [&todo](auto op, LogicalResult &result) {
345 if (!op.getInReductionVars().empty() || op.getInReductionByref() ||
346 op.getInReductionSyms())
347 result = todo("in_reduction");
348 };
349 auto checkNowait = [&todo](auto op, LogicalResult &result) {
350 if (op.getNowait())
351 result = todo("nowait");
352 };
353 auto checkOrder = [&todo](auto op, LogicalResult &result) {
354 if (op.getOrder() || op.getOrderMod())
355 result = todo("order");
356 };
357 auto checkPrivate = [&todo](auto op, LogicalResult &result) {
358 if (!op.getPrivateVars().empty() || op.getPrivateSyms())
359 result = todo("privatization");
360 };
361 auto checkReduction = [&todo](auto op, LogicalResult &result) {
362 if (isa<omp::TeamsOp>(op) || isa<omp::TaskloopOp>(op))
363 if (!op.getReductionVars().empty() || op.getReductionByref() ||
364 op.getReductionSyms())
365 result = todo("reduction");
366 if (op.getReductionMod() &&
367 op.getReductionMod().value() != omp::ReductionModifier::defaultmod)
368 result = todo("reduction with modifier");
369 };
370 auto checkTaskReduction = [&todo](auto op, LogicalResult &result) {
371 if (!op.getTaskReductionVars().empty() || op.getTaskReductionByref() ||
372 op.getTaskReductionSyms())
373 result = todo("task_reduction");
374 };
375 auto checkNumTeams = [&todo](auto op, LogicalResult &result) {
376 if (op.hasNumTeamsMultiDim())
377 result = todo("num_teams with multi-dimensional values");
378 };
379 auto checkNumThreads = [&todo](auto op, LogicalResult &result) {
380 if (op.hasNumThreadsMultiDim())
381 result = todo("num_threads with multi-dimensional values");
382 };
383
384 auto checkThreadLimit = [&todo](auto op, LogicalResult &result) {
385 if (op.hasThreadLimitMultiDim())
386 result = todo("thread_limit with multi-dimensional values");
387 };
388
389 LogicalResult result = success();
391 .Case([&](omp::DistributeOp op) {
392 checkAllocate(op, result);
393 checkOrder(op, result);
394 })
395 .Case([&](omp::SectionsOp op) {
396 checkAllocate(op, result);
397 checkPrivate(op, result);
398 checkReduction(op, result);
399 })
400 .Case([&](omp::SingleOp op) {
401 checkAllocate(op, result);
402 checkPrivate(op, result);
403 })
404 .Case([&](omp::TeamsOp op) {
405 checkAllocate(op, result);
406 checkPrivate(op, result);
407 checkNumTeams(op, result);
408 checkThreadLimit(op, result);
409 })
410 .Case([&](omp::TaskOp op) {
411 checkAffinity(op, result);
412 checkAllocate(op, result);
413 checkInReduction(op, result);
414 })
415 .Case([&](omp::TaskgroupOp op) {
416 checkAllocate(op, result);
417 checkTaskReduction(op, result);
418 })
419 .Case([&](omp::TaskwaitOp op) {
420 checkDepend(op, result);
421 checkNowait(op, result);
422 })
423 .Case([&](omp::TaskloopOp op) {
424 checkAllocate(op, result);
425 checkInReduction(op, result);
426 checkReduction(op, result);
427 })
428 .Case([&](omp::WsloopOp op) {
429 checkAllocate(op, result);
430 checkOrder(op, result);
431 checkReduction(op, result);
432 })
433 .Case([&](omp::ParallelOp op) {
434 checkAllocate(op, result);
435 checkReduction(op, result);
436 checkNumThreads(op, result);
437 })
438 .Case([&](omp::SimdOp op) { checkReduction(op, result); })
439 .Case<omp::AtomicReadOp, omp::AtomicWriteOp, omp::AtomicUpdateOp,
440 omp::AtomicCaptureOp>([&](auto op) { checkHint(op, result); })
441 .Case<omp::TargetEnterDataOp, omp::TargetExitDataOp>(
442 [&](auto op) { checkDepend(op, result); })
443 .Case([&](omp::TargetUpdateOp op) { checkDepend(op, result); })
444 .Case([&](omp::TargetOp op) {
445 checkAllocate(op, result);
446 checkBare(op, result);
447 checkInReduction(op, result);
448 checkThreadLimit(op, result);
449 })
450 .Default([](Operation &) {
451 // Assume all clauses for an operation can be translated unless they are
452 // checked above.
453 });
454 return result;
455}
456
457static LogicalResult handleError(llvm::Error error, Operation &op) {
458 LogicalResult result = success();
459 if (error) {
460 llvm::handleAllErrors(
461 std::move(error),
462 [&](const PreviouslyReportedError &) { result = failure(); },
463 [&](const llvm::ErrorInfoBase &err) {
464 result = op.emitError(err.message());
465 });
466 }
467 return result;
468}
469
470template <typename T>
471static LogicalResult handleError(llvm::Expected<T> &result, Operation &op) {
472 if (!result)
473 return handleError(result.takeError(), op);
474
475 return success();
476}
477
478/// Find the insertion point for allocas given the current insertion point for
479/// normal operations in the builder.
480static llvm::OpenMPIRBuilder::InsertPointTy
481findAllocaInsertPoint(llvm::IRBuilderBase &builder,
482 LLVM::ModuleTranslation &moduleTranslation) {
483 // If there is an alloca insertion point on stack, i.e. we are in a nested
484 // operation and a specific point was provided by some surrounding operation,
485 // use it.
486 llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
487 WalkResult walkResult = moduleTranslation.stackWalk<OpenMPAllocaStackFrame>(
488 [&](OpenMPAllocaStackFrame &frame) {
489 allocaInsertPoint = frame.allocaInsertPoint;
490 return WalkResult::interrupt();
491 });
492 // In cases with multiple levels of outlining, the tree walk might find an
493 // alloca insertion point that is inside the original function while the
494 // builder insertion point is inside the outlined function. We need to make
495 // sure that we do not use it in those cases.
496 if (walkResult.wasInterrupted() &&
497 allocaInsertPoint.getBlock()->getParent() ==
498 builder.GetInsertBlock()->getParent())
499 return allocaInsertPoint;
500
501 // Otherwise, insert to the entry block of the surrounding function.
502 // If the current IRBuilder InsertPoint is the function's entry, it cannot
503 // also be used for alloca insertion which would result in insertion order
504 // confusion. Create a new BasicBlock for the Builder and use the entry block
505 // for the allocs.
506 // TODO: Create a dedicated alloca BasicBlock at function creation such that
507 // we do not need to move the current InertPoint here.
508 if (builder.GetInsertBlock() ==
509 &builder.GetInsertBlock()->getParent()->getEntryBlock()) {
510 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end() &&
511 "Assuming end of basic block");
512 llvm::BasicBlock *entryBB = llvm::BasicBlock::Create(
513 builder.getContext(), "entry", builder.GetInsertBlock()->getParent(),
514 builder.GetInsertBlock()->getNextNode());
515 builder.CreateBr(entryBB);
516 builder.SetInsertPoint(entryBB);
517 }
518
519 llvm::BasicBlock &funcEntryBlock =
520 builder.GetInsertBlock()->getParent()->getEntryBlock();
521 return llvm::OpenMPIRBuilder::InsertPointTy(
522 &funcEntryBlock, funcEntryBlock.getFirstInsertionPt());
523}
524
525/// Find the loop information structure for the loop nest being translated. It
526/// will return a `null` value unless called from the translation function for
527/// a loop wrapper operation after successfully translating its body.
528static llvm::CanonicalLoopInfo *
530 llvm::CanonicalLoopInfo *loopInfo = nullptr;
531 moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
532 [&](OpenMPLoopInfoStackFrame &frame) {
533 loopInfo = frame.loopInfo;
534 return WalkResult::interrupt();
535 });
536 return loopInfo;
537}
538
539/// Converts the given region that appears within an OpenMP dialect operation to
540/// LLVM IR, creating a branch from the `sourceBlock` to the entry block of the
541/// region, and a branch from any block with an successor-less OpenMP terminator
542/// to `continuationBlock`. Populates `continuationBlockPHIs` with the PHI nodes
543/// of the continuation block if provided.
545 Region &region, StringRef blockName, llvm::IRBuilderBase &builder,
546 LLVM::ModuleTranslation &moduleTranslation,
547 SmallVectorImpl<llvm::PHINode *> *continuationBlockPHIs = nullptr) {
548 bool isLoopWrapper = isa<omp::LoopWrapperInterface>(region.getParentOp());
549
550 llvm::BasicBlock *continuationBlock =
551 splitBB(builder, true, "omp.region.cont");
552 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
553
554 llvm::LLVMContext &llvmContext = builder.getContext();
555 for (Block &bb : region) {
556 llvm::BasicBlock *llvmBB = llvm::BasicBlock::Create(
557 llvmContext, blockName, builder.GetInsertBlock()->getParent(),
558 builder.GetInsertBlock()->getNextNode());
559 moduleTranslation.mapBlock(&bb, llvmBB);
560 }
561
562 llvm::Instruction *sourceTerminator = sourceBlock->getTerminator();
563
564 // Terminators (namely YieldOp) may be forwarding values to the region that
565 // need to be available in the continuation block. Collect the types of these
566 // operands in preparation of creating PHI nodes. This is skipped for loop
567 // wrapper operations, for which we know in advance they have no terminators.
568 SmallVector<llvm::Type *> continuationBlockPHITypes;
569 unsigned numYields = 0;
570
571 if (!isLoopWrapper) {
572 bool operandsProcessed = false;
573 for (Block &bb : region.getBlocks()) {
574 if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(bb.getTerminator())) {
575 if (!operandsProcessed) {
576 for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
577 continuationBlockPHITypes.push_back(
578 moduleTranslation.convertType(yield->getOperand(i).getType()));
579 }
580 operandsProcessed = true;
581 } else {
582 assert(continuationBlockPHITypes.size() == yield->getNumOperands() &&
583 "mismatching number of values yielded from the region");
584 for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
585 llvm::Type *operandType =
586 moduleTranslation.convertType(yield->getOperand(i).getType());
587 (void)operandType;
588 assert(continuationBlockPHITypes[i] == operandType &&
589 "values of mismatching types yielded from the region");
590 }
591 }
592 numYields++;
593 }
594 }
595 }
596
597 // Insert PHI nodes in the continuation block for any values forwarded by the
598 // terminators in this region.
599 if (!continuationBlockPHITypes.empty())
600 assert(
601 continuationBlockPHIs &&
602 "expected continuation block PHIs if converted regions yield values");
603 if (continuationBlockPHIs) {
604 llvm::IRBuilderBase::InsertPointGuard guard(builder);
605 continuationBlockPHIs->reserve(continuationBlockPHITypes.size());
606 builder.SetInsertPoint(continuationBlock, continuationBlock->begin());
607 for (llvm::Type *ty : continuationBlockPHITypes)
608 continuationBlockPHIs->push_back(builder.CreatePHI(ty, numYields));
609 }
610
611 // Convert blocks one by one in topological order to ensure
612 // defs are converted before uses.
614 for (Block *bb : blocks) {
615 llvm::BasicBlock *llvmBB = moduleTranslation.lookupBlock(bb);
616 // Retarget the branch of the entry block to the entry block of the
617 // converted region (regions are single-entry).
618 if (bb->isEntryBlock()) {
619 assert(sourceTerminator->getNumSuccessors() == 1 &&
620 "provided entry block has multiple successors");
621 assert(sourceTerminator->getSuccessor(0) == continuationBlock &&
622 "ContinuationBlock is not the successor of the entry block");
623 sourceTerminator->setSuccessor(0, llvmBB);
624 }
625
626 llvm::IRBuilderBase::InsertPointGuard guard(builder);
627 if (failed(
628 moduleTranslation.convertBlock(*bb, bb->isEntryBlock(), builder)))
629 return llvm::make_error<PreviouslyReportedError>();
630
631 // Create a direct branch here for loop wrappers to prevent their lack of a
632 // terminator from causing a crash below.
633 if (isLoopWrapper) {
634 builder.CreateBr(continuationBlock);
635 continue;
636 }
637
638 // Special handling for `omp.yield` and `omp.terminator` (we may have more
639 // than one): they return the control to the parent OpenMP dialect operation
640 // so replace them with the branch to the continuation block. We handle this
641 // here to avoid relying inter-function communication through the
642 // ModuleTranslation class to set up the correct insertion point. This is
643 // also consistent with MLIR's idiom of handling special region terminators
644 // in the same code that handles the region-owning operation.
645 Operation *terminator = bb->getTerminator();
646 if (isa<omp::TerminatorOp, omp::YieldOp>(terminator)) {
647 builder.CreateBr(continuationBlock);
648
649 for (unsigned i = 0, e = terminator->getNumOperands(); i < e; ++i)
650 (*continuationBlockPHIs)[i]->addIncoming(
651 moduleTranslation.lookupValue(terminator->getOperand(i)), llvmBB);
652 }
653 }
654 // After all blocks have been traversed and values mapped, connect the PHI
655 // nodes to the results of preceding blocks.
656 LLVM::detail::connectPHINodes(region, moduleTranslation);
657
658 // Remove the blocks and values defined in this region from the mapping since
659 // they are not visible outside of this region. This allows the same region to
660 // be converted several times, that is cloned, without clashes, and slightly
661 // speeds up the lookups.
662 moduleTranslation.forgetMapping(region);
663
664 return continuationBlock;
665}
666
667/// Convert ProcBindKind from MLIR-generated enum to LLVM enum.
668static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind) {
669 switch (kind) {
670 case omp::ClauseProcBindKind::Close:
671 return llvm::omp::ProcBindKind::OMP_PROC_BIND_close;
672 case omp::ClauseProcBindKind::Master:
673 return llvm::omp::ProcBindKind::OMP_PROC_BIND_master;
674 case omp::ClauseProcBindKind::Primary:
675 return llvm::omp::ProcBindKind::OMP_PROC_BIND_primary;
676 case omp::ClauseProcBindKind::Spread:
677 return llvm::omp::ProcBindKind::OMP_PROC_BIND_spread;
678 }
679 llvm_unreachable("Unknown ClauseProcBindKind kind");
680}
681
682/// Converts an OpenMP 'masked' operation into LLVM IR using OpenMPIRBuilder.
683static LogicalResult
684convertOmpMasked(Operation &opInst, llvm::IRBuilderBase &builder,
685 LLVM::ModuleTranslation &moduleTranslation) {
686 auto maskedOp = cast<omp::MaskedOp>(opInst);
687 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
688
689 if (failed(checkImplementationStatus(opInst)))
690 return failure();
691
692 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
693 // MaskedOp has only one region associated with it.
694 auto &region = maskedOp.getRegion();
695 builder.restoreIP(codeGenIP);
696 return convertOmpOpRegions(region, "omp.masked.region", builder,
697 moduleTranslation)
698 .takeError();
699 };
700
701 // TODO: Perform finalization actions for variables. This has to be
702 // called for variables which have destructors/finalizers.
703 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
704
705 llvm::Value *filterVal = nullptr;
706 if (auto filterVar = maskedOp.getFilteredThreadId()) {
707 filterVal = moduleTranslation.lookupValue(filterVar);
708 } else {
709 llvm::LLVMContext &llvmContext = builder.getContext();
710 filterVal =
711 llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext), /*V=*/0);
712 }
713 assert(filterVal != nullptr);
714 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
715 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
716 moduleTranslation.getOpenMPBuilder()->createMasked(ompLoc, bodyGenCB,
717 finiCB, filterVal);
718
719 if (failed(handleError(afterIP, opInst)))
720 return failure();
721
722 builder.restoreIP(*afterIP);
723 return success();
724}
725
726/// Converts an OpenMP 'master' operation into LLVM IR using OpenMPIRBuilder.
727static LogicalResult
728convertOmpMaster(Operation &opInst, llvm::IRBuilderBase &builder,
729 LLVM::ModuleTranslation &moduleTranslation) {
730 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
731 auto masterOp = cast<omp::MasterOp>(opInst);
732
733 if (failed(checkImplementationStatus(opInst)))
734 return failure();
735
736 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
737 // MasterOp has only one region associated with it.
738 auto &region = masterOp.getRegion();
739 builder.restoreIP(codeGenIP);
740 return convertOmpOpRegions(region, "omp.master.region", builder,
741 moduleTranslation)
742 .takeError();
743 };
744
745 // TODO: Perform finalization actions for variables. This has to be
746 // called for variables which have destructors/finalizers.
747 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
748
749 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
750 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
751 moduleTranslation.getOpenMPBuilder()->createMaster(ompLoc, bodyGenCB,
752 finiCB);
753
754 if (failed(handleError(afterIP, opInst)))
755 return failure();
756
757 builder.restoreIP(*afterIP);
758 return success();
759}
760
761/// Converts an OpenMP 'critical' operation into LLVM IR using OpenMPIRBuilder.
762static LogicalResult
763convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder,
764 LLVM::ModuleTranslation &moduleTranslation) {
765 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
766 auto criticalOp = cast<omp::CriticalOp>(opInst);
767
768 if (failed(checkImplementationStatus(opInst)))
769 return failure();
770
771 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
772 // CriticalOp has only one region associated with it.
773 auto &region = cast<omp::CriticalOp>(opInst).getRegion();
774 builder.restoreIP(codeGenIP);
775 return convertOmpOpRegions(region, "omp.critical.region", builder,
776 moduleTranslation)
777 .takeError();
778 };
779
780 // TODO: Perform finalization actions for variables. This has to be
781 // called for variables which have destructors/finalizers.
782 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
783
784 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
785 llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
786 llvm::Constant *hint = nullptr;
787
788 // If it has a name, it probably has a hint too.
789 if (criticalOp.getNameAttr()) {
790 // The verifiers in OpenMP Dialect guarentee that all the pointers are
791 // non-null
792 auto symbolRef = cast<SymbolRefAttr>(criticalOp.getNameAttr());
793 auto criticalDeclareOp =
795 symbolRef);
796 hint =
797 llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext),
798 static_cast<int>(criticalDeclareOp.getHint()));
799 }
800 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
801 moduleTranslation.getOpenMPBuilder()->createCritical(
802 ompLoc, bodyGenCB, finiCB, criticalOp.getName().value_or(""), hint);
803
804 if (failed(handleError(afterIP, opInst)))
805 return failure();
806
807 builder.restoreIP(*afterIP);
808 return success();
809}
810
811/// A util to collect info needed to convert delayed privatizers from MLIR to
812/// LLVM.
814 template <typename OP>
816 : blockArgs(
817 cast<omp::BlockArgOpenMPOpInterface>(*op).getPrivateBlockArgs()) {
818 mlirVars.reserve(blockArgs.size());
819 llvmVars.reserve(blockArgs.size());
820 collectPrivatizationDecls<OP>(op);
821
822 for (mlir::Value privateVar : op.getPrivateVars())
823 mlirVars.push_back(privateVar);
824 }
825
830
831private:
832 /// Populates `privatizations` with privatization declarations used for the
833 /// given op.
834 template <class OP>
835 void collectPrivatizationDecls(OP op) {
836 std::optional<ArrayAttr> attr = op.getPrivateSyms();
837 if (!attr)
838 return;
839
840 privatizers.reserve(privatizers.size() + attr->size());
841 for (auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
842 privatizers.push_back(findPrivatizer(op, symbolRef));
843 }
844 }
845};
846
847/// Populates `reductions` with reduction declarations used in the given op.
848template <typename T>
849static void
852 std::optional<ArrayAttr> attr = op.getReductionSyms();
853 if (!attr)
854 return;
855
856 reductions.reserve(reductions.size() + op.getNumReductionVars());
857 for (auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
858 reductions.push_back(
860 op, symbolRef));
861 }
862}
863
864/// Translates the blocks contained in the given region and appends them to at
865/// the current insertion point of `builder`. The operations of the entry block
866/// are appended to the current insertion block. If set, `continuationBlockArgs`
867/// is populated with translated values that correspond to the values
868/// omp.yield'ed from the region.
869static LogicalResult inlineConvertOmpRegions(
870 Region &region, StringRef blockName, llvm::IRBuilderBase &builder,
871 LLVM::ModuleTranslation &moduleTranslation,
872 SmallVectorImpl<llvm::Value *> *continuationBlockArgs = nullptr) {
873 if (region.empty())
874 return success();
875
876 // Special case for single-block regions that don't create additional blocks:
877 // insert operations without creating additional blocks.
878 if (region.hasOneBlock()) {
879 llvm::Instruction *potentialTerminator =
880 builder.GetInsertBlock()->empty() ? nullptr
881 : &builder.GetInsertBlock()->back();
882
883 if (potentialTerminator && potentialTerminator->isTerminator())
884 potentialTerminator->removeFromParent();
885 moduleTranslation.mapBlock(&region.front(), builder.GetInsertBlock());
886
887 if (failed(moduleTranslation.convertBlock(
888 region.front(), /*ignoreArguments=*/true, builder)))
889 return failure();
890
891 // The continuation arguments are simply the translated terminator operands.
892 if (continuationBlockArgs)
893 llvm::append_range(
894 *continuationBlockArgs,
895 moduleTranslation.lookupValues(region.front().back().getOperands()));
896
897 // Drop the mapping that is no longer necessary so that the same region can
898 // be processed multiple times.
899 moduleTranslation.forgetMapping(region);
900
901 if (potentialTerminator && potentialTerminator->isTerminator()) {
902 llvm::BasicBlock *block = builder.GetInsertBlock();
903 if (block->empty()) {
904 // this can happen for really simple reduction init regions e.g.
905 // %0 = llvm.mlir.constant(0 : i32) : i32
906 // omp.yield(%0 : i32)
907 // because the llvm.mlir.constant (MLIR op) isn't converted into any
908 // llvm op
909 potentialTerminator->insertInto(block, block->begin());
910 } else {
911 potentialTerminator->insertAfter(&block->back());
912 }
913 }
914
915 return success();
916 }
917
919 llvm::Expected<llvm::BasicBlock *> continuationBlock =
920 convertOmpOpRegions(region, blockName, builder, moduleTranslation, &phis);
921
922 if (failed(handleError(continuationBlock, *region.getParentOp())))
923 return failure();
924
925 if (continuationBlockArgs)
926 llvm::append_range(*continuationBlockArgs, phis);
927 builder.SetInsertPoint(*continuationBlock,
928 (*continuationBlock)->getFirstInsertionPt());
929 return success();
930}
931
932namespace {
933/// Owning equivalents of OpenMPIRBuilder::(Atomic)ReductionGen that are used to
934/// store lambdas with capture.
935using OwningReductionGen =
936 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
937 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *,
938 llvm::Value *&)>;
939using OwningAtomicReductionGen =
940 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
941 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Type *, llvm::Value *,
942 llvm::Value *)>;
943using OwningDataPtrPtrReductionGen =
944 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
945 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *&)>;
946} // namespace
947
948/// Create an OpenMPIRBuilder-compatible reduction generator for the given
949/// reduction declaration. The generator uses `builder` but ignores its
950/// insertion point.
951static OwningReductionGen
952makeReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder,
953 LLVM::ModuleTranslation &moduleTranslation) {
954 // The lambda is mutable because we need access to non-const methods of decl
955 // (which aren't actually mutating it), and we must capture decl by-value to
956 // avoid the dangling reference after the parent function returns.
957 OwningReductionGen gen =
958 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
959 llvm::Value *lhs, llvm::Value *rhs,
960 llvm::Value *&result) mutable
961 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
962 moduleTranslation.mapValue(decl.getReductionLhsArg(), lhs);
963 moduleTranslation.mapValue(decl.getReductionRhsArg(), rhs);
964 builder.restoreIP(insertPoint);
966 if (failed(inlineConvertOmpRegions(decl.getReductionRegion(),
967 "omp.reduction.nonatomic.body", builder,
968 moduleTranslation, &phis)))
969 return llvm::createStringError(
970 "failed to inline `combiner` region of `omp.declare_reduction`");
971 result = llvm::getSingleElement(phis);
972 return builder.saveIP();
973 };
974 return gen;
975}
976
977/// Create an OpenMPIRBuilder-compatible atomic reduction generator for the
978/// given reduction declaration. The generator uses `builder` but ignores its
979/// insertion point. Returns null if there is no atomic region available in the
980/// reduction declaration.
981static OwningAtomicReductionGen
982makeAtomicReductionGen(omp::DeclareReductionOp decl,
983 llvm::IRBuilderBase &builder,
984 LLVM::ModuleTranslation &moduleTranslation) {
985 if (decl.getAtomicReductionRegion().empty())
986 return OwningAtomicReductionGen();
987
988 // The lambda is mutable because we need access to non-const methods of decl
989 // (which aren't actually mutating it), and we must capture decl by-value to
990 // avoid the dangling reference after the parent function returns.
991 OwningAtomicReductionGen atomicGen =
992 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint, llvm::Type *,
993 llvm::Value *lhs, llvm::Value *rhs) mutable
994 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
995 moduleTranslation.mapValue(decl.getAtomicReductionLhsArg(), lhs);
996 moduleTranslation.mapValue(decl.getAtomicReductionRhsArg(), rhs);
997 builder.restoreIP(insertPoint);
999 if (failed(inlineConvertOmpRegions(decl.getAtomicReductionRegion(),
1000 "omp.reduction.atomic.body", builder,
1001 moduleTranslation, &phis)))
1002 return llvm::createStringError(
1003 "failed to inline `atomic` region of `omp.declare_reduction`");
1004 assert(phis.empty());
1005 return builder.saveIP();
1006 };
1007 return atomicGen;
1008}
1009
1010/// Create an OpenMPIRBuilder-compatible `data_ptr_ptr` reduction generator for
1011/// the given reduction declaration. The generator uses `builder` but ignores
1012/// its insertion point. Returns null if there is no `data_ptr_ptr` region
1013/// available in the reduction declaration.
1014static OwningDataPtrPtrReductionGen
1015makeRefDataPtrGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder,
1016 LLVM::ModuleTranslation &moduleTranslation, bool isByRef) {
1017 if (!isByRef)
1018 return OwningDataPtrPtrReductionGen();
1019
1020 OwningDataPtrPtrReductionGen refDataPtrGen =
1021 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
1022 llvm::Value *byRefVal, llvm::Value *&result) mutable
1023 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1024 moduleTranslation.mapValue(decl.getDataPtrPtrRegionArg(), byRefVal);
1025 builder.restoreIP(insertPoint);
1027 if (failed(inlineConvertOmpRegions(decl.getDataPtrPtrRegion(),
1028 "omp.data_ptr_ptr.body", builder,
1029 moduleTranslation, &phis)))
1030 return llvm::createStringError(
1031 "failed to inline `data_ptr_ptr` region of `omp.declare_reduction`");
1032 result = llvm::getSingleElement(phis);
1033 return builder.saveIP();
1034 };
1035
1036 return refDataPtrGen;
1037}
1038
1039/// Converts an OpenMP 'ordered' operation into LLVM IR using OpenMPIRBuilder.
1040static LogicalResult
1041convertOmpOrdered(Operation &opInst, llvm::IRBuilderBase &builder,
1042 LLVM::ModuleTranslation &moduleTranslation) {
1043 auto orderedOp = cast<omp::OrderedOp>(opInst);
1044
1045 if (failed(checkImplementationStatus(opInst)))
1046 return failure();
1047
1048 omp::ClauseDepend dependType = *orderedOp.getDoacrossDependType();
1049 bool isDependSource = dependType == omp::ClauseDepend::dependsource;
1050 unsigned numLoops = *orderedOp.getDoacrossNumLoops();
1051 SmallVector<llvm::Value *> vecValues =
1052 moduleTranslation.lookupValues(orderedOp.getDoacrossDependVars());
1053
1054 size_t indexVecValues = 0;
1055 while (indexVecValues < vecValues.size()) {
1056 SmallVector<llvm::Value *> storeValues;
1057 storeValues.reserve(numLoops);
1058 for (unsigned i = 0; i < numLoops; i++) {
1059 storeValues.push_back(vecValues[indexVecValues]);
1060 indexVecValues++;
1061 }
1062 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1063 findAllocaInsertPoint(builder, moduleTranslation);
1064 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1065 builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createOrderedDepend(
1066 ompLoc, allocaIP, numLoops, storeValues, ".cnt.addr", isDependSource));
1067 }
1068 return success();
1069}
1070
1071/// Converts an OpenMP 'ordered_region' operation into LLVM IR using
1072/// OpenMPIRBuilder.
1073static LogicalResult
1074convertOmpOrderedRegion(Operation &opInst, llvm::IRBuilderBase &builder,
1075 LLVM::ModuleTranslation &moduleTranslation) {
1076 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1077 auto orderedRegionOp = cast<omp::OrderedRegionOp>(opInst);
1078
1079 if (failed(checkImplementationStatus(opInst)))
1080 return failure();
1081
1082 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
1083 // OrderedOp has only one region associated with it.
1084 auto &region = cast<omp::OrderedRegionOp>(opInst).getRegion();
1085 builder.restoreIP(codeGenIP);
1086 return convertOmpOpRegions(region, "omp.ordered.region", builder,
1087 moduleTranslation)
1088 .takeError();
1089 };
1090
1091 // TODO: Perform finalization actions for variables. This has to be
1092 // called for variables which have destructors/finalizers.
1093 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
1094
1095 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1096 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1097 moduleTranslation.getOpenMPBuilder()->createOrderedThreadsSimd(
1098 ompLoc, bodyGenCB, finiCB, !orderedRegionOp.getParLevelSimd());
1099
1100 if (failed(handleError(afterIP, opInst)))
1101 return failure();
1102
1103 builder.restoreIP(*afterIP);
1104 return success();
1105}
1106
1107namespace {
1108/// Contains the arguments for an LLVM store operation
1109struct DeferredStore {
1110 DeferredStore(llvm::Value *value, llvm::Value *address)
1111 : value(value), address(address) {}
1112
1113 llvm::Value *value;
1114 llvm::Value *address;
1115};
1116} // namespace
1117
1118/// Allocate space for privatized reduction variables.
1119/// `deferredStores` contains information to create store operations which needs
1120/// to be inserted after all allocas
1121template <typename T>
1122static LogicalResult
1124 llvm::IRBuilderBase &builder,
1125 LLVM::ModuleTranslation &moduleTranslation,
1126 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1128 SmallVectorImpl<llvm::Value *> &privateReductionVariables,
1129 DenseMap<Value, llvm::Value *> &reductionVariableMap,
1130 SmallVectorImpl<DeferredStore> &deferredStores,
1131 llvm::ArrayRef<bool> isByRefs) {
1132 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1133 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1134
1135 // delay creating stores until after all allocas
1136 deferredStores.reserve(loop.getNumReductionVars());
1137
1138 for (std::size_t i = 0; i < loop.getNumReductionVars(); ++i) {
1139 Region &allocRegion = reductionDecls[i].getAllocRegion();
1140 if (isByRefs[i]) {
1141 if (allocRegion.empty())
1142 continue;
1143
1145 if (failed(inlineConvertOmpRegions(allocRegion, "omp.reduction.alloc",
1146 builder, moduleTranslation, &phis)))
1147 return loop.emitError(
1148 "failed to inline `alloc` region of `omp.declare_reduction`");
1149
1150 assert(phis.size() == 1 && "expected one allocation to be yielded");
1151 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1152
1153 // Allocate reduction variable (which is a pointer to the real reduction
1154 // variable allocated in the inlined region)
1155 llvm::Value *var = builder.CreateAlloca(
1156 moduleTranslation.convertType(reductionDecls[i].getType()));
1157
1158 llvm::Type *ptrTy = builder.getPtrTy();
1159 llvm::Value *castVar =
1160 builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
1161 llvm::Value *castPhi =
1162 builder.CreatePointerBitCastOrAddrSpaceCast(phis[0], ptrTy);
1163
1164 deferredStores.emplace_back(castPhi, castVar);
1165
1166 privateReductionVariables[i] = castVar;
1167 moduleTranslation.mapValue(reductionArgs[i], castPhi);
1168 reductionVariableMap.try_emplace(loop.getReductionVars()[i], castPhi);
1169 } else {
1170 assert(allocRegion.empty() &&
1171 "allocaction is implicit for by-val reduction");
1172 llvm::Value *var = builder.CreateAlloca(
1173 moduleTranslation.convertType(reductionDecls[i].getType()));
1174
1175 llvm::Type *ptrTy = builder.getPtrTy();
1176 llvm::Value *castVar =
1177 builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
1178
1179 moduleTranslation.mapValue(reductionArgs[i], castVar);
1180 privateReductionVariables[i] = castVar;
1181 reductionVariableMap.try_emplace(loop.getReductionVars()[i], castVar);
1182 }
1183 }
1184
1185 return success();
1186}
1187
1188/// Map input arguments to reduction initialization region
1189template <typename T>
1190static void
1192 llvm::IRBuilderBase &builder,
1194 DenseMap<Value, llvm::Value *> &reductionVariableMap,
1195 unsigned i) {
1196 // map input argument to the initialization region
1197 mlir::omp::DeclareReductionOp &reduction = reductionDecls[i];
1198 Region &initializerRegion = reduction.getInitializerRegion();
1199 Block &entry = initializerRegion.front();
1200
1201 mlir::Value mlirSource = loop.getReductionVars()[i];
1202 llvm::Value *llvmSource = moduleTranslation.lookupValue(mlirSource);
1203 llvm::Value *origVal = llvmSource;
1204 // If a non-pointer value is expected, load the value from the source pointer.
1205 if (!isa<LLVM::LLVMPointerType>(
1206 reduction.getInitializerMoldArg().getType()) &&
1207 isa<LLVM::LLVMPointerType>(mlirSource.getType())) {
1208 origVal =
1209 builder.CreateLoad(moduleTranslation.convertType(
1210 reduction.getInitializerMoldArg().getType()),
1211 llvmSource, "omp_orig");
1212 }
1213 moduleTranslation.mapValue(reduction.getInitializerMoldArg(), origVal);
1214
1215 if (entry.getNumArguments() > 1) {
1216 llvm::Value *allocation =
1217 reductionVariableMap.lookup(loop.getReductionVars()[i]);
1218 moduleTranslation.mapValue(reduction.getInitializerAllocArg(), allocation);
1219 }
1220}
1221
1222static void
1223setInsertPointForPossiblyEmptyBlock(llvm::IRBuilderBase &builder,
1224 llvm::BasicBlock *block = nullptr) {
1225 if (block == nullptr)
1226 block = builder.GetInsertBlock();
1227
1228 if (block->empty() || block->getTerminator() == nullptr)
1229 builder.SetInsertPoint(block);
1230 else
1231 builder.SetInsertPoint(block->getTerminator());
1232}
1233
1234/// Inline reductions' `init` regions. This functions assumes that the
1235/// `builder`'s insertion point is where the user wants the `init` regions to be
1236/// inlined; i.e. it does not try to find a proper insertion location for the
1237/// `init` regions. It also leaves the `builder's insertions point in a state
1238/// where the user can continue the code-gen directly afterwards.
1239template <typename OP>
1240static LogicalResult
1241initReductionVars(OP op, ArrayRef<BlockArgument> reductionArgs,
1242 llvm::IRBuilderBase &builder,
1243 LLVM::ModuleTranslation &moduleTranslation,
1244 llvm::BasicBlock *latestAllocaBlock,
1246 SmallVectorImpl<llvm::Value *> &privateReductionVariables,
1247 DenseMap<Value, llvm::Value *> &reductionVariableMap,
1248 llvm::ArrayRef<bool> isByRef,
1249 SmallVectorImpl<DeferredStore> &deferredStores) {
1250 if (op.getNumReductionVars() == 0)
1251 return success();
1252
1253 llvm::BasicBlock *initBlock = splitBB(builder, true, "omp.reduction.init");
1254 auto allocaIP = llvm::IRBuilderBase::InsertPoint(
1255 latestAllocaBlock, latestAllocaBlock->getTerminator()->getIterator());
1256 builder.restoreIP(allocaIP);
1257 SmallVector<llvm::Value *> byRefVars(op.getNumReductionVars());
1258
1259 for (unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1260 if (isByRef[i]) {
1261 if (!reductionDecls[i].getAllocRegion().empty())
1262 continue;
1263
1264 // TODO: remove after all users of by-ref are updated to use the alloc
1265 // region: Allocate reduction variable (which is a pointer to the real
1266 // reduciton variable allocated in the inlined region)
1267 byRefVars[i] = builder.CreateAlloca(
1268 moduleTranslation.convertType(reductionDecls[i].getType()));
1269 }
1270 }
1271
1272 setInsertPointForPossiblyEmptyBlock(builder, initBlock);
1273
1274 // store result of the alloc region to the allocated pointer to the real
1275 // reduction variable
1276 for (auto [data, addr] : deferredStores)
1277 builder.CreateStore(data, addr);
1278
1279 // Before the loop, store the initial values of reductions into reduction
1280 // variables. Although this could be done after allocas, we don't want to mess
1281 // up with the alloca insertion point.
1282 for (unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1284
1285 // map block argument to initializer region
1286 mapInitializationArgs(op, moduleTranslation, builder, reductionDecls,
1287 reductionVariableMap, i);
1288
1289 // TODO In some cases (specially on the GPU), the init regions may
1290 // contains stack alloctaions. If the region is inlined in a loop, this is
1291 // problematic. Instead of just inlining the region, handle allocations by
1292 // hoisting fixed length allocations to the function entry and using
1293 // stacksave and restore for variable length ones.
1294 if (failed(inlineConvertOmpRegions(reductionDecls[i].getInitializerRegion(),
1295 "omp.reduction.neutral", builder,
1296 moduleTranslation, &phis)))
1297 return failure();
1298
1299 assert(phis.size() == 1 && "expected one value to be yielded from the "
1300 "reduction neutral element declaration region");
1301
1303
1304 if (isByRef[i]) {
1305 if (!reductionDecls[i].getAllocRegion().empty())
1306 // done in allocReductionVars
1307 continue;
1308
1309 // TODO: this path can be removed once all users of by-ref are updated to
1310 // use an alloc region
1311
1312 // Store the result of the inlined region to the allocated reduction var
1313 // ptr
1314 builder.CreateStore(phis[0], byRefVars[i]);
1315
1316 privateReductionVariables[i] = byRefVars[i];
1317 moduleTranslation.mapValue(reductionArgs[i], phis[0]);
1318 reductionVariableMap.try_emplace(op.getReductionVars()[i], phis[0]);
1319 } else {
1320 // for by-ref case the store is inside of the reduction region
1321 builder.CreateStore(phis[0], privateReductionVariables[i]);
1322 // the rest was handled in allocByValReductionVars
1323 }
1324
1325 // forget the mapping for the initializer region because we might need a
1326 // different mapping if this reduction declaration is re-used for a
1327 // different variable
1328 moduleTranslation.forgetMapping(reductionDecls[i].getInitializerRegion());
1329 }
1330
1331 return success();
1332}
1333
1334/// Collect reduction info
1335template <typename T>
1336static void collectReductionInfo(
1337 T loop, llvm::IRBuilderBase &builder,
1338 LLVM::ModuleTranslation &moduleTranslation,
1341 SmallVectorImpl<OwningAtomicReductionGen> &owningAtomicReductionGens,
1343 const ArrayRef<llvm::Value *> privateReductionVariables,
1345 ArrayRef<bool> isByRef) {
1346 unsigned numReductions = loop.getNumReductionVars();
1347
1348 for (unsigned i = 0; i < numReductions; ++i) {
1349 owningReductionGens.push_back(
1350 makeReductionGen(reductionDecls[i], builder, moduleTranslation));
1351 owningAtomicReductionGens.push_back(
1352 makeAtomicReductionGen(reductionDecls[i], builder, moduleTranslation));
1354 reductionDecls[i], builder, moduleTranslation, isByRef[i]));
1355 }
1356
1357 // Collect the reduction information.
1358 reductionInfos.reserve(numReductions);
1359 for (unsigned i = 0; i < numReductions; ++i) {
1360 llvm::OpenMPIRBuilder::ReductionGenAtomicCBTy atomicGen = nullptr;
1361 if (owningAtomicReductionGens[i])
1362 atomicGen = owningAtomicReductionGens[i];
1363 llvm::Value *variable =
1364 moduleTranslation.lookupValue(loop.getReductionVars()[i]);
1365 mlir::Type allocatedType;
1366 reductionDecls[i].getAllocRegion().walk([&](mlir::Operation *op) {
1367 if (auto alloca = mlir::dyn_cast<LLVM::AllocaOp>(op)) {
1368 allocatedType = alloca.getElemType();
1370 }
1371
1373 });
1374
1375 reductionInfos.push_back(
1376 {moduleTranslation.convertType(reductionDecls[i].getType()), variable,
1377 privateReductionVariables[i],
1378 /*EvaluationKind=*/llvm::OpenMPIRBuilder::EvalKind::Scalar,
1380 /*ReductionGenClang=*/nullptr, atomicGen,
1382 allocatedType ? moduleTranslation.convertType(allocatedType) : nullptr,
1383 reductionDecls[i].getByrefElementType()
1384 ? moduleTranslation.convertType(
1385 *reductionDecls[i].getByrefElementType())
1386 : nullptr});
1387 }
1388}
1389
1390/// handling of DeclareReductionOp's cleanup region
1391static LogicalResult
1393 llvm::ArrayRef<llvm::Value *> privateVariables,
1394 LLVM::ModuleTranslation &moduleTranslation,
1395 llvm::IRBuilderBase &builder, StringRef regionName,
1396 bool shouldLoadCleanupRegionArg = true) {
1397 for (auto [i, cleanupRegion] : llvm::enumerate(cleanupRegions)) {
1398 if (cleanupRegion->empty())
1399 continue;
1400
1401 // map the argument to the cleanup region
1402 Block &entry = cleanupRegion->front();
1403
1404 llvm::Instruction *potentialTerminator =
1405 builder.GetInsertBlock()->empty() ? nullptr
1406 : &builder.GetInsertBlock()->back();
1407 if (potentialTerminator && potentialTerminator->isTerminator())
1408 builder.SetInsertPoint(potentialTerminator);
1409 llvm::Value *privateVarValue =
1410 shouldLoadCleanupRegionArg
1411 ? builder.CreateLoad(
1412 moduleTranslation.convertType(entry.getArgument(0).getType()),
1413 privateVariables[i])
1414 : privateVariables[i];
1415
1416 moduleTranslation.mapValue(entry.getArgument(0), privateVarValue);
1417
1418 if (failed(inlineConvertOmpRegions(*cleanupRegion, regionName, builder,
1419 moduleTranslation)))
1420 return failure();
1421
1422 // clear block argument mapping in case it needs to be re-created with a
1423 // different source for another use of the same reduction decl
1424 moduleTranslation.forgetMapping(*cleanupRegion);
1425 }
1426 return success();
1427}
1428
1429// TODO: not used by ParallelOp
1430template <class OP>
1431static LogicalResult createReductionsAndCleanup(
1432 OP op, llvm::IRBuilderBase &builder,
1433 LLVM::ModuleTranslation &moduleTranslation,
1434 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1436 ArrayRef<llvm::Value *> privateReductionVariables, ArrayRef<bool> isByRef,
1437 bool isNowait = false, bool isTeamsReduction = false) {
1438 // Process the reductions if required.
1439 if (op.getNumReductionVars() == 0)
1440 return success();
1441
1443 SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
1444 SmallVector<OwningDataPtrPtrReductionGen> owningReductionGenRefDataPtrGens;
1446
1447 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
1448
1449 // Create the reduction generators. We need to own them here because
1450 // ReductionInfo only accepts references to the generators.
1451 collectReductionInfo(op, builder, moduleTranslation, reductionDecls,
1452 owningReductionGens, owningAtomicReductionGens,
1453 owningReductionGenRefDataPtrGens,
1454 privateReductionVariables, reductionInfos, isByRef);
1455
1456 // The call to createReductions below expects the block to have a
1457 // terminator. Create an unreachable instruction to serve as terminator
1458 // and remove it later.
1459 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
1460 builder.SetInsertPoint(tempTerminator);
1461 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
1462 ompBuilder->createReductions(builder.saveIP(), allocaIP, reductionInfos,
1463 isByRef, isNowait, isTeamsReduction);
1464
1465 if (failed(handleError(contInsertPoint, *op)))
1466 return failure();
1467
1468 if (!contInsertPoint->getBlock())
1469 return op->emitOpError() << "failed to convert reductions";
1470
1471 llvm::OpenMPIRBuilder::InsertPointTy afterIP = *contInsertPoint;
1472 if (!isTeamsReduction) {
1473 llvm::OpenMPIRBuilder::InsertPointOrErrorTy barrierIP =
1474 ompBuilder->createBarrier(*contInsertPoint, llvm::omp::OMPD_for);
1475
1476 if (failed(handleError(barrierIP, *op)))
1477 return failure();
1478 afterIP = *barrierIP;
1479 }
1480
1481 tempTerminator->eraseFromParent();
1482 builder.restoreIP(afterIP);
1483
1484 // after the construct, deallocate private reduction variables
1485 SmallVector<Region *> reductionRegions;
1486 llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
1487 [](omp::DeclareReductionOp reductionDecl) {
1488 return &reductionDecl.getCleanupRegion();
1489 });
1490 return inlineOmpRegionCleanup(reductionRegions, privateReductionVariables,
1491 moduleTranslation, builder,
1492 "omp.reduction.cleanup");
1493 return success();
1494}
1495
1496static ArrayRef<bool> getIsByRef(std::optional<ArrayRef<bool>> attr) {
1497 if (!attr)
1498 return {};
1499 return *attr;
1500}
1501
1502// TODO: not used by omp.parallel
1503template <typename OP>
1505 OP op, ArrayRef<BlockArgument> reductionArgs, llvm::IRBuilderBase &builder,
1506 LLVM::ModuleTranslation &moduleTranslation,
1507 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1509 SmallVectorImpl<llvm::Value *> &privateReductionVariables,
1510 DenseMap<Value, llvm::Value *> &reductionVariableMap,
1511 llvm::ArrayRef<bool> isByRef) {
1512 if (op.getNumReductionVars() == 0)
1513 return success();
1514
1515 SmallVector<DeferredStore> deferredStores;
1516
1517 if (failed(allocReductionVars(op, reductionArgs, builder, moduleTranslation,
1518 allocaIP, reductionDecls,
1519 privateReductionVariables, reductionVariableMap,
1520 deferredStores, isByRef)))
1521 return failure();
1522
1523 return initReductionVars(op, reductionArgs, builder, moduleTranslation,
1524 allocaIP.getBlock(), reductionDecls,
1525 privateReductionVariables, reductionVariableMap,
1526 isByRef, deferredStores);
1527}
1528
1529/// Return the llvm::Value * corresponding to the `privateVar` that
1530/// is being privatized. It isn't always as simple as looking up
1531/// moduleTranslation with privateVar. For instance, in case of
1532/// an allocatable, the descriptor for the allocatable is privatized.
1533/// This descriptor is mapped using an MapInfoOp. So, this function
1534/// will return a pointer to the llvm::Value corresponding to the
1535/// block argument for the mapped descriptor.
1536static llvm::Value *
1537findAssociatedValue(Value privateVar, llvm::IRBuilderBase &builder,
1538 LLVM::ModuleTranslation &moduleTranslation,
1539 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1540 if (mappedPrivateVars == nullptr || !mappedPrivateVars->contains(privateVar))
1541 return moduleTranslation.lookupValue(privateVar);
1542
1543 Value blockArg = (*mappedPrivateVars)[privateVar];
1544 Type privVarType = privateVar.getType();
1545 Type blockArgType = blockArg.getType();
1546 assert(isa<LLVM::LLVMPointerType>(blockArgType) &&
1547 "A block argument corresponding to a mapped var should have "
1548 "!llvm.ptr type");
1549
1550 if (privVarType == blockArgType)
1551 return moduleTranslation.lookupValue(blockArg);
1552
1553 // This typically happens when the privatized type is lowered from
1554 // boxchar<KIND> and gets lowered to !llvm.struct<(ptr, i64)>. That is the
1555 // struct/pair is passed by value. But, mapped values are passed only as
1556 // pointers, so before we privatize, we must load the pointer.
1557 if (!isa<LLVM::LLVMPointerType>(privVarType))
1558 return builder.CreateLoad(moduleTranslation.convertType(privVarType),
1559 moduleTranslation.lookupValue(blockArg));
1560
1561 return moduleTranslation.lookupValue(privateVar);
1562}
1563
1564/// Initialize a single (first)private variable. You probably want to use
1565/// allocateAndInitPrivateVars instead of this.
1566/// This returns the private variable which has been initialized. This
1567/// variable should be mapped before constructing the body of the Op.
1569initPrivateVar(llvm::IRBuilderBase &builder,
1570 LLVM::ModuleTranslation &moduleTranslation,
1571 omp::PrivateClauseOp &privDecl, llvm::Value *nonPrivateVar,
1572 BlockArgument &blockArg, llvm::Value *llvmPrivateVar,
1573 llvm::BasicBlock *privInitBlock,
1574 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1575 Region &initRegion = privDecl.getInitRegion();
1576 if (initRegion.empty())
1577 return llvmPrivateVar;
1578
1579 assert(nonPrivateVar);
1580 moduleTranslation.mapValue(privDecl.getInitMoldArg(), nonPrivateVar);
1581 moduleTranslation.mapValue(privDecl.getInitPrivateArg(), llvmPrivateVar);
1582
1583 // in-place convert the private initialization region
1585 if (failed(inlineConvertOmpRegions(initRegion, "omp.private.init", builder,
1586 moduleTranslation, &phis)))
1587 return llvm::createStringError(
1588 "failed to inline `init` region of `omp.private`");
1589
1590 assert(phis.size() == 1 && "expected one allocation to be yielded");
1591
1592 // clear init region block argument mapping in case it needs to be
1593 // re-created with a different source for another use of the same
1594 // reduction decl
1595 moduleTranslation.forgetMapping(initRegion);
1596
1597 // Prefer the value yielded from the init region to the allocated private
1598 // variable in case the region is operating on arguments by-value (e.g.
1599 // Fortran character boxes).
1600 return phis[0];
1601}
1602
1603/// Version of initPrivateVar which looks up the nonPrivateVar from mlirPrivVar.
1605 llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation,
1606 omp::PrivateClauseOp &privDecl, Value mlirPrivVar, BlockArgument &blockArg,
1607 llvm::Value *llvmPrivateVar, llvm::BasicBlock *privInitBlock,
1608 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1609 return initPrivateVar(
1610 builder, moduleTranslation, privDecl,
1611 findAssociatedValue(mlirPrivVar, builder, moduleTranslation,
1612 mappedPrivateVars),
1613 blockArg, llvmPrivateVar, privInitBlock, mappedPrivateVars);
1614}
1615
1616static llvm::Error
1617initPrivateVars(llvm::IRBuilderBase &builder,
1618 LLVM::ModuleTranslation &moduleTranslation,
1619 PrivateVarsInfo &privateVarsInfo,
1620 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1621 if (privateVarsInfo.blockArgs.empty())
1622 return llvm::Error::success();
1623
1624 llvm::BasicBlock *privInitBlock = splitBB(builder, true, "omp.private.init");
1625 setInsertPointForPossiblyEmptyBlock(builder, privInitBlock);
1626
1627 for (auto [idx, zip] : llvm::enumerate(llvm::zip_equal(
1628 privateVarsInfo.privatizers, privateVarsInfo.mlirVars,
1629 privateVarsInfo.blockArgs, privateVarsInfo.llvmVars))) {
1630 auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVar] = zip;
1632 builder, moduleTranslation, privDecl, mlirPrivVar, blockArg,
1633 llvmPrivateVar, privInitBlock, mappedPrivateVars);
1634
1635 if (!privVarOrErr)
1636 return privVarOrErr.takeError();
1637
1638 llvmPrivateVar = privVarOrErr.get();
1639 moduleTranslation.mapValue(blockArg, llvmPrivateVar);
1640
1642 }
1643
1644 return llvm::Error::success();
1645}
1646
1647/// Allocate and initialize delayed private variables. Returns the basic block
1648/// which comes after all of these allocations. llvm::Value * for each of these
1649/// private variables are populated in llvmPrivateVars.
1651allocatePrivateVars(llvm::IRBuilderBase &builder,
1652 LLVM::ModuleTranslation &moduleTranslation,
1653 PrivateVarsInfo &privateVarsInfo,
1654 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1655 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1656 // Allocate private vars
1657 llvm::Instruction *allocaTerminator = allocaIP.getBlock()->getTerminator();
1658 splitBB(llvm::OpenMPIRBuilder::InsertPointTy(allocaIP.getBlock(),
1659 allocaTerminator->getIterator()),
1660 true, allocaTerminator->getStableDebugLoc(),
1661 "omp.region.after_alloca");
1662
1663 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1664 // Update the allocaTerminator since the alloca block was split above.
1665 allocaTerminator = allocaIP.getBlock()->getTerminator();
1666 builder.SetInsertPoint(allocaTerminator);
1667 // The new terminator is an uncondition branch created by the splitBB above.
1668 assert(allocaTerminator->getNumSuccessors() == 1 &&
1669 "This is an unconditional branch created by splitBB");
1670
1671 llvm::DataLayout dataLayout = builder.GetInsertBlock()->getDataLayout();
1672 llvm::BasicBlock *afterAllocas = allocaTerminator->getSuccessor(0);
1673
1674 unsigned int allocaAS =
1675 moduleTranslation.getLLVMModule()->getDataLayout().getAllocaAddrSpace();
1676 unsigned int defaultAS = moduleTranslation.getLLVMModule()
1677 ->getDataLayout()
1678 .getProgramAddressSpace();
1679
1680 for (auto [privDecl, mlirPrivVar, blockArg] :
1681 llvm::zip_equal(privateVarsInfo.privatizers, privateVarsInfo.mlirVars,
1682 privateVarsInfo.blockArgs)) {
1683 llvm::Type *llvmAllocType =
1684 moduleTranslation.convertType(privDecl.getType());
1685 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1686 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
1687 llvmAllocType, /*ArraySize=*/nullptr, "omp.private.alloc");
1688 if (allocaAS != defaultAS)
1689 llvmPrivateVar = builder.CreateAddrSpaceCast(llvmPrivateVar,
1690 builder.getPtrTy(defaultAS));
1691
1692 privateVarsInfo.llvmVars.push_back(llvmPrivateVar);
1693 }
1694
1695 return afterAllocas;
1696}
1697
1698/// This can't always be determined statically, but when we can, it is good to
1699/// avoid generating compiler-added barriers which will deadlock the program.
1701 for (mlir::Operation *parent = op->getParentOp(); parent != nullptr;
1702 parent = parent->getParentOp()) {
1703 if (mlir::isa<omp::SingleOp, omp::CriticalOp>(parent))
1704 return true;
1705
1706 // e.g.
1707 // omp.single {
1708 // omp.parallel {
1709 // op
1710 // }
1711 // }
1712 if (mlir::isa<omp::ParallelOp>(parent))
1713 return false;
1714 }
1715 return false;
1716}
1717
1718static LogicalResult copyFirstPrivateVars(
1719 mlir::Operation *op, llvm::IRBuilderBase &builder,
1720 LLVM::ModuleTranslation &moduleTranslation,
1722 ArrayRef<llvm::Value *> llvmPrivateVars,
1723 SmallVectorImpl<omp::PrivateClauseOp> &privateDecls, bool insertBarrier,
1724 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1725 // Apply copy region for firstprivate.
1726 bool needsFirstprivate =
1727 llvm::any_of(privateDecls, [](omp::PrivateClauseOp &privOp) {
1728 return privOp.getDataSharingType() ==
1729 omp::DataSharingClauseType::FirstPrivate;
1730 });
1731
1732 if (!needsFirstprivate)
1733 return success();
1734
1735 llvm::BasicBlock *copyBlock =
1736 splitBB(builder, /*CreateBranch=*/true, "omp.private.copy");
1737 setInsertPointForPossiblyEmptyBlock(builder, copyBlock);
1738
1739 for (auto [decl, moldVar, llvmVar] :
1740 llvm::zip_equal(privateDecls, moldVars, llvmPrivateVars)) {
1741 if (decl.getDataSharingType() != omp::DataSharingClauseType::FirstPrivate)
1742 continue;
1743
1744 // copyRegion implements `lhs = rhs`
1745 Region &copyRegion = decl.getCopyRegion();
1746
1747 moduleTranslation.mapValue(decl.getCopyMoldArg(), moldVar);
1748
1749 // map copyRegion lhs arg
1750 moduleTranslation.mapValue(decl.getCopyPrivateArg(), llvmVar);
1751
1752 // in-place convert copy region
1753 if (failed(inlineConvertOmpRegions(copyRegion, "omp.private.copy", builder,
1754 moduleTranslation)))
1755 return decl.emitError("failed to inline `copy` region of `omp.private`");
1756
1758
1759 // ignore unused value yielded from copy region
1760
1761 // clear copy region block argument mapping in case it needs to be
1762 // re-created with different sources for reuse of the same reduction
1763 // decl
1764 moduleTranslation.forgetMapping(copyRegion);
1765 }
1766
1767 if (insertBarrier && !opIsInSingleThread(op)) {
1768 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
1769 llvm::OpenMPIRBuilder::InsertPointOrErrorTy res =
1770 ompBuilder->createBarrier(builder.saveIP(), llvm::omp::OMPD_barrier);
1771 if (failed(handleError(res, *op)))
1772 return failure();
1773 }
1774
1775 return success();
1776}
1777
1778static LogicalResult copyFirstPrivateVars(
1779 mlir::Operation *op, llvm::IRBuilderBase &builder,
1780 LLVM::ModuleTranslation &moduleTranslation,
1781 SmallVectorImpl<mlir::Value> &mlirPrivateVars,
1782 ArrayRef<llvm::Value *> llvmPrivateVars,
1783 SmallVectorImpl<omp::PrivateClauseOp> &privateDecls, bool insertBarrier,
1784 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1785 llvm::SmallVector<llvm::Value *> moldVars(mlirPrivateVars.size());
1786 llvm::transform(mlirPrivateVars, moldVars.begin(), [&](mlir::Value mlirVar) {
1787 // map copyRegion rhs arg
1788 llvm::Value *moldVar = findAssociatedValue(
1789 mlirVar, builder, moduleTranslation, mappedPrivateVars);
1790 assert(moldVar);
1791 return moldVar;
1792 });
1793 return copyFirstPrivateVars(op, builder, moduleTranslation, moldVars,
1794 llvmPrivateVars, privateDecls, insertBarrier,
1795 mappedPrivateVars);
1796}
1797
1798static LogicalResult
1799cleanupPrivateVars(llvm::IRBuilderBase &builder,
1800 LLVM::ModuleTranslation &moduleTranslation, Location loc,
1801 SmallVectorImpl<llvm::Value *> &llvmPrivateVars,
1803 // private variable deallocation
1804 SmallVector<Region *> privateCleanupRegions;
1805 llvm::transform(privateDecls, std::back_inserter(privateCleanupRegions),
1806 [](omp::PrivateClauseOp privatizer) {
1807 return &privatizer.getDeallocRegion();
1808 });
1809
1810 if (failed(inlineOmpRegionCleanup(
1811 privateCleanupRegions, llvmPrivateVars, moduleTranslation, builder,
1812 "omp.private.dealloc", /*shouldLoadCleanupRegionArg=*/false)))
1813 return mlir::emitError(loc, "failed to inline `dealloc` region of an "
1814 "`omp.private` op in");
1815
1816 return success();
1817}
1818
1819/// Returns true if the construct contains omp.cancel or omp.cancellation_point
1821 // omp.cancel and omp.cancellation_point must be "closely nested" so they will
1822 // be visible and not inside of function calls. This is enforced by the
1823 // verifier.
1824 return op
1825 ->walk([](Operation *child) {
1826 if (mlir::isa<omp::CancelOp, omp::CancellationPointOp>(child))
1827 return WalkResult::interrupt();
1828 return WalkResult::advance();
1829 })
1830 .wasInterrupted();
1831}
1832
1833static LogicalResult
1834convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder,
1835 LLVM::ModuleTranslation &moduleTranslation) {
1836 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1837 using StorableBodyGenCallbackTy =
1838 llvm::OpenMPIRBuilder::StorableBodyGenCallbackTy;
1839
1840 auto sectionsOp = cast<omp::SectionsOp>(opInst);
1841
1842 if (failed(checkImplementationStatus(opInst)))
1843 return failure();
1844
1845 llvm::ArrayRef<bool> isByRef = getIsByRef(sectionsOp.getReductionByref());
1846 assert(isByRef.size() == sectionsOp.getNumReductionVars());
1847
1849 collectReductionDecls(sectionsOp, reductionDecls);
1850 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1851 findAllocaInsertPoint(builder, moduleTranslation);
1852
1853 SmallVector<llvm::Value *> privateReductionVariables(
1854 sectionsOp.getNumReductionVars());
1855 DenseMap<Value, llvm::Value *> reductionVariableMap;
1856
1857 MutableArrayRef<BlockArgument> reductionArgs =
1858 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
1859
1861 sectionsOp, reductionArgs, builder, moduleTranslation, allocaIP,
1862 reductionDecls, privateReductionVariables, reductionVariableMap,
1863 isByRef)))
1864 return failure();
1865
1867
1868 for (Operation &op : *sectionsOp.getRegion().begin()) {
1869 auto sectionOp = dyn_cast<omp::SectionOp>(op);
1870 if (!sectionOp) // omp.terminator
1871 continue;
1872
1873 Region &region = sectionOp.getRegion();
1874 auto sectionCB = [&sectionsOp, &region, &builder, &moduleTranslation](
1875 InsertPointTy allocaIP, InsertPointTy codeGenIP) {
1876 builder.restoreIP(codeGenIP);
1877
1878 // map the omp.section reduction block argument to the omp.sections block
1879 // arguments
1880 // TODO: this assumes that the only block arguments are reduction
1881 // variables
1882 assert(region.getNumArguments() ==
1883 sectionsOp.getRegion().getNumArguments());
1884 for (auto [sectionsArg, sectionArg] : llvm::zip_equal(
1885 sectionsOp.getRegion().getArguments(), region.getArguments())) {
1886 llvm::Value *llvmVal = moduleTranslation.lookupValue(sectionsArg);
1887 assert(llvmVal);
1888 moduleTranslation.mapValue(sectionArg, llvmVal);
1889 }
1890
1891 return convertOmpOpRegions(region, "omp.section.region", builder,
1892 moduleTranslation)
1893 .takeError();
1894 };
1895 sectionCBs.push_back(sectionCB);
1896 }
1897
1898 // No sections within omp.sections operation - skip generation. This situation
1899 // is only possible if there is only a terminator operation inside the
1900 // sections operation
1901 if (sectionCBs.empty())
1902 return success();
1903
1904 assert(isa<omp::SectionOp>(*sectionsOp.getRegion().op_begin()));
1905
1906 // TODO: Perform appropriate actions according to the data-sharing
1907 // attribute (shared, private, firstprivate, ...) of variables.
1908 // Currently defaults to shared.
1909 auto privCB = [&](InsertPointTy, InsertPointTy codeGenIP, llvm::Value &,
1910 llvm::Value &vPtr, llvm::Value *&replacementValue)
1911 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1912 replacementValue = &vPtr;
1913 return codeGenIP;
1914 };
1915
1916 // TODO: Perform finalization actions for variables. This has to be
1917 // called for variables which have destructors/finalizers.
1918 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
1919
1920 allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
1921 bool isCancellable = constructIsCancellable(sectionsOp);
1922 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1923 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1924 moduleTranslation.getOpenMPBuilder()->createSections(
1925 ompLoc, allocaIP, sectionCBs, privCB, finiCB, isCancellable,
1926 sectionsOp.getNowait());
1927
1928 if (failed(handleError(afterIP, opInst)))
1929 return failure();
1930
1931 builder.restoreIP(*afterIP);
1932
1933 // Process the reductions if required.
1935 sectionsOp, builder, moduleTranslation, allocaIP, reductionDecls,
1936 privateReductionVariables, isByRef, sectionsOp.getNowait());
1937}
1938
1939/// Converts an OpenMP single construct into LLVM IR using OpenMPIRBuilder.
1940static LogicalResult
1941convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder,
1942 LLVM::ModuleTranslation &moduleTranslation) {
1943 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1944 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1945
1946 if (failed(checkImplementationStatus(*singleOp)))
1947 return failure();
1948
1949 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
1950 builder.restoreIP(codegenIP);
1951 return convertOmpOpRegions(singleOp.getRegion(), "omp.single.region",
1952 builder, moduleTranslation)
1953 .takeError();
1954 };
1955 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
1956
1957 // Handle copyprivate
1958 Operation::operand_range cpVars = singleOp.getCopyprivateVars();
1959 std::optional<ArrayAttr> cpFuncs = singleOp.getCopyprivateSyms();
1962 for (size_t i = 0, e = cpVars.size(); i < e; ++i) {
1963 llvmCPVars.push_back(moduleTranslation.lookupValue(cpVars[i]));
1965 singleOp, cast<SymbolRefAttr>((*cpFuncs)[i]));
1966 llvmCPFuncs.push_back(
1967 moduleTranslation.lookupFunction(llvmFuncOp.getName()));
1968 }
1969
1970 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1971 moduleTranslation.getOpenMPBuilder()->createSingle(
1972 ompLoc, bodyCB, finiCB, singleOp.getNowait(), llvmCPVars,
1973 llvmCPFuncs);
1974
1975 if (failed(handleError(afterIP, *singleOp)))
1976 return failure();
1977
1978 builder.restoreIP(*afterIP);
1979 return success();
1980}
1981
1982static bool teamsReductionContainedInDistribute(omp::TeamsOp teamsOp) {
1983 auto iface =
1984 llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(teamsOp.getOperation());
1985 // Check that all uses of the reduction block arg has the same distribute op
1986 // parent.
1988 Operation *distOp = nullptr;
1989 for (auto ra : iface.getReductionBlockArgs())
1990 for (auto &use : ra.getUses()) {
1991 auto *useOp = use.getOwner();
1992 // Ignore debug uses.
1993 if (mlir::isa<LLVM::DbgDeclareOp, LLVM::DbgValueOp>(useOp)) {
1994 debugUses.push_back(useOp);
1995 continue;
1996 }
1997
1998 auto currentDistOp = useOp->getParentOfType<omp::DistributeOp>();
1999 // Use is not inside a distribute op - return false
2000 if (!currentDistOp)
2001 return false;
2002 // Multiple distribute operations - return false
2003 Operation *currentOp = currentDistOp.getOperation();
2004 if (distOp && (distOp != currentOp))
2005 return false;
2006
2007 distOp = currentOp;
2008 }
2009
2010 // If we are going to use distribute reduction then remove any debug uses of
2011 // the reduction parameters in teamsOp. Otherwise they will be left without
2012 // any mapped value in moduleTranslation and will eventually error out.
2013 for (auto *use : debugUses)
2014 use->erase();
2015 return true;
2016}
2017
2018// Convert an OpenMP Teams construct to LLVM IR using OpenMPIRBuilder
2019static LogicalResult
2020convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder,
2021 LLVM::ModuleTranslation &moduleTranslation) {
2022 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2023 if (failed(checkImplementationStatus(*op)))
2024 return failure();
2025
2026 DenseMap<Value, llvm::Value *> reductionVariableMap;
2027 unsigned numReductionVars = op.getNumReductionVars();
2029 SmallVector<llvm::Value *> privateReductionVariables(numReductionVars);
2030 llvm::ArrayRef<bool> isByRef;
2031 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2032 findAllocaInsertPoint(builder, moduleTranslation);
2033
2034 // Only do teams reduction if there is no distribute op that captures the
2035 // reduction instead.
2036 bool doTeamsReduction = !teamsReductionContainedInDistribute(op);
2037 if (doTeamsReduction) {
2038 isByRef = getIsByRef(op.getReductionByref());
2039
2040 assert(isByRef.size() == op.getNumReductionVars());
2041
2042 MutableArrayRef<BlockArgument> reductionArgs =
2043 llvm::cast<omp::BlockArgOpenMPOpInterface>(*op).getReductionBlockArgs();
2044
2045 collectReductionDecls(op, reductionDecls);
2046
2048 op, reductionArgs, builder, moduleTranslation, allocaIP,
2049 reductionDecls, privateReductionVariables, reductionVariableMap,
2050 isByRef)))
2051 return failure();
2052 }
2053
2054 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
2056 moduleTranslation, allocaIP);
2057 builder.restoreIP(codegenIP);
2058 return convertOmpOpRegions(op.getRegion(), "omp.teams.region", builder,
2059 moduleTranslation)
2060 .takeError();
2061 };
2062
2063 llvm::Value *numTeamsLower = nullptr;
2064 if (Value numTeamsLowerVar = op.getNumTeamsLower())
2065 numTeamsLower = moduleTranslation.lookupValue(numTeamsLowerVar);
2066
2067 llvm::Value *numTeamsUpper = nullptr;
2068 if (!op.getNumTeamsUpperVars().empty())
2069 numTeamsUpper = moduleTranslation.lookupValue(op.getNumTeams(0));
2070
2071 llvm::Value *threadLimit = nullptr;
2072 if (!op.getThreadLimitVars().empty())
2073 threadLimit = moduleTranslation.lookupValue(op.getThreadLimit(0));
2074
2075 llvm::Value *ifExpr = nullptr;
2076 if (Value ifVar = op.getIfExpr())
2077 ifExpr = moduleTranslation.lookupValue(ifVar);
2078
2079 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2080 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2081 moduleTranslation.getOpenMPBuilder()->createTeams(
2082 ompLoc, bodyCB, numTeamsLower, numTeamsUpper, threadLimit, ifExpr);
2083
2084 if (failed(handleError(afterIP, *op)))
2085 return failure();
2086
2087 builder.restoreIP(*afterIP);
2088 if (doTeamsReduction) {
2089 // Process the reductions if required.
2091 op, builder, moduleTranslation, allocaIP, reductionDecls,
2092 privateReductionVariables, isByRef,
2093 /*isNoWait*/ false, /*isTeamsReduction*/ true);
2094 }
2095 return success();
2096}
2097
2098static void
2099buildDependData(std::optional<ArrayAttr> dependKinds, OperandRange dependVars,
2100 LLVM::ModuleTranslation &moduleTranslation,
2102 if (dependVars.empty())
2103 return;
2104 for (auto dep : llvm::zip(dependVars, dependKinds->getValue())) {
2105 llvm::omp::RTLDependenceKindTy type;
2106 switch (
2107 cast<mlir::omp::ClauseTaskDependAttr>(std::get<1>(dep)).getValue()) {
2108 case mlir::omp::ClauseTaskDepend::taskdependin:
2109 type = llvm::omp::RTLDependenceKindTy::DepIn;
2110 break;
2111 // The OpenMP runtime requires that the codegen for 'depend' clause for
2112 // 'out' dependency kind must be the same as codegen for 'depend' clause
2113 // with 'inout' dependency.
2114 case mlir::omp::ClauseTaskDepend::taskdependout:
2115 case mlir::omp::ClauseTaskDepend::taskdependinout:
2116 type = llvm::omp::RTLDependenceKindTy::DepInOut;
2117 break;
2118 case mlir::omp::ClauseTaskDepend::taskdependmutexinoutset:
2119 type = llvm::omp::RTLDependenceKindTy::DepMutexInOutSet;
2120 break;
2121 case mlir::omp::ClauseTaskDepend::taskdependinoutset:
2122 type = llvm::omp::RTLDependenceKindTy::DepInOutSet;
2123 break;
2124 };
2125 llvm::Value *depVal = moduleTranslation.lookupValue(std::get<0>(dep));
2126 llvm::OpenMPIRBuilder::DependData dd(type, depVal->getType(), depVal);
2127 dds.emplace_back(dd);
2128 }
2129}
2130
2131/// Shared implementation of a callback which adds a termiator for the new block
2132/// created for the branch taken when an openmp construct is cancelled. The
2133/// terminator is saved in \p cancelTerminators. This callback is invoked only
2134/// if there is cancellation inside of the taskgroup body.
2135/// The terminator will need to be fixed to branch to the correct block to
2136/// cleanup the construct.
2137static void
2139 llvm::IRBuilderBase &llvmBuilder,
2140 llvm::OpenMPIRBuilder &ompBuilder, mlir::Operation *op,
2141 llvm::omp::Directive cancelDirective) {
2142 auto finiCB = [&](llvm::OpenMPIRBuilder::InsertPointTy ip) -> llvm::Error {
2143 llvm::IRBuilderBase::InsertPointGuard guard(llvmBuilder);
2144
2145 // ip is currently in the block branched to if cancellation occurred.
2146 // We need to create a branch to terminate that block.
2147 llvmBuilder.restoreIP(ip);
2148
2149 // We must still clean up the construct after cancelling it, so we need to
2150 // branch to the block that finalizes the taskgroup.
2151 // That block has not been created yet so use this block as a dummy for now
2152 // and fix this after creating the operation.
2153 cancelTerminators.push_back(llvmBuilder.CreateBr(ip.getBlock()));
2154 return llvm::Error::success();
2155 };
2156 // We have to add the cleanup to the OpenMPIRBuilder before the body gets
2157 // created in case the body contains omp.cancel (which will then expect to be
2158 // able to find this cleanup callback).
2159 ompBuilder.pushFinalizationCB(
2160 {finiCB, cancelDirective, constructIsCancellable(op)});
2161}
2162
2163/// If we cancelled the construct, we should branch to the finalization block of
2164/// that construct. OMPIRBuilder structures the CFG such that the cleanup block
2165/// is immediately before the continuation block. Now this finalization has
2166/// been created we can fix the branch.
2167static void
2169 llvm::OpenMPIRBuilder &ompBuilder,
2170 const llvm::OpenMPIRBuilder::InsertPointTy &afterIP) {
2171 ompBuilder.popFinalizationCB();
2172 llvm::BasicBlock *constructFini = afterIP.getBlock()->getSinglePredecessor();
2173 for (llvm::BranchInst *cancelBranch : cancelTerminators) {
2174 assert(cancelBranch->getNumSuccessors() == 1 &&
2175 "cancel branch should have one target");
2176 cancelBranch->setSuccessor(0, constructFini);
2177 }
2178}
2179
2180namespace {
2181/// TaskContextStructManager takes care of creating and freeing a structure
2182/// containing information needed by the task body to execute.
2183class TaskContextStructManager {
2184public:
2185 TaskContextStructManager(llvm::IRBuilderBase &builder,
2186 LLVM::ModuleTranslation &moduleTranslation,
2187 MutableArrayRef<omp::PrivateClauseOp> privateDecls)
2188 : builder{builder}, moduleTranslation{moduleTranslation},
2189 privateDecls{privateDecls} {}
2190
2191 /// Creates a heap allocated struct containing space for each private
2192 /// variable. Invariant: privateVarTypes, privateDecls, and the elements of
2193 /// the structure should all have the same order (although privateDecls which
2194 /// do not read from the mold argument are skipped).
2195 void generateTaskContextStruct();
2196
2197 /// Create GEPs to access each member of the structure representing a private
2198 /// variable, adding them to llvmPrivateVars. Null values are added where
2199 /// private decls were skipped so that the ordering continues to match the
2200 /// private decls.
2201 void createGEPsToPrivateVars();
2202
2203 /// Given the address of the structure, return a GEP for each private variable
2204 /// in the structure. Null values are added where private decls were skipped
2205 /// so that the ordering continues to match the private decls.
2206 /// Must be called after generateTaskContextStruct().
2207 SmallVector<llvm::Value *>
2208 createGEPsToPrivateVars(llvm::Value *altStructPtr) const;
2209
2210 /// De-allocate the task context structure.
2211 void freeStructPtr();
2212
2213 MutableArrayRef<llvm::Value *> getLLVMPrivateVarGEPs() {
2214 return llvmPrivateVarGEPs;
2215 }
2216
2217 llvm::Value *getStructPtr() { return structPtr; }
2218
2219private:
2220 llvm::IRBuilderBase &builder;
2221 LLVM::ModuleTranslation &moduleTranslation;
2222 MutableArrayRef<omp::PrivateClauseOp> privateDecls;
2223
2224 /// The type of each member of the structure, in order.
2225 SmallVector<llvm::Type *> privateVarTypes;
2226
2227 /// LLVM values for each private variable, or null if that private variable is
2228 /// not included in the task context structure
2229 SmallVector<llvm::Value *> llvmPrivateVarGEPs;
2230
2231 /// A pointer to the structure containing context for this task.
2232 llvm::Value *structPtr = nullptr;
2233 /// The type of the structure
2234 llvm::Type *structTy = nullptr;
2235};
2236} // namespace
2237
2238void TaskContextStructManager::generateTaskContextStruct() {
2239 if (privateDecls.empty())
2240 return;
2241 privateVarTypes.reserve(privateDecls.size());
2242
2243 for (omp::PrivateClauseOp &privOp : privateDecls) {
2244 // Skip private variables which can safely be allocated and initialised
2245 // inside of the task
2246 if (!privOp.readsFromMold())
2247 continue;
2248 Type mlirType = privOp.getType();
2249 privateVarTypes.push_back(moduleTranslation.convertType(mlirType));
2250 }
2251
2252 if (privateVarTypes.empty())
2253 return;
2254
2255 structTy = llvm::StructType::get(moduleTranslation.getLLVMContext(),
2256 privateVarTypes);
2257
2258 llvm::DataLayout dataLayout =
2259 builder.GetInsertBlock()->getModule()->getDataLayout();
2260 llvm::Type *intPtrTy = builder.getIntPtrTy(dataLayout);
2261 llvm::Constant *allocSize = llvm::ConstantExpr::getSizeOf(structTy);
2262
2263 // Heap allocate the structure
2264 structPtr = builder.CreateMalloc(intPtrTy, structTy, allocSize,
2265 /*ArraySize=*/nullptr, /*MallocF=*/nullptr,
2266 "omp.task.context_ptr");
2267}
2268
2269SmallVector<llvm::Value *> TaskContextStructManager::createGEPsToPrivateVars(
2270 llvm::Value *altStructPtr) const {
2271 SmallVector<llvm::Value *> ret;
2272
2273 // Create GEPs for each struct member
2274 ret.reserve(privateDecls.size());
2275 llvm::Value *zero = builder.getInt32(0);
2276 unsigned i = 0;
2277 for (auto privDecl : privateDecls) {
2278 if (!privDecl.readsFromMold()) {
2279 // Handle this inside of the task so we don't pass unnessecary vars in
2280 ret.push_back(nullptr);
2281 continue;
2282 }
2283 llvm::Value *iVal = builder.getInt32(i);
2284 llvm::Value *gep = builder.CreateGEP(structTy, altStructPtr, {zero, iVal});
2285 ret.push_back(gep);
2286 i += 1;
2287 }
2288 return ret;
2289}
2290
2291void TaskContextStructManager::createGEPsToPrivateVars() {
2292 if (!structPtr)
2293 assert(privateVarTypes.empty());
2294 // Still need to run createGEPsToPrivateVars to populate llvmPrivateVarGEPs
2295 // with null values for skipped private decls
2296
2297 llvmPrivateVarGEPs = createGEPsToPrivateVars(structPtr);
2298}
2299
2300void TaskContextStructManager::freeStructPtr() {
2301 if (!structPtr)
2302 return;
2303
2304 llvm::IRBuilderBase::InsertPointGuard guard{builder};
2305 // Ensure we don't put the call to free() after the terminator
2306 builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
2307 builder.CreateFree(structPtr);
2308}
2309
2310/// Converts an OpenMP task construct into LLVM IR using OpenMPIRBuilder.
2311static LogicalResult
2312convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
2313 LLVM::ModuleTranslation &moduleTranslation) {
2314 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2315 if (failed(checkImplementationStatus(*taskOp)))
2316 return failure();
2317
2318 PrivateVarsInfo privateVarsInfo(taskOp);
2319 TaskContextStructManager taskStructMgr{builder, moduleTranslation,
2320 privateVarsInfo.privatizers};
2321
2322 // Allocate and copy private variables before creating the task. This avoids
2323 // accessing invalid memory if (after this scope ends) the private variables
2324 // are initialized from host variables or if the variables are copied into
2325 // from host variables (firstprivate). The insertion point is just before
2326 // where the code for creating and scheduling the task will go. That puts this
2327 // code outside of the outlined task region, which is what we want because
2328 // this way the initialization and copy regions are executed immediately while
2329 // the host variable data are still live.
2330
2331 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2332 findAllocaInsertPoint(builder, moduleTranslation);
2333
2334 // Not using splitBB() because that requires the current block to have a
2335 // terminator.
2336 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end());
2337 llvm::BasicBlock *taskStartBlock = llvm::BasicBlock::Create(
2338 builder.getContext(), "omp.task.start",
2339 /*Parent=*/builder.GetInsertBlock()->getParent());
2340 llvm::Instruction *branchToTaskStartBlock = builder.CreateBr(taskStartBlock);
2341 builder.SetInsertPoint(branchToTaskStartBlock);
2342
2343 // Now do this again to make the initialization and copy blocks
2344 llvm::BasicBlock *copyBlock =
2345 splitBB(builder, /*CreateBranch=*/true, "omp.private.copy");
2346 llvm::BasicBlock *initBlock =
2347 splitBB(builder, /*CreateBranch=*/true, "omp.private.init");
2348
2349 // Now the control flow graph should look like
2350 // starter_block:
2351 // <---- where we started when convertOmpTaskOp was called
2352 // br %omp.private.init
2353 // omp.private.init:
2354 // br %omp.private.copy
2355 // omp.private.copy:
2356 // br %omp.task.start
2357 // omp.task.start:
2358 // <---- where we want the insertion point to be when we call createTask()
2359
2360 // Save the alloca insertion point on ModuleTranslation stack for use in
2361 // nested regions.
2363 moduleTranslation, allocaIP);
2364
2365 // Allocate and initialize private variables
2366 builder.SetInsertPoint(initBlock->getTerminator());
2367
2368 // Create task variable structure
2369 taskStructMgr.generateTaskContextStruct();
2370 // GEPs so that we can initialize the variables. Don't use these GEPs inside
2371 // of the body otherwise it will be the GEP not the struct which is fowarded
2372 // to the outlined function. GEPs forwarded in this way are passed in a
2373 // stack-allocated (by OpenMPIRBuilder) structure which is not safe for tasks
2374 // which may not be executed until after the current stack frame goes out of
2375 // scope.
2376 taskStructMgr.createGEPsToPrivateVars();
2377
2378 for (auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVarAlloc] :
2379 llvm::zip_equal(privateVarsInfo.privatizers, privateVarsInfo.mlirVars,
2380 privateVarsInfo.blockArgs,
2381 taskStructMgr.getLLVMPrivateVarGEPs())) {
2382 // To be handled inside the task.
2383 if (!privDecl.readsFromMold())
2384 continue;
2385 assert(llvmPrivateVarAlloc &&
2386 "reads from mold so shouldn't have been skipped");
2387
2388 llvm::Expected<llvm::Value *> privateVarOrErr =
2389 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2390 blockArg, llvmPrivateVarAlloc, initBlock);
2391 if (!privateVarOrErr)
2392 return handleError(privateVarOrErr, *taskOp.getOperation());
2393
2395
2396 // TODO: this is a bit of a hack for Fortran character boxes.
2397 // Character boxes are passed by value into the init region and then the
2398 // initialized character box is yielded by value. Here we need to store the
2399 // yielded value into the private allocation, and load the private
2400 // allocation to match the type expected by region block arguments.
2401 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
2402 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2403 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
2404 // Load it so we have the value pointed to by the GEP
2405 llvmPrivateVarAlloc = builder.CreateLoad(privateVarOrErr.get()->getType(),
2406 llvmPrivateVarAlloc);
2407 }
2408 assert(llvmPrivateVarAlloc->getType() ==
2409 moduleTranslation.convertType(blockArg.getType()));
2410
2411 // Mapping blockArg -> llvmPrivateVarAlloc is done inside the body callback
2412 // so that OpenMPIRBuilder doesn't try to pass each GEP address through a
2413 // stack allocated structure.
2414 }
2415
2416 // firstprivate copy region
2417 setInsertPointForPossiblyEmptyBlock(builder, copyBlock);
2418 if (failed(copyFirstPrivateVars(
2419 taskOp, builder, moduleTranslation, privateVarsInfo.mlirVars,
2420 taskStructMgr.getLLVMPrivateVarGEPs(), privateVarsInfo.privatizers,
2421 taskOp.getPrivateNeedsBarrier())))
2422 return llvm::failure();
2423
2424 // Set up for call to createTask()
2425 builder.SetInsertPoint(taskStartBlock);
2426
2427 auto bodyCB = [&](InsertPointTy allocaIP,
2428 InsertPointTy codegenIP) -> llvm::Error {
2429 // Save the alloca insertion point on ModuleTranslation stack for use in
2430 // nested regions.
2432 moduleTranslation, allocaIP);
2433
2434 // translate the body of the task:
2435 builder.restoreIP(codegenIP);
2436
2437 llvm::BasicBlock *privInitBlock = nullptr;
2438 privateVarsInfo.llvmVars.resize(privateVarsInfo.blockArgs.size());
2439 for (auto [i, zip] : llvm::enumerate(llvm::zip_equal(
2440 privateVarsInfo.blockArgs, privateVarsInfo.privatizers,
2441 privateVarsInfo.mlirVars))) {
2442 auto [blockArg, privDecl, mlirPrivVar] = zip;
2443 // This is handled before the task executes
2444 if (privDecl.readsFromMold())
2445 continue;
2446
2447 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2448 llvm::Type *llvmAllocType =
2449 moduleTranslation.convertType(privDecl.getType());
2450 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
2451 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
2452 llvmAllocType, /*ArraySize=*/nullptr, "omp.private.alloc");
2453
2454 llvm::Expected<llvm::Value *> privateVarOrError =
2455 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2456 blockArg, llvmPrivateVar, privInitBlock);
2457 if (!privateVarOrError)
2458 return privateVarOrError.takeError();
2459 moduleTranslation.mapValue(blockArg, privateVarOrError.get());
2460 privateVarsInfo.llvmVars[i] = privateVarOrError.get();
2461 }
2462
2463 taskStructMgr.createGEPsToPrivateVars();
2464 for (auto [i, llvmPrivVar] :
2465 llvm::enumerate(taskStructMgr.getLLVMPrivateVarGEPs())) {
2466 if (!llvmPrivVar) {
2467 assert(privateVarsInfo.llvmVars[i] &&
2468 "This is added in the loop above");
2469 continue;
2470 }
2471 privateVarsInfo.llvmVars[i] = llvmPrivVar;
2472 }
2473
2474 // Find and map the addresses of each variable within the task context
2475 // structure
2476 for (auto [blockArg, llvmPrivateVar, privateDecl] :
2477 llvm::zip_equal(privateVarsInfo.blockArgs, privateVarsInfo.llvmVars,
2478 privateVarsInfo.privatizers)) {
2479 // This was handled above.
2480 if (!privateDecl.readsFromMold())
2481 continue;
2482 // Fix broken pass-by-value case for Fortran character boxes
2483 if (!mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2484 llvmPrivateVar = builder.CreateLoad(
2485 moduleTranslation.convertType(blockArg.getType()), llvmPrivateVar);
2486 }
2487 assert(llvmPrivateVar->getType() ==
2488 moduleTranslation.convertType(blockArg.getType()));
2489 moduleTranslation.mapValue(blockArg, llvmPrivateVar);
2490 }
2491
2492 auto continuationBlockOrError = convertOmpOpRegions(
2493 taskOp.getRegion(), "omp.task.region", builder, moduleTranslation);
2494 if (failed(handleError(continuationBlockOrError, *taskOp)))
2495 return llvm::make_error<PreviouslyReportedError>();
2496
2497 builder.SetInsertPoint(continuationBlockOrError.get()->getTerminator());
2498
2499 if (failed(cleanupPrivateVars(builder, moduleTranslation, taskOp.getLoc(),
2500 privateVarsInfo.llvmVars,
2501 privateVarsInfo.privatizers)))
2502 return llvm::make_error<PreviouslyReportedError>();
2503
2504 // Free heap allocated task context structure at the end of the task.
2505 taskStructMgr.freeStructPtr();
2506
2507 return llvm::Error::success();
2508 };
2509
2510 llvm::OpenMPIRBuilder &ompBuilder = *moduleTranslation.getOpenMPBuilder();
2511 SmallVector<llvm::BranchInst *> cancelTerminators;
2512 // The directive to match here is OMPD_taskgroup because it is the taskgroup
2513 // which is canceled. This is handled here because it is the task's cleanup
2514 // block which should be branched to.
2515 pushCancelFinalizationCB(cancelTerminators, builder, ompBuilder, taskOp,
2516 llvm::omp::Directive::OMPD_taskgroup);
2517
2519 buildDependData(taskOp.getDependKinds(), taskOp.getDependVars(),
2520 moduleTranslation, dds);
2521
2522 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2523 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2524 moduleTranslation.getOpenMPBuilder()->createTask(
2525 ompLoc, allocaIP, bodyCB, !taskOp.getUntied(),
2526 moduleTranslation.lookupValue(taskOp.getFinal()),
2527 moduleTranslation.lookupValue(taskOp.getIfExpr()), dds,
2528 taskOp.getMergeable(),
2529 moduleTranslation.lookupValue(taskOp.getEventHandle()),
2530 moduleTranslation.lookupValue(taskOp.getPriority()));
2531
2532 if (failed(handleError(afterIP, *taskOp)))
2533 return failure();
2534
2535 // Set the correct branch target for task cancellation
2536 popCancelFinalizationCB(cancelTerminators, ompBuilder, afterIP.get());
2537
2538 builder.restoreIP(*afterIP);
2539 return success();
2540}
2541
2542// Converts an OpenMP taskloop construct into LLVM IR using OpenMPIRBuilder.
2543static LogicalResult
2544convertOmpTaskloopOp(Operation &opInst, llvm::IRBuilderBase &builder,
2545 LLVM::ModuleTranslation &moduleTranslation) {
2546 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2547 auto taskloopOp = cast<omp::TaskloopOp>(opInst);
2548 if (failed(checkImplementationStatus(opInst)))
2549 return failure();
2550
2551 // It stores the pointer of allocated firstprivate copies,
2552 // which can be used later for freeing the allocated space.
2553 SmallVector<llvm::Value *> llvmFirstPrivateVars;
2554 PrivateVarsInfo privateVarsInfo(taskloopOp);
2555 TaskContextStructManager taskStructMgr{builder, moduleTranslation,
2556 privateVarsInfo.privatizers};
2557
2558 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2559 findAllocaInsertPoint(builder, moduleTranslation);
2560
2561 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end());
2562 llvm::BasicBlock *taskloopStartBlock = llvm::BasicBlock::Create(
2563 builder.getContext(), "omp.taskloop.start",
2564 /*Parent=*/builder.GetInsertBlock()->getParent());
2565 llvm::Instruction *branchToTaskloopStartBlock =
2566 builder.CreateBr(taskloopStartBlock);
2567 builder.SetInsertPoint(branchToTaskloopStartBlock);
2568
2569 llvm::BasicBlock *copyBlock =
2570 splitBB(builder, /*CreateBranch=*/true, "omp.private.copy");
2571 llvm::BasicBlock *initBlock =
2572 splitBB(builder, /*CreateBranch=*/true, "omp.private.init");
2573
2575 moduleTranslation, allocaIP);
2576
2577 // Allocate and initialize private variables
2578 builder.SetInsertPoint(initBlock->getTerminator());
2579
2580 // TODO: don't allocate if the loop has zero iterations.
2581 taskStructMgr.generateTaskContextStruct();
2582 taskStructMgr.createGEPsToPrivateVars();
2583
2584 llvmFirstPrivateVars.resize(privateVarsInfo.blockArgs.size());
2585
2586 for (auto [i, zip] : llvm::enumerate(llvm::zip_equal(
2587 privateVarsInfo.privatizers, privateVarsInfo.mlirVars,
2588 privateVarsInfo.blockArgs, taskStructMgr.getLLVMPrivateVarGEPs()))) {
2589 auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVarAlloc] = zip;
2590 // To be handled inside the taskloop.
2591 if (!privDecl.readsFromMold())
2592 continue;
2593 assert(llvmPrivateVarAlloc &&
2594 "reads from mold so shouldn't have been skipped");
2595
2596 llvm::Expected<llvm::Value *> privateVarOrErr =
2597 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2598 blockArg, llvmPrivateVarAlloc, initBlock);
2599 if (!privateVarOrErr)
2600 return handleError(privateVarOrErr, *taskloopOp.getOperation());
2601
2602 llvmFirstPrivateVars[i] = privateVarOrErr.get();
2603
2604 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2605 builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
2606
2607 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
2608 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2609 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
2610 // Load it so we have the value pointed to by the GEP
2611 llvmPrivateVarAlloc = builder.CreateLoad(privateVarOrErr.get()->getType(),
2612 llvmPrivateVarAlloc);
2613 }
2614 assert(llvmPrivateVarAlloc->getType() ==
2615 moduleTranslation.convertType(blockArg.getType()));
2616 }
2617
2618 // firstprivate copy region
2619 setInsertPointForPossiblyEmptyBlock(builder, copyBlock);
2620 if (failed(copyFirstPrivateVars(
2621 taskloopOp, builder, moduleTranslation, privateVarsInfo.mlirVars,
2622 taskStructMgr.getLLVMPrivateVarGEPs(), privateVarsInfo.privatizers,
2623 taskloopOp.getPrivateNeedsBarrier())))
2624 return llvm::failure();
2625
2626 // Set up inserttion point for call to createTaskloop()
2627 builder.SetInsertPoint(taskloopStartBlock);
2628
2629 auto bodyCB = [&](InsertPointTy allocaIP,
2630 InsertPointTy codegenIP) -> llvm::Error {
2631 // Save the alloca insertion point on ModuleTranslation stack for use in
2632 // nested regions.
2634 moduleTranslation, allocaIP);
2635
2636 // translate the body of the taskloop:
2637 builder.restoreIP(codegenIP);
2638
2639 llvm::BasicBlock *privInitBlock = nullptr;
2640 privateVarsInfo.llvmVars.resize(privateVarsInfo.blockArgs.size());
2641 for (auto [i, zip] : llvm::enumerate(llvm::zip_equal(
2642 privateVarsInfo.blockArgs, privateVarsInfo.privatizers,
2643 privateVarsInfo.mlirVars))) {
2644 auto [blockArg, privDecl, mlirPrivVar] = zip;
2645 // This is handled before the task executes
2646 if (privDecl.readsFromMold())
2647 continue;
2648
2649 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2650 llvm::Type *llvmAllocType =
2651 moduleTranslation.convertType(privDecl.getType());
2652 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
2653 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
2654 llvmAllocType, /*ArraySize=*/nullptr, "omp.private.alloc");
2655
2656 llvm::Expected<llvm::Value *> privateVarOrError =
2657 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2658 blockArg, llvmPrivateVar, privInitBlock);
2659 if (!privateVarOrError)
2660 return privateVarOrError.takeError();
2661 moduleTranslation.mapValue(blockArg, privateVarOrError.get());
2662 privateVarsInfo.llvmVars[i] = privateVarOrError.get();
2663 }
2664
2665 taskStructMgr.createGEPsToPrivateVars();
2666 for (auto [i, llvmPrivVar] :
2667 llvm::enumerate(taskStructMgr.getLLVMPrivateVarGEPs())) {
2668 if (!llvmPrivVar) {
2669 assert(privateVarsInfo.llvmVars[i] &&
2670 "This is added in the loop above");
2671 continue;
2672 }
2673 privateVarsInfo.llvmVars[i] = llvmPrivVar;
2674 }
2675
2676 // Find and map the addresses of each variable within the taskloop context
2677 // structure
2678 for (auto [blockArg, llvmPrivateVar, privateDecl] :
2679 llvm::zip_equal(privateVarsInfo.blockArgs, privateVarsInfo.llvmVars,
2680 privateVarsInfo.privatizers)) {
2681 // This was handled above.
2682 if (!privateDecl.readsFromMold())
2683 continue;
2684 // Fix broken pass-by-value case for Fortran character boxes
2685 if (!mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2686 llvmPrivateVar = builder.CreateLoad(
2687 moduleTranslation.convertType(blockArg.getType()), llvmPrivateVar);
2688 }
2689 assert(llvmPrivateVar->getType() ==
2690 moduleTranslation.convertType(blockArg.getType()));
2691 moduleTranslation.mapValue(blockArg, llvmPrivateVar);
2692 }
2693
2694 auto continuationBlockOrError =
2695 convertOmpOpRegions(taskloopOp.getRegion(), "omp.taskloop.region",
2696 builder, moduleTranslation);
2697
2698 if (failed(handleError(continuationBlockOrError, opInst)))
2699 return llvm::make_error<PreviouslyReportedError>();
2700
2701 builder.SetInsertPoint(continuationBlockOrError.get()->getTerminator());
2702
2703 // This is freeing the private variables as mapped inside of the task: these
2704 // will be per-task private copies possibly after task duplication. This is
2705 // handled transparently by how these are passed to the structure passed
2706 // into the outlined function. When the task is duplicated, that structure
2707 // is duplicated too.
2708 if (failed(cleanupPrivateVars(builder, moduleTranslation,
2709 taskloopOp.getLoc(), privateVarsInfo.llvmVars,
2710 privateVarsInfo.privatizers)))
2711 return llvm::make_error<PreviouslyReportedError>();
2712 // Similarly, the task context structure freed inside the task is the
2713 // per-task copy after task duplication.
2714 taskStructMgr.freeStructPtr();
2715
2716 return llvm::Error::success();
2717 };
2718
2719 // Taskloop divides into an appropriate number of tasks by repeatedly
2720 // duplicating the original task. Each time this is done, the task context
2721 // structure must be duplicated too.
2722 auto taskDupCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP,
2723 llvm::Value *destPtr, llvm::Value *srcPtr)
2725 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2726 builder.restoreIP(codegenIP);
2727
2728 llvm::Type *ptrTy =
2729 builder.getPtrTy(srcPtr->getType()->getPointerAddressSpace());
2730 llvm::Value *src =
2731 builder.CreateLoad(ptrTy, srcPtr, "omp.taskloop.context.src");
2732
2733 TaskContextStructManager &srcStructMgr = taskStructMgr;
2734 TaskContextStructManager destStructMgr(builder, moduleTranslation,
2735 privateVarsInfo.privatizers);
2736 destStructMgr.generateTaskContextStruct();
2737 llvm::Value *dest = destStructMgr.getStructPtr();
2738 dest->setName("omp.taskloop.context.dest");
2739 builder.CreateStore(dest, destPtr);
2740
2742 srcStructMgr.createGEPsToPrivateVars(src);
2744 destStructMgr.createGEPsToPrivateVars(dest);
2745
2746 // Inline init regions.
2747 for (auto [privDecl, mold, blockArg, llvmPrivateVarAlloc] :
2748 llvm::zip_equal(privateVarsInfo.privatizers, srcGEPs,
2749 privateVarsInfo.blockArgs, destGEPs)) {
2750 // To be handled inside task body.
2751 if (!privDecl.readsFromMold())
2752 continue;
2753 assert(llvmPrivateVarAlloc &&
2754 "reads from mold so shouldn't have been skipped");
2755
2756 llvm::Expected<llvm::Value *> privateVarOrErr =
2757 initPrivateVar(builder, moduleTranslation, privDecl, mold, blockArg,
2758 llvmPrivateVarAlloc, builder.GetInsertBlock());
2759 if (!privateVarOrErr)
2760 return privateVarOrErr.takeError();
2761
2763
2764 // TODO: this is a bit of a hack for Fortran character boxes.
2765 // Character boxes are passed by value into the init region and then the
2766 // initialized character box is yielded by value. Here we need to store
2767 // the yielded value into the private allocation, and load the private
2768 // allocation to match the type expected by region block arguments.
2769 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
2770 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2771 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
2772 // Load it so we have the value pointed to by the GEP
2773 llvmPrivateVarAlloc = builder.CreateLoad(
2774 privateVarOrErr.get()->getType(), llvmPrivateVarAlloc);
2775 }
2776 assert(llvmPrivateVarAlloc->getType() ==
2777 moduleTranslation.convertType(blockArg.getType()));
2778
2779 // Mapping blockArg -> llvmPrivateVarAlloc is done inside the body
2780 // callback so that OpenMPIRBuilder doesn't try to pass each GEP address
2781 // through a stack allocated structure.
2782 }
2783
2784 if (failed(copyFirstPrivateVars(
2785 &opInst, builder, moduleTranslation, srcGEPs, destGEPs,
2786 privateVarsInfo.privatizers, taskloopOp.getPrivateNeedsBarrier())))
2787 return llvm::make_error<PreviouslyReportedError>();
2788
2789 return builder.saveIP();
2790 };
2791
2792 auto loopOp = cast<omp::LoopNestOp>(taskloopOp.getWrappedLoop());
2793
2794 auto loopInfo = [&]() -> llvm::Expected<llvm::CanonicalLoopInfo *> {
2795 llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
2796 return loopInfo;
2797 };
2798
2799 Operation::operand_range lowerBounds = loopOp.getLoopLowerBounds();
2800 Operation::operand_range upperBounds = loopOp.getLoopUpperBounds();
2801 Operation::operand_range steps = loopOp.getLoopSteps();
2802 llvm::Type *boundType =
2803 moduleTranslation.lookupValue(lowerBounds[0])->getType();
2804 llvm::Value *lbVal = nullptr;
2805 llvm::Value *ubVal = builder.getIntN(boundType->getIntegerBitWidth(), 1);
2806 llvm::Value *stepVal = nullptr;
2807 if (loopOp.getCollapseNumLoops() > 1) {
2808 // In cases where Collapse is used with Taskloop, the upper bound of the
2809 // iteration space needs to be recalculated to cater for the collapsed loop.
2810 // The Collapsed Loop UpperBound is the product of all collapsed
2811 // loop's tripcount.
2812 // The LowerBound for collapsed loops is always 1. When the loops are
2813 // collapsed, it will reset the bounds and introduce processing to ensure
2814 // the index's are presented as expected. As this happens after creating
2815 // Taskloop, these bounds need predicting. Example:
2816 // !$omp taskloop collapse(2)
2817 // do i = 1, 10
2818 // do j = 1, 5
2819 // ..
2820 // end do
2821 // end do
2822 // This loop above has a total of 50 iterations, so the lb will be 1, and
2823 // the ub will be 50. collapseLoops in OMPIRBuilder then handles ensuring
2824 // that i and j are properly presented when used in the loop.
2825 for (uint64_t i = 0; i < loopOp.getCollapseNumLoops(); i++) {
2826 llvm::Value *loopLb = moduleTranslation.lookupValue(lowerBounds[i]);
2827 llvm::Value *loopUb = moduleTranslation.lookupValue(upperBounds[i]);
2828 llvm::Value *loopStep = moduleTranslation.lookupValue(steps[i]);
2829 // In some cases, such as where the ub is less than the lb so the loop
2830 // steps down, the calculation for the loopTripCount is swapped. To ensure
2831 // the correct value is found, calculate both UB - LB and LB - UB then
2832 // select which value to use depending on how the loop has been
2833 // configured.
2834 llvm::Value *loopLbMinusOne = builder.CreateSub(
2835 loopLb, builder.getIntN(boundType->getIntegerBitWidth(), 1));
2836 llvm::Value *loopUbMinusOne = builder.CreateSub(
2837 loopUb, builder.getIntN(boundType->getIntegerBitWidth(), 1));
2838 llvm::Value *boundsCmp = builder.CreateICmpSLT(loopLb, loopUb);
2839 llvm::Value *ubMinusLb = builder.CreateSub(loopUb, loopLbMinusOne);
2840 llvm::Value *lbMinusUb = builder.CreateSub(loopLb, loopUbMinusOne);
2841 llvm::Value *loopTripCount =
2842 builder.CreateSelect(boundsCmp, ubMinusLb, lbMinusUb);
2843 loopTripCount = builder.CreateBinaryIntrinsic(
2844 llvm::Intrinsic::abs, loopTripCount, builder.getFalse());
2845 // For loops that have a step value not equal to 1, we need to adjust the
2846 // trip count to ensure the correct number of iterations for the loop is
2847 // captured.
2848 llvm::Value *loopTripCountDivStep =
2849 builder.CreateSDiv(loopTripCount, loopStep);
2850 loopTripCountDivStep = builder.CreateBinaryIntrinsic(
2851 llvm::Intrinsic::abs, loopTripCountDivStep, builder.getFalse());
2852 llvm::Value *loopTripCountRem =
2853 builder.CreateSRem(loopTripCount, loopStep);
2854 loopTripCountRem = builder.CreateBinaryIntrinsic(
2855 llvm::Intrinsic::abs, loopTripCountRem, builder.getFalse());
2856 llvm::Value *needsRoundUp = builder.CreateICmpNE(
2857 loopTripCountRem,
2858 builder.getIntN(loopTripCountRem->getType()->getIntegerBitWidth(),
2859 0));
2860 loopTripCount =
2861 builder.CreateAdd(loopTripCountDivStep,
2862 builder.CreateZExtOrTrunc(
2863 needsRoundUp, loopTripCountDivStep->getType()));
2864 ubVal = builder.CreateMul(ubVal, loopTripCount);
2865 }
2866 lbVal = builder.getIntN(boundType->getIntegerBitWidth(), 1);
2867 stepVal = builder.getIntN(boundType->getIntegerBitWidth(), 1);
2868 } else {
2869 lbVal = moduleTranslation.lookupValue(lowerBounds[0]);
2870 ubVal = moduleTranslation.lookupValue(upperBounds[0]);
2871 stepVal = moduleTranslation.lookupValue(steps[0]);
2872 }
2873 assert(lbVal != nullptr && "Expected value for lbVal");
2874 assert(ubVal != nullptr && "Expected value for ubVal");
2875 assert(stepVal != nullptr && "Expected value for stepVal");
2876
2877 llvm::Value *ifCond = nullptr;
2878 llvm::Value *grainsize = nullptr;
2879 int sched = 0; // default
2880 mlir::Value grainsizeVal = taskloopOp.getGrainsize();
2881 mlir::Value numTasksVal = taskloopOp.getNumTasks();
2882 if (Value ifVar = taskloopOp.getIfExpr())
2883 ifCond = moduleTranslation.lookupValue(ifVar);
2884 if (grainsizeVal) {
2885 grainsize = moduleTranslation.lookupValue(grainsizeVal);
2886 sched = 1; // grainsize
2887 } else if (numTasksVal) {
2888 grainsize = moduleTranslation.lookupValue(numTasksVal);
2889 sched = 2; // num_tasks
2890 }
2891
2892 llvm::OpenMPIRBuilder::TaskDupCallbackTy taskDupOrNull = nullptr;
2893 if (taskStructMgr.getStructPtr())
2894 taskDupOrNull = taskDupCB;
2895
2896 llvm::OpenMPIRBuilder &ompBuilder = *moduleTranslation.getOpenMPBuilder();
2897 SmallVector<llvm::BranchInst *> cancelTerminators;
2898 // The directive to match here is OMPD_taskgroup because it is the
2899 // taskgroup which is canceled. This is handled here because it is the
2900 // task's cleanup block which should be branched to. It doesn't depend upon
2901 // nogroup because even in that case the taskloop might still be inside an
2902 // explicit taskgroup.
2903 pushCancelFinalizationCB(cancelTerminators, builder, ompBuilder, taskloopOp,
2904 llvm::omp::Directive::OMPD_taskgroup);
2905
2906 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2907 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2908 moduleTranslation.getOpenMPBuilder()->createTaskloop(
2909 ompLoc, allocaIP, bodyCB, loopInfo, lbVal, ubVal, stepVal,
2910 taskloopOp.getUntied(), ifCond, grainsize, taskloopOp.getNogroup(),
2911 sched, moduleTranslation.lookupValue(taskloopOp.getFinal()),
2912 taskloopOp.getMergeable(),
2913 moduleTranslation.lookupValue(taskloopOp.getPriority()),
2914 loopOp.getCollapseNumLoops(), taskDupOrNull,
2915 taskStructMgr.getStructPtr());
2916
2917 if (failed(handleError(afterIP, opInst)))
2918 return failure();
2919
2920 popCancelFinalizationCB(cancelTerminators, ompBuilder, afterIP.get());
2921
2922 builder.restoreIP(*afterIP);
2923 return success();
2924}
2925
2926/// Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder.
2927static LogicalResult
2928convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder,
2929 LLVM::ModuleTranslation &moduleTranslation) {
2930 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2931 if (failed(checkImplementationStatus(*tgOp)))
2932 return failure();
2933
2934 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
2935 builder.restoreIP(codegenIP);
2936 return convertOmpOpRegions(tgOp.getRegion(), "omp.taskgroup.region",
2937 builder, moduleTranslation)
2938 .takeError();
2939 };
2940
2941 InsertPointTy allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
2942 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2943 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2944 moduleTranslation.getOpenMPBuilder()->createTaskgroup(ompLoc, allocaIP,
2945 bodyCB);
2946
2947 if (failed(handleError(afterIP, *tgOp)))
2948 return failure();
2949
2950 builder.restoreIP(*afterIP);
2951 return success();
2952}
2953
2954static LogicalResult
2955convertOmpTaskwaitOp(omp::TaskwaitOp twOp, llvm::IRBuilderBase &builder,
2956 LLVM::ModuleTranslation &moduleTranslation) {
2957 if (failed(checkImplementationStatus(*twOp)))
2958 return failure();
2959
2960 moduleTranslation.getOpenMPBuilder()->createTaskwait(builder.saveIP());
2961 return success();
2962}
2963
2964/// Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
2965static LogicalResult
2966convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
2967 LLVM::ModuleTranslation &moduleTranslation) {
2968 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
2969 auto wsloopOp = cast<omp::WsloopOp>(opInst);
2970 if (failed(checkImplementationStatus(opInst)))
2971 return failure();
2972
2973 auto loopOp = cast<omp::LoopNestOp>(wsloopOp.getWrappedLoop());
2974 llvm::ArrayRef<bool> isByRef = getIsByRef(wsloopOp.getReductionByref());
2975 assert(isByRef.size() == wsloopOp.getNumReductionVars());
2976
2977 // Static is the default.
2978 auto schedule =
2979 wsloopOp.getScheduleKind().value_or(omp::ClauseScheduleKind::Static);
2980
2981 // Find the loop configuration.
2982 llvm::Value *step = moduleTranslation.lookupValue(loopOp.getLoopSteps()[0]);
2983 llvm::Type *ivType = step->getType();
2984 llvm::Value *chunk = nullptr;
2985 if (wsloopOp.getScheduleChunk()) {
2986 llvm::Value *chunkVar =
2987 moduleTranslation.lookupValue(wsloopOp.getScheduleChunk());
2988 chunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
2989 }
2990
2991 omp::DistributeOp distributeOp = nullptr;
2992 llvm::Value *distScheduleChunk = nullptr;
2993 bool hasDistSchedule = false;
2994 if (llvm::isa_and_present<omp::DistributeOp>(opInst.getParentOp())) {
2995 distributeOp = cast<omp::DistributeOp>(opInst.getParentOp());
2996 hasDistSchedule = distributeOp.getDistScheduleStatic();
2997 if (distributeOp.getDistScheduleChunkSize()) {
2998 llvm::Value *chunkVar = moduleTranslation.lookupValue(
2999 distributeOp.getDistScheduleChunkSize());
3000 distScheduleChunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
3001 }
3002 }
3003
3004 PrivateVarsInfo privateVarsInfo(wsloopOp);
3005
3007 collectReductionDecls(wsloopOp, reductionDecls);
3008 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3009 findAllocaInsertPoint(builder, moduleTranslation);
3010
3011 SmallVector<llvm::Value *> privateReductionVariables(
3012 wsloopOp.getNumReductionVars());
3013
3015 builder, moduleTranslation, privateVarsInfo, allocaIP);
3016 if (handleError(afterAllocas, opInst).failed())
3017 return failure();
3018
3019 DenseMap<Value, llvm::Value *> reductionVariableMap;
3020
3021 MutableArrayRef<BlockArgument> reductionArgs =
3022 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
3023
3024 SmallVector<DeferredStore> deferredStores;
3025
3026 if (failed(allocReductionVars(wsloopOp, reductionArgs, builder,
3027 moduleTranslation, allocaIP, reductionDecls,
3028 privateReductionVariables, reductionVariableMap,
3029 deferredStores, isByRef)))
3030 return failure();
3031
3032 if (handleError(initPrivateVars(builder, moduleTranslation, privateVarsInfo),
3033 opInst)
3034 .failed())
3035 return failure();
3036
3037 if (failed(copyFirstPrivateVars(
3038 wsloopOp, builder, moduleTranslation, privateVarsInfo.mlirVars,
3039 privateVarsInfo.llvmVars, privateVarsInfo.privatizers,
3040 wsloopOp.getPrivateNeedsBarrier())))
3041 return failure();
3042
3043 assert(afterAllocas.get()->getSinglePredecessor());
3044 if (failed(initReductionVars(wsloopOp, reductionArgs, builder,
3045 moduleTranslation,
3046 afterAllocas.get()->getSinglePredecessor(),
3047 reductionDecls, privateReductionVariables,
3048 reductionVariableMap, isByRef, deferredStores)))
3049 return failure();
3050
3051 // TODO: Handle doacross loops when the ordered clause has a parameter.
3052 bool isOrdered = wsloopOp.getOrdered().has_value();
3053 std::optional<omp::ScheduleModifier> scheduleMod = wsloopOp.getScheduleMod();
3054 bool isSimd = wsloopOp.getScheduleSimd();
3055 bool loopNeedsBarrier = !wsloopOp.getNowait();
3056
3057 // The only legal way for the direct parent to be omp.distribute is that this
3058 // represents 'distribute parallel do'. Otherwise, this is a regular
3059 // worksharing loop.
3060 llvm::omp::WorksharingLoopType workshareLoopType =
3061 llvm::isa_and_present<omp::DistributeOp>(opInst.getParentOp())
3062 ? llvm::omp::WorksharingLoopType::DistributeForStaticLoop
3063 : llvm::omp::WorksharingLoopType::ForStaticLoop;
3064
3065 SmallVector<llvm::BranchInst *> cancelTerminators;
3066 pushCancelFinalizationCB(cancelTerminators, builder, *ompBuilder, wsloopOp,
3067 llvm::omp::Directive::OMPD_for);
3068
3069 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3070
3071 // Initialize linear variables and linear step
3072 LinearClauseProcessor linearClauseProcessor;
3073
3074 if (!wsloopOp.getLinearVars().empty()) {
3075 auto linearVarTypes = wsloopOp.getLinearVarTypes().value();
3076 for (mlir::Attribute linearVarType : linearVarTypes)
3077 linearClauseProcessor.registerType(moduleTranslation, linearVarType);
3078
3079 for (auto [idx, linearVar] : llvm::enumerate(wsloopOp.getLinearVars()))
3080 linearClauseProcessor.createLinearVar(
3081 builder, moduleTranslation, moduleTranslation.lookupValue(linearVar),
3082 idx);
3083 for (mlir::Value linearStep : wsloopOp.getLinearStepVars())
3084 linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
3085 }
3086
3088 wsloopOp.getRegion(), "omp.wsloop.region", builder, moduleTranslation);
3089
3090 if (failed(handleError(regionBlock, opInst)))
3091 return failure();
3092
3093 llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
3094
3095 // Emit Initialization and Update IR for linear variables
3096 if (!wsloopOp.getLinearVars().empty()) {
3097 linearClauseProcessor.initLinearVar(builder, moduleTranslation,
3098 loopInfo->getPreheader());
3099 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
3100 moduleTranslation.getOpenMPBuilder()->createBarrier(
3101 builder.saveIP(), llvm::omp::OMPD_barrier);
3102 if (failed(handleError(afterBarrierIP, *loopOp)))
3103 return failure();
3104 builder.restoreIP(*afterBarrierIP);
3105 linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
3106 loopInfo->getIndVar());
3107 linearClauseProcessor.splitLinearFiniBB(builder, loopInfo->getExit());
3108 }
3109
3110 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
3111
3112 // Check if we can generate no-loop kernel
3113 bool noLoopMode = false;
3114 omp::TargetOp targetOp = wsloopOp->getParentOfType<mlir::omp::TargetOp>();
3115 if (targetOp) {
3116 Operation *targetCapturedOp = targetOp.getInnermostCapturedOmpOp();
3117 // We need this check because, without it, noLoopMode would be set to true
3118 // for every omp.wsloop nested inside a no-loop SPMD target region, even if
3119 // that loop is not the top-level SPMD one.
3120 if (loopOp == targetCapturedOp) {
3121 omp::TargetRegionFlags kernelFlags =
3122 targetOp.getKernelExecFlags(targetCapturedOp);
3123 if (omp::bitEnumContainsAll(kernelFlags,
3124 omp::TargetRegionFlags::spmd |
3125 omp::TargetRegionFlags::no_loop) &&
3126 !omp::bitEnumContainsAny(kernelFlags,
3127 omp::TargetRegionFlags::generic))
3128 noLoopMode = true;
3129 }
3130 }
3131
3132 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
3133 ompBuilder->applyWorkshareLoop(
3134 ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
3135 convertToScheduleKind(schedule), chunk, isSimd,
3136 scheduleMod == omp::ScheduleModifier::monotonic,
3137 scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
3138 workshareLoopType, noLoopMode, hasDistSchedule, distScheduleChunk);
3139
3140 if (failed(handleError(wsloopIP, opInst)))
3141 return failure();
3142
3143 // Emit finalization and in-place rewrites for linear vars.
3144 if (!wsloopOp.getLinearVars().empty()) {
3145 llvm::OpenMPIRBuilder::InsertPointTy oldIP = builder.saveIP();
3146 assert(loopInfo->getLastIter() &&
3147 "`lastiter` in CanonicalLoopInfo is nullptr");
3148 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
3149 linearClauseProcessor.finalizeLinearVar(builder, moduleTranslation,
3150 loopInfo->getLastIter());
3151 if (failed(handleError(afterBarrierIP, *loopOp)))
3152 return failure();
3153 for (size_t index = 0; index < wsloopOp.getLinearVars().size(); index++)
3154 linearClauseProcessor.rewriteInPlace(builder, "omp.loop_nest.region",
3155 index);
3156 builder.restoreIP(oldIP);
3157 }
3158
3159 // Set the correct branch target for task cancellation
3160 popCancelFinalizationCB(cancelTerminators, *ompBuilder, wsloopIP.get());
3161
3162 // Process the reductions if required.
3163 if (failed(createReductionsAndCleanup(
3164 wsloopOp, builder, moduleTranslation, allocaIP, reductionDecls,
3165 privateReductionVariables, isByRef, wsloopOp.getNowait(),
3166 /*isTeamsReduction=*/false)))
3167 return failure();
3168
3169 return cleanupPrivateVars(builder, moduleTranslation, wsloopOp.getLoc(),
3170 privateVarsInfo.llvmVars,
3171 privateVarsInfo.privatizers);
3172}
3173
3174/// Converts the OpenMP parallel operation to LLVM IR.
3175static LogicalResult
3176convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
3177 LLVM::ModuleTranslation &moduleTranslation) {
3178 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
3179 ArrayRef<bool> isByRef = getIsByRef(opInst.getReductionByref());
3180 assert(isByRef.size() == opInst.getNumReductionVars());
3181 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3182 bool isCancellable = constructIsCancellable(opInst);
3183
3184 if (failed(checkImplementationStatus(*opInst)))
3185 return failure();
3186
3187 PrivateVarsInfo privateVarsInfo(opInst);
3188
3189 // Collect reduction declarations
3191 collectReductionDecls(opInst, reductionDecls);
3192 SmallVector<llvm::Value *> privateReductionVariables(
3193 opInst.getNumReductionVars());
3194 SmallVector<DeferredStore> deferredStores;
3195
3196 auto bodyGenCB = [&](InsertPointTy allocaIP,
3197 InsertPointTy codeGenIP) -> llvm::Error {
3199 builder, moduleTranslation, privateVarsInfo, allocaIP);
3200 if (handleError(afterAllocas, *opInst).failed())
3201 return llvm::make_error<PreviouslyReportedError>();
3202
3203 // Allocate reduction vars
3204 DenseMap<Value, llvm::Value *> reductionVariableMap;
3205
3206 MutableArrayRef<BlockArgument> reductionArgs =
3207 cast<omp::BlockArgOpenMPOpInterface>(*opInst).getReductionBlockArgs();
3208
3209 allocaIP =
3210 InsertPointTy(allocaIP.getBlock(),
3211 allocaIP.getBlock()->getTerminator()->getIterator());
3212
3213 if (failed(allocReductionVars(
3214 opInst, reductionArgs, builder, moduleTranslation, allocaIP,
3215 reductionDecls, privateReductionVariables, reductionVariableMap,
3216 deferredStores, isByRef)))
3217 return llvm::make_error<PreviouslyReportedError>();
3218
3219 assert(afterAllocas.get()->getSinglePredecessor());
3220 builder.restoreIP(codeGenIP);
3221
3222 if (handleError(
3223 initPrivateVars(builder, moduleTranslation, privateVarsInfo),
3224 *opInst)
3225 .failed())
3226 return llvm::make_error<PreviouslyReportedError>();
3227
3228 if (failed(copyFirstPrivateVars(
3229 opInst, builder, moduleTranslation, privateVarsInfo.mlirVars,
3230 privateVarsInfo.llvmVars, privateVarsInfo.privatizers,
3231 opInst.getPrivateNeedsBarrier())))
3232 return llvm::make_error<PreviouslyReportedError>();
3233
3234 if (failed(
3235 initReductionVars(opInst, reductionArgs, builder, moduleTranslation,
3236 afterAllocas.get()->getSinglePredecessor(),
3237 reductionDecls, privateReductionVariables,
3238 reductionVariableMap, isByRef, deferredStores)))
3239 return llvm::make_error<PreviouslyReportedError>();
3240
3241 // Save the alloca insertion point on ModuleTranslation stack for use in
3242 // nested regions.
3244 moduleTranslation, allocaIP);
3245
3246 // ParallelOp has only one region associated with it.
3248 opInst.getRegion(), "omp.par.region", builder, moduleTranslation);
3249 if (!regionBlock)
3250 return regionBlock.takeError();
3251
3252 // Process the reductions if required.
3253 if (opInst.getNumReductionVars() > 0) {
3254 // Collect reduction info
3256 SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
3258 owningReductionGenRefDataPtrGens;
3260 collectReductionInfo(opInst, builder, moduleTranslation, reductionDecls,
3261 owningReductionGens, owningAtomicReductionGens,
3262 owningReductionGenRefDataPtrGens,
3263 privateReductionVariables, reductionInfos, isByRef);
3264
3265 // Move to region cont block
3266 builder.SetInsertPoint((*regionBlock)->getTerminator());
3267
3268 // Generate reductions from info
3269 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
3270 builder.SetInsertPoint(tempTerminator);
3271
3272 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
3273 ompBuilder->createReductions(
3274 builder.saveIP(), allocaIP, reductionInfos, isByRef,
3275 /*IsNoWait=*/false, /*IsTeamsReduction=*/false);
3276 if (!contInsertPoint)
3277 return contInsertPoint.takeError();
3278
3279 if (!contInsertPoint->getBlock())
3280 return llvm::make_error<PreviouslyReportedError>();
3281
3282 tempTerminator->eraseFromParent();
3283 builder.restoreIP(*contInsertPoint);
3284 }
3285
3286 return llvm::Error::success();
3287 };
3288
3289 auto privCB = [](InsertPointTy allocaIP, InsertPointTy codeGenIP,
3290 llvm::Value &, llvm::Value &val, llvm::Value *&replVal) {
3291 // tell OpenMPIRBuilder not to do anything. We handled Privatisation in
3292 // bodyGenCB.
3293 replVal = &val;
3294 return codeGenIP;
3295 };
3296
3297 // TODO: Perform finalization actions for variables. This has to be
3298 // called for variables which have destructors/finalizers.
3299 auto finiCB = [&](InsertPointTy codeGenIP) -> llvm::Error {
3300 InsertPointTy oldIP = builder.saveIP();
3301 builder.restoreIP(codeGenIP);
3302
3303 // if the reduction has a cleanup region, inline it here to finalize the
3304 // reduction variables
3305 SmallVector<Region *> reductionCleanupRegions;
3306 llvm::transform(reductionDecls, std::back_inserter(reductionCleanupRegions),
3307 [](omp::DeclareReductionOp reductionDecl) {
3308 return &reductionDecl.getCleanupRegion();
3309 });
3310 if (failed(inlineOmpRegionCleanup(
3311 reductionCleanupRegions, privateReductionVariables,
3312 moduleTranslation, builder, "omp.reduction.cleanup")))
3313 return llvm::createStringError(
3314 "failed to inline `cleanup` region of `omp.declare_reduction`");
3315
3316 if (failed(cleanupPrivateVars(builder, moduleTranslation, opInst.getLoc(),
3317 privateVarsInfo.llvmVars,
3318 privateVarsInfo.privatizers)))
3319 return llvm::make_error<PreviouslyReportedError>();
3320
3321 // If we could be performing cancellation, add the cancellation barrier on
3322 // the way out of the outlined region.
3323 if (isCancellable) {
3324 auto IPOrErr = ompBuilder->createBarrier(
3325 llvm::OpenMPIRBuilder::LocationDescription(builder),
3326 llvm::omp::Directive::OMPD_unknown,
3327 /* ForceSimpleCall */ false,
3328 /* CheckCancelFlag */ false);
3329 if (!IPOrErr)
3330 return IPOrErr.takeError();
3331 }
3332
3333 builder.restoreIP(oldIP);
3334 return llvm::Error::success();
3335 };
3336
3337 llvm::Value *ifCond = nullptr;
3338 if (auto ifVar = opInst.getIfExpr())
3339 ifCond = moduleTranslation.lookupValue(ifVar);
3340 llvm::Value *numThreads = nullptr;
3341 if (!opInst.getNumThreadsVars().empty())
3342 numThreads = moduleTranslation.lookupValue(opInst.getNumThreads(0));
3343 auto pbKind = llvm::omp::OMP_PROC_BIND_default;
3344 if (auto bind = opInst.getProcBindKind())
3345 pbKind = getProcBindKind(*bind);
3346
3347 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3348 findAllocaInsertPoint(builder, moduleTranslation);
3349 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3350
3351 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3352 ompBuilder->createParallel(ompLoc, allocaIP, bodyGenCB, privCB, finiCB,
3353 ifCond, numThreads, pbKind, isCancellable);
3354
3355 if (failed(handleError(afterIP, *opInst)))
3356 return failure();
3357
3358 builder.restoreIP(*afterIP);
3359 return success();
3360}
3361
3362/// Convert Order attribute to llvm::omp::OrderKind.
3363static llvm::omp::OrderKind
3364convertOrderKind(std::optional<omp::ClauseOrderKind> o) {
3365 if (!o)
3366 return llvm::omp::OrderKind::OMP_ORDER_unknown;
3367 switch (*o) {
3368 case omp::ClauseOrderKind::Concurrent:
3369 return llvm::omp::OrderKind::OMP_ORDER_concurrent;
3370 }
3371 llvm_unreachable("Unknown ClauseOrderKind kind");
3372}
3373
3374/// Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder.
3375static LogicalResult
3376convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
3377 LLVM::ModuleTranslation &moduleTranslation) {
3378 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3379 auto simdOp = cast<omp::SimdOp>(opInst);
3380
3381 if (failed(checkImplementationStatus(opInst)))
3382 return failure();
3383
3384 PrivateVarsInfo privateVarsInfo(simdOp);
3385
3386 MutableArrayRef<BlockArgument> reductionArgs =
3387 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
3388 DenseMap<Value, llvm::Value *> reductionVariableMap;
3389 SmallVector<llvm::Value *> privateReductionVariables(
3390 simdOp.getNumReductionVars());
3391 SmallVector<DeferredStore> deferredStores;
3393 collectReductionDecls(simdOp, reductionDecls);
3394 llvm::ArrayRef<bool> isByRef = getIsByRef(simdOp.getReductionByref());
3395 assert(isByRef.size() == simdOp.getNumReductionVars());
3396
3397 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3398 findAllocaInsertPoint(builder, moduleTranslation);
3399
3401 builder, moduleTranslation, privateVarsInfo, allocaIP);
3402 if (handleError(afterAllocas, opInst).failed())
3403 return failure();
3404
3405 // Initialize linear variables and linear step
3406 LinearClauseProcessor linearClauseProcessor;
3407
3408 if (!simdOp.getLinearVars().empty()) {
3409 auto linearVarTypes = simdOp.getLinearVarTypes().value();
3410 for (mlir::Attribute linearVarType : linearVarTypes)
3411 linearClauseProcessor.registerType(moduleTranslation, linearVarType);
3412 for (auto [idx, linearVar] : llvm::enumerate(simdOp.getLinearVars())) {
3413 bool isImplicit = false;
3414 for (auto [mlirPrivVar, llvmPrivateVar] : llvm::zip_equal(
3415 privateVarsInfo.mlirVars, privateVarsInfo.llvmVars)) {
3416 // If the linear variable is implicit, reuse the already
3417 // existing llvm::Value
3418 if (linearVar == mlirPrivVar) {
3419 isImplicit = true;
3420 linearClauseProcessor.createLinearVar(builder, moduleTranslation,
3421 llvmPrivateVar, idx);
3422 break;
3423 }
3424 }
3425
3426 if (!isImplicit)
3427 linearClauseProcessor.createLinearVar(
3428 builder, moduleTranslation,
3429 moduleTranslation.lookupValue(linearVar), idx);
3430 }
3431 for (mlir::Value linearStep : simdOp.getLinearStepVars())
3432 linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
3433 }
3434
3435 if (failed(allocReductionVars(simdOp, reductionArgs, builder,
3436 moduleTranslation, allocaIP, reductionDecls,
3437 privateReductionVariables, reductionVariableMap,
3438 deferredStores, isByRef)))
3439 return failure();
3440
3441 if (handleError(initPrivateVars(builder, moduleTranslation, privateVarsInfo),
3442 opInst)
3443 .failed())
3444 return failure();
3445
3446 // No call to copyFirstPrivateVars because FIRSTPRIVATE is not allowed for
3447 // SIMD.
3448
3449 assert(afterAllocas.get()->getSinglePredecessor());
3450 if (failed(initReductionVars(simdOp, reductionArgs, builder,
3451 moduleTranslation,
3452 afterAllocas.get()->getSinglePredecessor(),
3453 reductionDecls, privateReductionVariables,
3454 reductionVariableMap, isByRef, deferredStores)))
3455 return failure();
3456
3457 llvm::ConstantInt *simdlen = nullptr;
3458 if (std::optional<uint64_t> simdlenVar = simdOp.getSimdlen())
3459 simdlen = builder.getInt64(simdlenVar.value());
3460
3461 llvm::ConstantInt *safelen = nullptr;
3462 if (std::optional<uint64_t> safelenVar = simdOp.getSafelen())
3463 safelen = builder.getInt64(safelenVar.value());
3464
3465 llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars;
3466 llvm::omp::OrderKind order = convertOrderKind(simdOp.getOrder());
3467
3468 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
3469 std::optional<ArrayAttr> alignmentValues = simdOp.getAlignments();
3470 mlir::OperandRange operands = simdOp.getAlignedVars();
3471 for (size_t i = 0; i < operands.size(); ++i) {
3472 llvm::Value *alignment = nullptr;
3473 llvm::Value *llvmVal = moduleTranslation.lookupValue(operands[i]);
3474 llvm::Type *ty = llvmVal->getType();
3475
3476 auto intAttr = cast<IntegerAttr>((*alignmentValues)[i]);
3477 alignment = builder.getInt64(intAttr.getInt());
3478 assert(ty->isPointerTy() && "Invalid type for aligned variable");
3479 assert(alignment && "Invalid alignment value");
3480
3481 // Check if the alignment value is not a power of 2. If so, skip emitting
3482 // alignment.
3483 if (!intAttr.getValue().isPowerOf2())
3484 continue;
3485
3486 auto curInsert = builder.saveIP();
3487 builder.SetInsertPoint(sourceBlock);
3488 llvmVal = builder.CreateLoad(ty, llvmVal);
3489 builder.restoreIP(curInsert);
3490 alignedVars[llvmVal] = alignment;
3491 }
3492
3494 simdOp.getRegion(), "omp.simd.region", builder, moduleTranslation);
3495
3496 if (failed(handleError(regionBlock, opInst)))
3497 return failure();
3498
3499 llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
3500 // Emit Initialization for linear variables
3501 if (simdOp.getLinearVars().size()) {
3502 linearClauseProcessor.initLinearVar(builder, moduleTranslation,
3503 loopInfo->getPreheader());
3504
3505 linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
3506 loopInfo->getIndVar());
3507 }
3508 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
3509
3510 ompBuilder->applySimd(loopInfo, alignedVars,
3511 simdOp.getIfExpr()
3512 ? moduleTranslation.lookupValue(simdOp.getIfExpr())
3513 : nullptr,
3514 order, simdlen, safelen);
3515
3516 linearClauseProcessor.emitStoresForLinearVar(builder);
3517
3518 // Check if this SIMD loop contains ordered regions
3519 bool hasOrderedRegions = false;
3520 simdOp.getRegion().walk([&](omp::OrderedRegionOp orderedOp) {
3521 hasOrderedRegions = true;
3522 return WalkResult::interrupt();
3523 });
3524
3525 for (size_t index = 0; index < simdOp.getLinearVars().size(); index++) {
3526 linearClauseProcessor.rewriteInPlace(builder, "omp.loop_nest.region",
3527 index);
3528 if (hasOrderedRegions) {
3529 // Also rewrite uses in ordered regions so they read the current value
3530 linearClauseProcessor.rewriteInPlace(builder, "omp.ordered.region",
3531 index);
3532 // Also rewrite uses in finalize blocks (code after ordered regions)
3533 linearClauseProcessor.rewriteInPlace(builder, "omp_region.finalize",
3534 index);
3535 }
3536 }
3537
3538 // We now need to reduce the per-simd-lane reduction variable into the
3539 // original variable. This works a bit differently to other reductions (e.g.
3540 // wsloop) because we don't need to call into the OpenMP runtime to handle
3541 // threads: everything happened in this one thread.
3542 for (auto [i, tuple] : llvm::enumerate(
3543 llvm::zip(reductionDecls, isByRef, simdOp.getReductionVars(),
3544 privateReductionVariables))) {
3545 auto [decl, byRef, reductionVar, privateReductionVar] = tuple;
3546
3547 OwningReductionGen gen = makeReductionGen(decl, builder, moduleTranslation);
3548 llvm::Value *originalVariable = moduleTranslation.lookupValue(reductionVar);
3549 llvm::Type *reductionType = moduleTranslation.convertType(decl.getType());
3550
3551 // We have one less load for by-ref case because that load is now inside of
3552 // the reduction region.
3553 llvm::Value *redValue = originalVariable;
3554 if (!byRef)
3555 redValue =
3556 builder.CreateLoad(reductionType, redValue, "red.value." + Twine(i));
3557 llvm::Value *privateRedValue = builder.CreateLoad(
3558 reductionType, privateReductionVar, "red.private.value." + Twine(i));
3559 llvm::Value *reduced;
3560
3561 auto res = gen(builder.saveIP(), redValue, privateRedValue, reduced);
3562 if (failed(handleError(res, opInst)))
3563 return failure();
3564 builder.restoreIP(res.get());
3565
3566 // For by-ref case, the store is inside of the reduction region.
3567 if (!byRef)
3568 builder.CreateStore(reduced, originalVariable);
3569 }
3570
3571 // After the construct, deallocate private reduction variables.
3572 SmallVector<Region *> reductionRegions;
3573 llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
3574 [](omp::DeclareReductionOp reductionDecl) {
3575 return &reductionDecl.getCleanupRegion();
3576 });
3577 if (failed(inlineOmpRegionCleanup(reductionRegions, privateReductionVariables,
3578 moduleTranslation, builder,
3579 "omp.reduction.cleanup")))
3580 return failure();
3581
3582 return cleanupPrivateVars(builder, moduleTranslation, simdOp.getLoc(),
3583 privateVarsInfo.llvmVars,
3584 privateVarsInfo.privatizers);
3585}
3586
3587/// Converts an OpenMP loop nest into LLVM IR using OpenMPIRBuilder.
3588static LogicalResult
3589convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder,
3590 LLVM::ModuleTranslation &moduleTranslation) {
3591 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3592 auto loopOp = cast<omp::LoopNestOp>(opInst);
3593
3594 if (failed(checkImplementationStatus(opInst)))
3595 return failure();
3596
3597 // Set up the source location value for OpenMP runtime.
3598 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3599
3600 // Generator of the canonical loop body.
3603 auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip,
3604 llvm::Value *iv) -> llvm::Error {
3605 // Make sure further conversions know about the induction variable.
3606 moduleTranslation.mapValue(
3607 loopOp.getRegion().front().getArgument(loopInfos.size()), iv);
3608
3609 // Capture the body insertion point for use in nested loops. BodyIP of the
3610 // CanonicalLoopInfo always points to the beginning of the entry block of
3611 // the body.
3612 bodyInsertPoints.push_back(ip);
3613
3614 if (loopInfos.size() != loopOp.getNumLoops() - 1)
3615 return llvm::Error::success();
3616
3617 // Convert the body of the loop.
3618 builder.restoreIP(ip);
3620 loopOp.getRegion(), "omp.loop_nest.region", builder, moduleTranslation);
3621 if (!regionBlock)
3622 return regionBlock.takeError();
3623
3624 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
3625 return llvm::Error::success();
3626 };
3627
3628 // Delegate actual loop construction to the OpenMP IRBuilder.
3629 // TODO: this currently assumes omp.loop_nest is semantically similar to SCF
3630 // loop, i.e. it has a positive step, uses signed integer semantics.
3631 // Reconsider this code when the nested loop operation clearly supports more
3632 // cases.
3633 for (unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) {
3634 llvm::Value *lowerBound =
3635 moduleTranslation.lookupValue(loopOp.getLoopLowerBounds()[i]);
3636 llvm::Value *upperBound =
3637 moduleTranslation.lookupValue(loopOp.getLoopUpperBounds()[i]);
3638 llvm::Value *step = moduleTranslation.lookupValue(loopOp.getLoopSteps()[i]);
3639
3640 // Make sure loop trip count are emitted in the preheader of the outermost
3641 // loop at the latest so that they are all available for the new collapsed
3642 // loop will be created below.
3643 llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc;
3644 llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP;
3645 if (i != 0) {
3646 loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back(),
3647 ompLoc.DL);
3648 computeIP = loopInfos.front()->getPreheaderIP();
3649 }
3650
3652 ompBuilder->createCanonicalLoop(
3653 loc, bodyGen, lowerBound, upperBound, step,
3654 /*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP);
3655
3656 if (failed(handleError(loopResult, *loopOp)))
3657 return failure();
3658
3659 loopInfos.push_back(*loopResult);
3660 }
3661
3662 llvm::OpenMPIRBuilder::InsertPointTy afterIP =
3663 loopInfos.front()->getAfterIP();
3664
3665 // Do tiling.
3666 if (const auto &tiles = loopOp.getTileSizes()) {
3667 llvm::Type *ivType = loopInfos.front()->getIndVarType();
3669
3670 for (auto tile : tiles.value()) {
3671 llvm::Value *tileVal = llvm::ConstantInt::get(ivType, tile);
3672 tileSizes.push_back(tileVal);
3673 }
3674
3675 std::vector<llvm::CanonicalLoopInfo *> newLoops =
3676 ompBuilder->tileLoops(ompLoc.DL, loopInfos, tileSizes);
3677
3678 // Update afterIP to get the correct insertion point after
3679 // tiling.
3680 llvm::BasicBlock *afterBB = newLoops.front()->getAfter();
3681 llvm::BasicBlock *afterAfterBB = afterBB->getSingleSuccessor();
3682 afterIP = {afterAfterBB, afterAfterBB->begin()};
3683
3684 // Update the loop infos.
3685 loopInfos.clear();
3686 for (const auto &newLoop : newLoops)
3687 loopInfos.push_back(newLoop);
3688 } // Tiling done.
3689
3690 // Do collapse.
3691 const auto &numCollapse = loopOp.getCollapseNumLoops();
3693 loopInfos.begin(), loopInfos.begin() + (numCollapse));
3694
3695 auto newTopLoopInfo =
3696 ompBuilder->collapseLoops(ompLoc.DL, collapseLoopInfos, {});
3697
3698 assert(newTopLoopInfo && "New top loop information is missing");
3699 moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
3700 [&](OpenMPLoopInfoStackFrame &frame) {
3701 frame.loopInfo = newTopLoopInfo;
3702 return WalkResult::interrupt();
3703 });
3704
3705 // Continue building IR after the loop. Note that the LoopInfo returned by
3706 // `collapseLoops` points inside the outermost loop and is intended for
3707 // potential further loop transformations. Use the insertion point stored
3708 // before collapsing loops instead.
3709 builder.restoreIP(afterIP);
3710 return success();
3711}
3712
3713/// Convert an omp.canonical_loop to LLVM-IR
3714static LogicalResult
3715convertOmpCanonicalLoopOp(omp::CanonicalLoopOp op, llvm::IRBuilderBase &builder,
3716 LLVM::ModuleTranslation &moduleTranslation) {
3717 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3718
3719 llvm::OpenMPIRBuilder::LocationDescription loopLoc(builder);
3720 Value loopIV = op.getInductionVar();
3721 Value loopTC = op.getTripCount();
3722
3723 llvm::Value *llvmTC = moduleTranslation.lookupValue(loopTC);
3724
3726 ompBuilder->createCanonicalLoop(
3727 loopLoc,
3728 [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *llvmIV) {
3729 // Register the mapping of MLIR induction variable to LLVM-IR
3730 // induction variable
3731 moduleTranslation.mapValue(loopIV, llvmIV);
3732
3733 builder.restoreIP(ip);
3735 convertOmpOpRegions(op.getRegion(), "omp.loop.region", builder,
3736 moduleTranslation);
3737
3738 return bodyGenStatus.takeError();
3739 },
3740 llvmTC, "omp.loop");
3741 if (!llvmOrError)
3742 return op.emitError(llvm::toString(llvmOrError.takeError()));
3743
3744 llvm::CanonicalLoopInfo *llvmCLI = *llvmOrError;
3745 llvm::IRBuilderBase::InsertPoint afterIP = llvmCLI->getAfterIP();
3746 builder.restoreIP(afterIP);
3747
3748 // Register the mapping of MLIR loop to LLVM-IR OpenMPIRBuilder loop
3749 if (Value cli = op.getCli())
3750 moduleTranslation.mapOmpLoop(cli, llvmCLI);
3751
3752 return success();
3753}
3754
3755/// Apply a `#pragma omp unroll` / "!$omp unroll" transformation using the
3756/// OpenMPIRBuilder.
3757static LogicalResult
3758applyUnrollHeuristic(omp::UnrollHeuristicOp op, llvm::IRBuilderBase &builder,
3759 LLVM::ModuleTranslation &moduleTranslation) {
3760 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3761
3762 Value applyee = op.getApplyee();
3763 assert(applyee && "Loop to apply unrolling on required");
3764
3765 llvm::CanonicalLoopInfo *consBuilderCLI =
3766 moduleTranslation.lookupOMPLoop(applyee);
3767 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
3768 ompBuilder->unrollLoopHeuristic(loc.DL, consBuilderCLI);
3769
3770 moduleTranslation.invalidateOmpLoop(applyee);
3771 return success();
3772}
3773
3774/// Apply a `#pragma omp tile` / `!$omp tile` transformation using the
3775/// OpenMPIRBuilder.
3776static LogicalResult applyTile(omp::TileOp op, llvm::IRBuilderBase &builder,
3777 LLVM::ModuleTranslation &moduleTranslation) {
3778 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3779 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
3780
3782 SmallVector<llvm::Value *> translatedSizes;
3783
3784 for (Value size : op.getSizes()) {
3785 llvm::Value *translatedSize = moduleTranslation.lookupValue(size);
3786 assert(translatedSize &&
3787 "sizes clause arguments must already be translated");
3788 translatedSizes.push_back(translatedSize);
3789 }
3790
3791 for (Value applyee : op.getApplyees()) {
3792 llvm::CanonicalLoopInfo *consBuilderCLI =
3793 moduleTranslation.lookupOMPLoop(applyee);
3794 assert(applyee && "Canonical loop must already been translated");
3795 translatedLoops.push_back(consBuilderCLI);
3796 }
3797
3798 auto generatedLoops =
3799 ompBuilder->tileLoops(loc.DL, translatedLoops, translatedSizes);
3800 if (!op.getGeneratees().empty()) {
3801 for (auto [mlirLoop, genLoop] :
3802 zip_equal(op.getGeneratees(), generatedLoops))
3803 moduleTranslation.mapOmpLoop(mlirLoop, genLoop);
3804 }
3805
3806 // CLIs can only be consumed once
3807 for (Value applyee : op.getApplyees())
3808 moduleTranslation.invalidateOmpLoop(applyee);
3809
3810 return success();
3811}
3812
3813/// Apply a `#pragma omp fuse` / `!$omp fuse` transformation using the
3814/// OpenMPIRBuilder.
3815static LogicalResult applyFuse(omp::FuseOp op, llvm::IRBuilderBase &builder,
3816 LLVM::ModuleTranslation &moduleTranslation) {
3817 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3818 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
3819
3820 // Select what CLIs are going to be fused
3821 SmallVector<llvm::CanonicalLoopInfo *> beforeFuse, toFuse, afterFuse;
3822 for (size_t i = 0; i < op.getApplyees().size(); i++) {
3823 Value applyee = op.getApplyees()[i];
3824 llvm::CanonicalLoopInfo *consBuilderCLI =
3825 moduleTranslation.lookupOMPLoop(applyee);
3826 assert(applyee && "Canonical loop must already been translated");
3827 if (op.getFirst().has_value() && i < op.getFirst().value() - 1)
3828 beforeFuse.push_back(consBuilderCLI);
3829 else if (op.getCount().has_value() &&
3830 i >= op.getFirst().value() + op.getCount().value() - 1)
3831 afterFuse.push_back(consBuilderCLI);
3832 else
3833 toFuse.push_back(consBuilderCLI);
3834 }
3835 assert(
3836 (op.getGeneratees().empty() ||
3837 beforeFuse.size() + afterFuse.size() + 1 == op.getGeneratees().size()) &&
3838 "Wrong number of generatees");
3839
3840 // do the fuse
3841 auto generatedLoop = ompBuilder->fuseLoops(loc.DL, toFuse);
3842 if (!op.getGeneratees().empty()) {
3843 size_t i = 0;
3844 for (; i < beforeFuse.size(); i++)
3845 moduleTranslation.mapOmpLoop(op.getGeneratees()[i], beforeFuse[i]);
3846 moduleTranslation.mapOmpLoop(op.getGeneratees()[i++], generatedLoop);
3847 for (; i < afterFuse.size(); i++)
3848 moduleTranslation.mapOmpLoop(op.getGeneratees()[i], afterFuse[i]);
3849 }
3850
3851 // CLIs can only be consumed once
3852 for (Value applyee : op.getApplyees())
3853 moduleTranslation.invalidateOmpLoop(applyee);
3854
3855 return success();
3856}
3857
3858/// Convert an Atomic Ordering attribute to llvm::AtomicOrdering.
3859static llvm::AtomicOrdering
3860convertAtomicOrdering(std::optional<omp::ClauseMemoryOrderKind> ao) {
3861 if (!ao)
3862 return llvm::AtomicOrdering::Monotonic; // Default Memory Ordering
3863
3864 switch (*ao) {
3865 case omp::ClauseMemoryOrderKind::Seq_cst:
3866 return llvm::AtomicOrdering::SequentiallyConsistent;
3867 case omp::ClauseMemoryOrderKind::Acq_rel:
3868 return llvm::AtomicOrdering::AcquireRelease;
3869 case omp::ClauseMemoryOrderKind::Acquire:
3870 return llvm::AtomicOrdering::Acquire;
3871 case omp::ClauseMemoryOrderKind::Release:
3872 return llvm::AtomicOrdering::Release;
3873 case omp::ClauseMemoryOrderKind::Relaxed:
3874 return llvm::AtomicOrdering::Monotonic;
3875 }
3876 llvm_unreachable("Unknown ClauseMemoryOrderKind kind");
3877}
3878
3879/// Convert omp.atomic.read operation to LLVM IR.
3880static LogicalResult
3881convertOmpAtomicRead(Operation &opInst, llvm::IRBuilderBase &builder,
3882 LLVM::ModuleTranslation &moduleTranslation) {
3883 auto readOp = cast<omp::AtomicReadOp>(opInst);
3884 if (failed(checkImplementationStatus(opInst)))
3885 return failure();
3886
3887 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3888 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3889 findAllocaInsertPoint(builder, moduleTranslation);
3890
3891 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3892
3893 llvm::AtomicOrdering AO = convertAtomicOrdering(readOp.getMemoryOrder());
3894 llvm::Value *x = moduleTranslation.lookupValue(readOp.getX());
3895 llvm::Value *v = moduleTranslation.lookupValue(readOp.getV());
3896
3897 llvm::Type *elementType =
3898 moduleTranslation.convertType(readOp.getElementType());
3899
3900 llvm::OpenMPIRBuilder::AtomicOpValue V = {v, elementType, false, false};
3901 llvm::OpenMPIRBuilder::AtomicOpValue X = {x, elementType, false, false};
3902 builder.restoreIP(ompBuilder->createAtomicRead(ompLoc, X, V, AO, allocaIP));
3903 return success();
3904}
3905
3906/// Converts an omp.atomic.write operation to LLVM IR.
3907static LogicalResult
3908convertOmpAtomicWrite(Operation &opInst, llvm::IRBuilderBase &builder,
3909 LLVM::ModuleTranslation &moduleTranslation) {
3910 auto writeOp = cast<omp::AtomicWriteOp>(opInst);
3911 if (failed(checkImplementationStatus(opInst)))
3912 return failure();
3913
3914 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3915 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3916 findAllocaInsertPoint(builder, moduleTranslation);
3917
3918 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3919 llvm::AtomicOrdering ao = convertAtomicOrdering(writeOp.getMemoryOrder());
3920 llvm::Value *expr = moduleTranslation.lookupValue(writeOp.getExpr());
3921 llvm::Value *dest = moduleTranslation.lookupValue(writeOp.getX());
3922 llvm::Type *ty = moduleTranslation.convertType(writeOp.getExpr().getType());
3923 llvm::OpenMPIRBuilder::AtomicOpValue x = {dest, ty, /*isSigned=*/false,
3924 /*isVolatile=*/false};
3925 builder.restoreIP(
3926 ompBuilder->createAtomicWrite(ompLoc, x, expr, ao, allocaIP));
3927 return success();
3928}
3929
3930/// Converts an LLVM dialect binary operation to the corresponding enum value
3931/// for `atomicrmw` supported binary operation.
3932static llvm::AtomicRMWInst::BinOp convertBinOpToAtomic(Operation &op) {
3934 .Case([&](LLVM::AddOp) { return llvm::AtomicRMWInst::BinOp::Add; })
3935 .Case([&](LLVM::SubOp) { return llvm::AtomicRMWInst::BinOp::Sub; })
3936 .Case([&](LLVM::AndOp) { return llvm::AtomicRMWInst::BinOp::And; })
3937 .Case([&](LLVM::OrOp) { return llvm::AtomicRMWInst::BinOp::Or; })
3938 .Case([&](LLVM::XOrOp) { return llvm::AtomicRMWInst::BinOp::Xor; })
3939 .Case([&](LLVM::UMaxOp) { return llvm::AtomicRMWInst::BinOp::UMax; })
3940 .Case([&](LLVM::UMinOp) { return llvm::AtomicRMWInst::BinOp::UMin; })
3941 .Case([&](LLVM::FAddOp) { return llvm::AtomicRMWInst::BinOp::FAdd; })
3942 .Case([&](LLVM::FSubOp) { return llvm::AtomicRMWInst::BinOp::FSub; })
3943 .Default(llvm::AtomicRMWInst::BinOp::BAD_BINOP);
3944}
3945
3946static void extractAtomicControlFlags(omp::AtomicUpdateOp atomicUpdateOp,
3947 bool &isIgnoreDenormalMode,
3948 bool &isFineGrainedMemory,
3949 bool &isRemoteMemory) {
3950 isIgnoreDenormalMode = false;
3951 isFineGrainedMemory = false;
3952 isRemoteMemory = false;
3953 if (atomicUpdateOp &&
3954 atomicUpdateOp->hasAttr(atomicUpdateOp.getAtomicControlAttrName())) {
3955 mlir::omp::AtomicControlAttr atomicControlAttr =
3956 atomicUpdateOp.getAtomicControlAttr();
3957 isIgnoreDenormalMode = atomicControlAttr.getIgnoreDenormalMode();
3958 isFineGrainedMemory = atomicControlAttr.getFineGrainedMemory();
3959 isRemoteMemory = atomicControlAttr.getRemoteMemory();
3960 }
3961}
3962
3963/// Converts an OpenMP atomic update operation using OpenMPIRBuilder.
3964static LogicalResult
3965convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst,
3966 llvm::IRBuilderBase &builder,
3967 LLVM::ModuleTranslation &moduleTranslation) {
3968 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3969 if (failed(checkImplementationStatus(*opInst)))
3970 return failure();
3971
3972 // Convert values and types.
3973 auto &innerOpList = opInst.getRegion().front().getOperations();
3974 bool isXBinopExpr{false};
3975 llvm::AtomicRMWInst::BinOp binop;
3976 mlir::Value mlirExpr;
3977 llvm::Value *llvmExpr = nullptr;
3978 llvm::Value *llvmX = nullptr;
3979 llvm::Type *llvmXElementType = nullptr;
3980 if (innerOpList.size() == 2) {
3981 // The two operations here are the update and the terminator.
3982 // Since we can identify the update operation, there is a possibility
3983 // that we can generate the atomicrmw instruction.
3984 mlir::Operation &innerOp = *opInst.getRegion().front().begin();
3985 if (!llvm::is_contained(innerOp.getOperands(),
3986 opInst.getRegion().getArgument(0))) {
3987 return opInst.emitError("no atomic update operation with region argument"
3988 " as operand found inside atomic.update region");
3989 }
3990 binop = convertBinOpToAtomic(innerOp);
3991 isXBinopExpr = innerOp.getOperand(0) == opInst.getRegion().getArgument(0);
3992 mlirExpr = (isXBinopExpr ? innerOp.getOperand(1) : innerOp.getOperand(0));
3993 llvmExpr = moduleTranslation.lookupValue(mlirExpr);
3994 } else {
3995 // Since the update region includes more than one operation
3996 // we will resort to generating a cmpxchg loop.
3997 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
3998 }
3999 llvmX = moduleTranslation.lookupValue(opInst.getX());
4000 llvmXElementType = moduleTranslation.convertType(
4001 opInst.getRegion().getArgument(0).getType());
4002 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
4003 /*isSigned=*/false,
4004 /*isVolatile=*/false};
4005
4006 llvm::AtomicOrdering atomicOrdering =
4007 convertAtomicOrdering(opInst.getMemoryOrder());
4008
4009 // Generate update code.
4010 auto updateFn =
4011 [&opInst, &moduleTranslation](
4012 llvm::Value *atomicx,
4013 llvm::IRBuilder<> &builder) -> llvm::Expected<llvm::Value *> {
4014 Block &bb = *opInst.getRegion().begin();
4015 moduleTranslation.mapValue(*opInst.getRegion().args_begin(), atomicx);
4016 moduleTranslation.mapBlock(&bb, builder.GetInsertBlock());
4017 if (failed(moduleTranslation.convertBlock(bb, true, builder)))
4018 return llvm::make_error<PreviouslyReportedError>();
4019
4020 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.getTerminator());
4021 assert(yieldop && yieldop.getResults().size() == 1 &&
4022 "terminator must be omp.yield op and it must have exactly one "
4023 "argument");
4024 return moduleTranslation.lookupValue(yieldop.getResults()[0]);
4025 };
4026
4027 bool isIgnoreDenormalMode;
4028 bool isFineGrainedMemory;
4029 bool isRemoteMemory;
4030 extractAtomicControlFlags(opInst, isIgnoreDenormalMode, isFineGrainedMemory,
4031 isRemoteMemory);
4032 // Handle ambiguous alloca, if any.
4033 auto allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
4034 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4035 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4036 ompBuilder->createAtomicUpdate(ompLoc, allocaIP, llvmAtomicX, llvmExpr,
4037 atomicOrdering, binop, updateFn,
4038 isXBinopExpr, isIgnoreDenormalMode,
4039 isFineGrainedMemory, isRemoteMemory);
4040
4041 if (failed(handleError(afterIP, *opInst)))
4042 return failure();
4043
4044 builder.restoreIP(*afterIP);
4045 return success();
4046}
4047
4048static LogicalResult
4049convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp,
4050 llvm::IRBuilderBase &builder,
4051 LLVM::ModuleTranslation &moduleTranslation) {
4052 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4053 if (failed(checkImplementationStatus(*atomicCaptureOp)))
4054 return failure();
4055
4056 mlir::Value mlirExpr;
4057 bool isXBinopExpr = false, isPostfixUpdate = false;
4058 llvm::AtomicRMWInst::BinOp binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
4059
4060 omp::AtomicUpdateOp atomicUpdateOp = atomicCaptureOp.getAtomicUpdateOp();
4061 omp::AtomicWriteOp atomicWriteOp = atomicCaptureOp.getAtomicWriteOp();
4062
4063 assert((atomicUpdateOp || atomicWriteOp) &&
4064 "internal op must be an atomic.update or atomic.write op");
4065
4066 if (atomicWriteOp) {
4067 isPostfixUpdate = true;
4068 mlirExpr = atomicWriteOp.getExpr();
4069 } else {
4070 isPostfixUpdate = atomicCaptureOp.getSecondOp() ==
4071 atomicCaptureOp.getAtomicUpdateOp().getOperation();
4072 auto &innerOpList = atomicUpdateOp.getRegion().front().getOperations();
4073 // Find the binary update operation that uses the region argument
4074 // and get the expression to update
4075 if (innerOpList.size() == 2) {
4076 mlir::Operation &innerOp = *atomicUpdateOp.getRegion().front().begin();
4077 if (!llvm::is_contained(innerOp.getOperands(),
4078 atomicUpdateOp.getRegion().getArgument(0))) {
4079 return atomicUpdateOp.emitError(
4080 "no atomic update operation with region argument"
4081 " as operand found inside atomic.update region");
4082 }
4083 binop = convertBinOpToAtomic(innerOp);
4084 isXBinopExpr =
4085 innerOp.getOperand(0) == atomicUpdateOp.getRegion().getArgument(0);
4086 mlirExpr = (isXBinopExpr ? innerOp.getOperand(1) : innerOp.getOperand(0));
4087 } else {
4088 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
4089 }
4090 }
4091
4092 llvm::Value *llvmExpr = moduleTranslation.lookupValue(mlirExpr);
4093 llvm::Value *llvmX =
4094 moduleTranslation.lookupValue(atomicCaptureOp.getAtomicReadOp().getX());
4095 llvm::Value *llvmV =
4096 moduleTranslation.lookupValue(atomicCaptureOp.getAtomicReadOp().getV());
4097 llvm::Type *llvmXElementType = moduleTranslation.convertType(
4098 atomicCaptureOp.getAtomicReadOp().getElementType());
4099 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
4100 /*isSigned=*/false,
4101 /*isVolatile=*/false};
4102 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicV = {llvmV, llvmXElementType,
4103 /*isSigned=*/false,
4104 /*isVolatile=*/false};
4105
4106 llvm::AtomicOrdering atomicOrdering =
4107 convertAtomicOrdering(atomicCaptureOp.getMemoryOrder());
4108
4109 auto updateFn =
4110 [&](llvm::Value *atomicx,
4111 llvm::IRBuilder<> &builder) -> llvm::Expected<llvm::Value *> {
4112 if (atomicWriteOp)
4113 return moduleTranslation.lookupValue(atomicWriteOp.getExpr());
4114 Block &bb = *atomicUpdateOp.getRegion().begin();
4115 moduleTranslation.mapValue(*atomicUpdateOp.getRegion().args_begin(),
4116 atomicx);
4117 moduleTranslation.mapBlock(&bb, builder.GetInsertBlock());
4118 if (failed(moduleTranslation.convertBlock(bb, true, builder)))
4119 return llvm::make_error<PreviouslyReportedError>();
4120
4121 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.getTerminator());
4122 assert(yieldop && yieldop.getResults().size() == 1 &&
4123 "terminator must be omp.yield op and it must have exactly one "
4124 "argument");
4125 return moduleTranslation.lookupValue(yieldop.getResults()[0]);
4126 };
4127
4128 bool isIgnoreDenormalMode;
4129 bool isFineGrainedMemory;
4130 bool isRemoteMemory;
4131 extractAtomicControlFlags(atomicUpdateOp, isIgnoreDenormalMode,
4132 isFineGrainedMemory, isRemoteMemory);
4133 // Handle ambiguous alloca, if any.
4134 auto allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
4135 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4136 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4137 ompBuilder->createAtomicCapture(
4138 ompLoc, allocaIP, llvmAtomicX, llvmAtomicV, llvmExpr, atomicOrdering,
4139 binop, updateFn, atomicUpdateOp, isPostfixUpdate, isXBinopExpr,
4140 isIgnoreDenormalMode, isFineGrainedMemory, isRemoteMemory);
4141
4142 if (failed(handleError(afterIP, *atomicCaptureOp)))
4143 return failure();
4144
4145 builder.restoreIP(*afterIP);
4146 return success();
4147}
4148
4149static llvm::omp::Directive convertCancellationConstructType(
4150 omp::ClauseCancellationConstructType directive) {
4151 switch (directive) {
4152 case omp::ClauseCancellationConstructType::Loop:
4153 return llvm::omp::Directive::OMPD_for;
4154 case omp::ClauseCancellationConstructType::Parallel:
4155 return llvm::omp::Directive::OMPD_parallel;
4156 case omp::ClauseCancellationConstructType::Sections:
4157 return llvm::omp::Directive::OMPD_sections;
4158 case omp::ClauseCancellationConstructType::Taskgroup:
4159 return llvm::omp::Directive::OMPD_taskgroup;
4160 }
4161 llvm_unreachable("Unhandled cancellation construct type");
4162}
4163
4164static LogicalResult
4165convertOmpCancel(omp::CancelOp op, llvm::IRBuilderBase &builder,
4166 LLVM::ModuleTranslation &moduleTranslation) {
4167 if (failed(checkImplementationStatus(*op.getOperation())))
4168 return failure();
4169
4170 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4171 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4172
4173 llvm::Value *ifCond = nullptr;
4174 if (Value ifVar = op.getIfExpr())
4175 ifCond = moduleTranslation.lookupValue(ifVar);
4176
4177 llvm::omp::Directive cancelledDirective =
4178 convertCancellationConstructType(op.getCancelDirective());
4179
4180 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4181 ompBuilder->createCancel(ompLoc, ifCond, cancelledDirective);
4182
4183 if (failed(handleError(afterIP, *op.getOperation())))
4184 return failure();
4185
4186 builder.restoreIP(afterIP.get());
4187
4188 return success();
4189}
4190
4191static LogicalResult
4192convertOmpCancellationPoint(omp::CancellationPointOp op,
4193 llvm::IRBuilderBase &builder,
4194 LLVM::ModuleTranslation &moduleTranslation) {
4195 if (failed(checkImplementationStatus(*op.getOperation())))
4196 return failure();
4197
4198 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4199 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4200
4201 llvm::omp::Directive cancelledDirective =
4202 convertCancellationConstructType(op.getCancelDirective());
4203
4204 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4205 ompBuilder->createCancellationPoint(ompLoc, cancelledDirective);
4206
4207 if (failed(handleError(afterIP, *op.getOperation())))
4208 return failure();
4209
4210 builder.restoreIP(afterIP.get());
4211
4212 return success();
4213}
4214
4215/// Converts an OpenMP Threadprivate operation into LLVM IR using
4216/// OpenMPIRBuilder.
4217static LogicalResult
4218convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder,
4219 LLVM::ModuleTranslation &moduleTranslation) {
4220 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4221 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4222 auto threadprivateOp = cast<omp::ThreadprivateOp>(opInst);
4223
4224 if (failed(checkImplementationStatus(opInst)))
4225 return failure();
4226
4227 Value symAddr = threadprivateOp.getSymAddr();
4228 auto *symOp = symAddr.getDefiningOp();
4229
4230 if (auto asCast = dyn_cast<LLVM::AddrSpaceCastOp>(symOp))
4231 symOp = asCast.getOperand().getDefiningOp();
4232
4233 if (!isa<LLVM::AddressOfOp>(symOp))
4234 return opInst.emitError("Addressing symbol not found");
4235 LLVM::AddressOfOp addressOfOp = dyn_cast<LLVM::AddressOfOp>(symOp);
4236
4237 LLVM::GlobalOp global =
4238 addressOfOp.getGlobal(moduleTranslation.symbolTable());
4239 llvm::GlobalValue *globalValue = moduleTranslation.lookupGlobal(global);
4240 llvm::Type *type = globalValue->getValueType();
4241 llvm::TypeSize typeSize =
4242 builder.GetInsertBlock()->getModule()->getDataLayout().getTypeStoreSize(
4243 type);
4244 llvm::ConstantInt *size = builder.getInt64(typeSize.getFixedValue());
4245 llvm::Value *callInst = ompBuilder->createCachedThreadPrivate(
4246 ompLoc, globalValue, size, global.getSymName() + ".cache");
4247 moduleTranslation.mapValue(opInst.getResult(0), callInst);
4248
4249 return success();
4250}
4251
4252static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind
4253convertToDeviceClauseKind(mlir::omp::DeclareTargetDeviceType deviceClause) {
4254 switch (deviceClause) {
4255 case mlir::omp::DeclareTargetDeviceType::host:
4256 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseHost;
4257 break;
4258 case mlir::omp::DeclareTargetDeviceType::nohost:
4259 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseNoHost;
4260 break;
4261 case mlir::omp::DeclareTargetDeviceType::any:
4262 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseAny;
4263 break;
4264 }
4265 llvm_unreachable("unhandled device clause");
4266}
4267
4268static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind
4270 mlir::omp::DeclareTargetCaptureClause captureClause) {
4271 switch (captureClause) {
4272 case mlir::omp::DeclareTargetCaptureClause::to:
4273 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo;
4274 case mlir::omp::DeclareTargetCaptureClause::link:
4275 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink;
4276 case mlir::omp::DeclareTargetCaptureClause::enter:
4277 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter;
4278 case mlir::omp::DeclareTargetCaptureClause::none:
4279 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryNone;
4280 }
4281 llvm_unreachable("unhandled capture clause");
4282}
4283
4285 Operation *op = value.getDefiningOp();
4286 if (auto addrCast = dyn_cast_if_present<LLVM::AddrSpaceCastOp>(op))
4287 op = addrCast->getOperand(0).getDefiningOp();
4288 if (auto addressOfOp = dyn_cast_if_present<LLVM::AddressOfOp>(op)) {
4289 auto modOp = addressOfOp->getParentOfType<mlir::ModuleOp>();
4290 return modOp.lookupSymbol(addressOfOp.getGlobalName());
4291 }
4292 return nullptr;
4293}
4294
4295static llvm::SmallString<64>
4296getDeclareTargetRefPtrSuffix(LLVM::GlobalOp globalOp,
4297 llvm::OpenMPIRBuilder &ompBuilder) {
4298 llvm::SmallString<64> suffix;
4299 llvm::raw_svector_ostream os(suffix);
4300 if (globalOp.getVisibility() == mlir::SymbolTable::Visibility::Private) {
4301 auto loc = globalOp->getLoc()->findInstanceOf<FileLineColLoc>();
4302 auto fileInfoCallBack = [&loc]() {
4303 return std::pair<std::string, uint64_t>(
4304 llvm::StringRef(loc.getFilename()), loc.getLine());
4305 };
4306
4307 auto vfs = llvm::vfs::getRealFileSystem();
4308 os << llvm::format(
4309 "_%x",
4310 ompBuilder.getTargetEntryUniqueInfo(fileInfoCallBack, *vfs).FileID);
4311 }
4312 os << "_decl_tgt_ref_ptr";
4313
4314 return suffix;
4315}
4316
4317static bool isDeclareTargetLink(Value value) {
4318 if (auto declareTargetGlobal =
4319 dyn_cast_if_present<omp::DeclareTargetInterface>(
4320 getGlobalOpFromValue(value)))
4321 if (declareTargetGlobal.getDeclareTargetCaptureClause() ==
4322 omp::DeclareTargetCaptureClause::link)
4323 return true;
4324 return false;
4325}
4326
4327static bool isDeclareTargetTo(Value value) {
4328 if (auto declareTargetGlobal =
4329 dyn_cast_if_present<omp::DeclareTargetInterface>(
4330 getGlobalOpFromValue(value)))
4331 if (declareTargetGlobal.getDeclareTargetCaptureClause() ==
4332 omp::DeclareTargetCaptureClause::to ||
4333 declareTargetGlobal.getDeclareTargetCaptureClause() ==
4334 omp::DeclareTargetCaptureClause::enter)
4335 return true;
4336 return false;
4337}
4338
4339// Returns the reference pointer generated by the lowering of the declare
4340// target operation in cases where the link clause is used or the to clause is
4341// used in USM mode.
4342static llvm::Value *
4344 LLVM::ModuleTranslation &moduleTranslation) {
4345 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4346 if (auto gOp =
4347 dyn_cast_or_null<LLVM::GlobalOp>(getGlobalOpFromValue(value))) {
4348 if (auto declareTargetGlobal =
4349 dyn_cast<omp::DeclareTargetInterface>(gOp.getOperation())) {
4350 // In this case, we must utilise the reference pointer generated by
4351 // the declare target operation, similar to Clang
4352 if ((declareTargetGlobal.getDeclareTargetCaptureClause() ==
4353 omp::DeclareTargetCaptureClause::link) ||
4354 (declareTargetGlobal.getDeclareTargetCaptureClause() ==
4355 omp::DeclareTargetCaptureClause::to &&
4356 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
4357 llvm::SmallString<64> suffix =
4358 getDeclareTargetRefPtrSuffix(gOp, *ompBuilder);
4359
4360 if (gOp.getSymName().contains(suffix))
4361 return moduleTranslation.getLLVMModule()->getNamedValue(
4362 gOp.getSymName());
4363
4364 return moduleTranslation.getLLVMModule()->getNamedValue(
4365 (gOp.getSymName().str() + suffix.str()).str());
4366 }
4367 }
4368 }
4369 return nullptr;
4370}
4371
4372namespace {
4373// Append customMappers information to existing MapInfosTy
4374struct MapInfosTy : llvm::OpenMPIRBuilder::MapInfosTy {
4375 SmallVector<Operation *, 4> Mappers;
4376
4377 /// Append arrays in \a CurInfo.
4378 void append(MapInfosTy &curInfo) {
4379 Mappers.append(curInfo.Mappers.begin(), curInfo.Mappers.end());
4380 llvm::OpenMPIRBuilder::MapInfosTy::append(curInfo);
4381 }
4382};
4383// A small helper structure to contain data gathered
4384// for map lowering and coalese it into one area and
4385// avoiding extra computations such as searches in the
4386// llvm module for lowered mapped variables or checking
4387// if something is declare target (and retrieving the
4388// value) more than neccessary.
4389struct MapInfoData : MapInfosTy {
4390 llvm::SmallVector<bool, 4> IsDeclareTarget;
4391 llvm::SmallVector<bool, 4> IsAMember;
4392 // Identify if mapping was added by mapClause or use_device clauses.
4393 llvm::SmallVector<bool, 4> IsAMapping;
4394 llvm::SmallVector<mlir::Operation *, 4> MapClause;
4395 llvm::SmallVector<llvm::Value *, 4> OriginalValue;
4396 // Stripped off array/pointer to get the underlying
4397 // element type
4398 llvm::SmallVector<llvm::Type *, 4> BaseType;
4399
4400 /// Append arrays in \a CurInfo.
4401 void append(MapInfoData &CurInfo) {
4402 IsDeclareTarget.append(CurInfo.IsDeclareTarget.begin(),
4403 CurInfo.IsDeclareTarget.end());
4404 MapClause.append(CurInfo.MapClause.begin(), CurInfo.MapClause.end());
4405 OriginalValue.append(CurInfo.OriginalValue.begin(),
4406 CurInfo.OriginalValue.end());
4407 BaseType.append(CurInfo.BaseType.begin(), CurInfo.BaseType.end());
4408 MapInfosTy::append(CurInfo);
4409 }
4410};
4411
4412enum class TargetDirectiveEnumTy : uint32_t {
4413 None = 0,
4414 Target = 1,
4415 TargetData = 2,
4416 TargetEnterData = 3,
4417 TargetExitData = 4,
4418 TargetUpdate = 5
4419};
4420
4421static TargetDirectiveEnumTy getTargetDirectiveEnumTyFromOp(Operation *op) {
4422 return llvm::TypeSwitch<Operation *, TargetDirectiveEnumTy>(op)
4423 .Case([](omp::TargetDataOp) { return TargetDirectiveEnumTy::TargetData; })
4424 .Case([](omp::TargetEnterDataOp) {
4425 return TargetDirectiveEnumTy::TargetEnterData;
4426 })
4427 .Case([&](omp::TargetExitDataOp) {
4428 return TargetDirectiveEnumTy::TargetExitData;
4429 })
4430 .Case([&](omp::TargetUpdateOp) {
4431 return TargetDirectiveEnumTy::TargetUpdate;
4432 })
4433 .Case([&](omp::TargetOp) { return TargetDirectiveEnumTy::Target; })
4434 .Default([&](Operation *op) { return TargetDirectiveEnumTy::None; });
4435}
4436
4437} // namespace
4438
4439static uint64_t getArrayElementSizeInBits(LLVM::LLVMArrayType arrTy,
4440 DataLayout &dl) {
4441 if (auto nestedArrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(
4442 arrTy.getElementType()))
4443 return getArrayElementSizeInBits(nestedArrTy, dl);
4444 return dl.getTypeSizeInBits(arrTy.getElementType());
4445}
4446
4447// This function calculates the size to be offloaded for a specified type, given
4448// its associated map clause (which can contain bounds information which affects
4449// the total size), this size is calculated based on the underlying element type
4450// e.g. given a 1-D array of ints, we will calculate the size from the integer
4451// type * number of elements in the array. This size can be used in other
4452// calculations but is ultimately used as an argument to the OpenMP runtimes
4453// kernel argument structure which is generated through the combinedInfo data
4454// structures.
4455// This function is somewhat equivalent to Clang's getExprTypeSize inside of
4456// CGOpenMPRuntime.cpp.
4457static llvm::Value *getSizeInBytes(DataLayout &dl, const mlir::Type &type,
4458 Operation *clauseOp,
4459 llvm::Value *basePointer,
4460 llvm::Type *baseType,
4461 llvm::IRBuilderBase &builder,
4462 LLVM::ModuleTranslation &moduleTranslation) {
4463 if (auto memberClause =
4464 mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(clauseOp)) {
4465 // This calculates the size to transfer based on bounds and the underlying
4466 // element type, provided bounds have been specified (Fortran
4467 // pointers/allocatables/target and arrays that have sections specified fall
4468 // into this as well)
4469 if (!memberClause.getBounds().empty()) {
4470 llvm::Value *elementCount = builder.getInt64(1);
4471 for (auto bounds : memberClause.getBounds()) {
4472 if (auto boundOp = mlir::dyn_cast_if_present<mlir::omp::MapBoundsOp>(
4473 bounds.getDefiningOp())) {
4474 // The below calculation for the size to be mapped calculated from the
4475 // map.info's bounds is: (elemCount * [UB - LB] + 1), later we
4476 // multiply by the underlying element types byte size to get the full
4477 // size to be offloaded based on the bounds
4478 elementCount = builder.CreateMul(
4479 elementCount,
4480 builder.CreateAdd(
4481 builder.CreateSub(
4482 moduleTranslation.lookupValue(boundOp.getUpperBound()),
4483 moduleTranslation.lookupValue(boundOp.getLowerBound())),
4484 builder.getInt64(1)));
4485 }
4486 }
4487
4488 // utilising getTypeSizeInBits instead of getTypeSize as getTypeSize gives
4489 // the size in inconsistent byte or bit format.
4490 uint64_t underlyingTypeSzInBits = dl.getTypeSizeInBits(type);
4491 if (auto arrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(type))
4492 underlyingTypeSzInBits = getArrayElementSizeInBits(arrTy, dl);
4493
4494 // The size in bytes x number of elements, the sizeInBytes stored is
4495 // the underyling types size, e.g. if ptr<i32>, it'll be the i32's
4496 // size, so we do some on the fly runtime math to get the size in
4497 // bytes from the extent (ub - lb) * sizeInBytes. NOTE: This may need
4498 // some adjustment for members with more complex types.
4499 return builder.CreateMul(elementCount,
4500 builder.getInt64(underlyingTypeSzInBits / 8));
4501 }
4502 }
4503
4504 return builder.getInt64(dl.getTypeSizeInBits(type) / 8);
4505}
4506
4507// Convert the MLIR map flag set to the runtime map flag set for embedding
4508// in LLVM-IR. This is important as the two bit-flag lists do not correspond
4509// 1-to-1 as there's flags the runtime doesn't care about and vice versa.
4510// Certain flags are discarded here such as RefPtee and co.
4511static llvm::omp::OpenMPOffloadMappingFlags
4512convertClauseMapFlags(omp::ClauseMapFlags mlirFlags) {
4513 auto mapTypeToBool = [&mlirFlags](omp::ClauseMapFlags flag) {
4514 return (mlirFlags & flag) == flag;
4515 };
4516 const bool hasExplicitMap =
4517 (mlirFlags & ~omp::ClauseMapFlags::is_device_ptr) !=
4518 omp::ClauseMapFlags::none;
4519
4520 llvm::omp::OpenMPOffloadMappingFlags mapType =
4521 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
4522
4523 if (mapTypeToBool(omp::ClauseMapFlags::to))
4524 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
4525
4526 if (mapTypeToBool(omp::ClauseMapFlags::from))
4527 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
4528
4529 if (mapTypeToBool(omp::ClauseMapFlags::always))
4530 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
4531
4532 if (mapTypeToBool(omp::ClauseMapFlags::del))
4533 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
4534
4535 if (mapTypeToBool(omp::ClauseMapFlags::return_param))
4536 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
4537
4538 if (mapTypeToBool(omp::ClauseMapFlags::priv))
4539 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRIVATE;
4540
4541 if (mapTypeToBool(omp::ClauseMapFlags::literal))
4542 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
4543
4544 if (mapTypeToBool(omp::ClauseMapFlags::implicit))
4545 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
4546
4547 if (mapTypeToBool(omp::ClauseMapFlags::close))
4548 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
4549
4550 if (mapTypeToBool(omp::ClauseMapFlags::present))
4551 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT;
4552
4553 if (mapTypeToBool(omp::ClauseMapFlags::ompx_hold))
4554 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD;
4555
4556 if (mapTypeToBool(omp::ClauseMapFlags::attach))
4557 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH;
4558
4559 if (mapTypeToBool(omp::ClauseMapFlags::is_device_ptr)) {
4560 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
4561 if (!hasExplicitMap)
4562 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
4563 }
4564
4565 return mapType;
4566}
4567
4569 MapInfoData &mapData, SmallVectorImpl<Value> &mapVars,
4570 LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl,
4571 llvm::IRBuilderBase &builder, ArrayRef<Value> useDevPtrOperands = {},
4572 ArrayRef<Value> useDevAddrOperands = {},
4573 ArrayRef<Value> hasDevAddrOperands = {}) {
4574 auto checkIsAMember = [](const auto &mapVars, auto mapOp) {
4575 // Check if this is a member mapping and correctly assign that it is, if
4576 // it is a member of a larger object.
4577 // TODO: Need better handling of members, and distinguishing of members
4578 // that are implicitly allocated on device vs explicitly passed in as
4579 // arguments.
4580 // TODO: May require some further additions to support nested record
4581 // types, i.e. member maps that can have member maps.
4582 for (Value mapValue : mapVars) {
4583 auto map = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
4584 for (auto member : map.getMembers())
4585 if (member == mapOp)
4586 return true;
4587 }
4588 return false;
4589 };
4590
4591 // Process MapOperands
4592 for (Value mapValue : mapVars) {
4593 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
4594 Value offloadPtr =
4595 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
4596 mapData.OriginalValue.push_back(moduleTranslation.lookupValue(offloadPtr));
4597 mapData.Pointers.push_back(mapData.OriginalValue.back());
4598
4599 if (llvm::Value *refPtr =
4600 getRefPtrIfDeclareTarget(offloadPtr, moduleTranslation)) {
4601 mapData.IsDeclareTarget.push_back(true);
4602 mapData.BasePointers.push_back(refPtr);
4603 } else if (isDeclareTargetTo(offloadPtr)) {
4604 mapData.IsDeclareTarget.push_back(true);
4605 mapData.BasePointers.push_back(mapData.OriginalValue.back());
4606 } else { // regular mapped variable
4607 mapData.IsDeclareTarget.push_back(false);
4608 mapData.BasePointers.push_back(mapData.OriginalValue.back());
4609 }
4610
4611 mapData.BaseType.push_back(
4612 moduleTranslation.convertType(mapOp.getVarType()));
4613 mapData.Sizes.push_back(
4614 getSizeInBytes(dl, mapOp.getVarType(), mapOp, mapData.Pointers.back(),
4615 mapData.BaseType.back(), builder, moduleTranslation));
4616 mapData.MapClause.push_back(mapOp.getOperation());
4617 mapData.Types.push_back(convertClauseMapFlags(mapOp.getMapType()));
4618 mapData.Names.push_back(LLVM::createMappingInformation(
4619 mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
4620 mapData.DevicePointers.push_back(llvm::OpenMPIRBuilder::DeviceInfoTy::None);
4621 if (mapOp.getMapperId())
4622 mapData.Mappers.push_back(
4624 mapOp, mapOp.getMapperIdAttr()));
4625 else
4626 mapData.Mappers.push_back(nullptr);
4627 mapData.IsAMapping.push_back(true);
4628 mapData.IsAMember.push_back(checkIsAMember(mapVars, mapOp));
4629 }
4630
4631 auto findMapInfo = [&mapData](llvm::Value *val,
4632 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
4633 unsigned index = 0;
4634 bool found = false;
4635 for (llvm::Value *basePtr : mapData.OriginalValue) {
4636 if (basePtr == val && mapData.IsAMapping[index]) {
4637 found = true;
4638 mapData.Types[index] |=
4639 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
4640 mapData.DevicePointers[index] = devInfoTy;
4641 }
4642 index++;
4643 }
4644 return found;
4645 };
4646
4647 // Process useDevPtr(Addr)Operands
4648 auto addDevInfos = [&](const llvm::ArrayRef<Value> &useDevOperands,
4649 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
4650 for (Value mapValue : useDevOperands) {
4651 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
4652 Value offloadPtr =
4653 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
4654 llvm::Value *origValue = moduleTranslation.lookupValue(offloadPtr);
4655
4656 // Check if map info is already present for this entry.
4657 if (!findMapInfo(origValue, devInfoTy)) {
4658 mapData.OriginalValue.push_back(origValue);
4659 mapData.Pointers.push_back(mapData.OriginalValue.back());
4660 mapData.IsDeclareTarget.push_back(false);
4661 mapData.BasePointers.push_back(mapData.OriginalValue.back());
4662 mapData.BaseType.push_back(
4663 moduleTranslation.convertType(mapOp.getVarType()));
4664 mapData.Sizes.push_back(builder.getInt64(0));
4665 mapData.MapClause.push_back(mapOp.getOperation());
4666 mapData.Types.push_back(
4667 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM);
4668 mapData.Names.push_back(LLVM::createMappingInformation(
4669 mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
4670 mapData.DevicePointers.push_back(devInfoTy);
4671 mapData.Mappers.push_back(nullptr);
4672 mapData.IsAMapping.push_back(false);
4673 mapData.IsAMember.push_back(checkIsAMember(useDevOperands, mapOp));
4674 }
4675 }
4676 };
4677
4678 addDevInfos(useDevAddrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
4679 addDevInfos(useDevPtrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer);
4680
4681 for (Value mapValue : hasDevAddrOperands) {
4682 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
4683 Value offloadPtr =
4684 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
4685 llvm::Value *origValue = moduleTranslation.lookupValue(offloadPtr);
4686 auto mapType = convertClauseMapFlags(mapOp.getMapType());
4687 auto mapTypeAlways = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
4688 bool isDevicePtr =
4689 (mapOp.getMapType() & omp::ClauseMapFlags::is_device_ptr) !=
4690 omp::ClauseMapFlags::none;
4691
4692 mapData.OriginalValue.push_back(origValue);
4693 mapData.BasePointers.push_back(origValue);
4694 mapData.Pointers.push_back(origValue);
4695 mapData.IsDeclareTarget.push_back(false);
4696 mapData.BaseType.push_back(
4697 moduleTranslation.convertType(mapOp.getVarType()));
4698 mapData.Sizes.push_back(
4699 builder.getInt64(dl.getTypeSize(mapOp.getVarType())));
4700 mapData.MapClause.push_back(mapOp.getOperation());
4701 if (llvm::to_underlying(mapType & mapTypeAlways)) {
4702 // Descriptors are mapped with the ALWAYS flag, since they can get
4703 // rematerialized, so the address of the decriptor for a given object
4704 // may change from one place to another.
4705 mapData.Types.push_back(mapType);
4706 // Technically it's possible for a non-descriptor mapping to have
4707 // both has-device-addr and ALWAYS, so lookup the mapper in case it
4708 // exists.
4709 if (mapOp.getMapperId()) {
4710 mapData.Mappers.push_back(
4712 mapOp, mapOp.getMapperIdAttr()));
4713 } else {
4714 mapData.Mappers.push_back(nullptr);
4715 }
4716 } else {
4717 // For is_device_ptr we need the map type to propagate so the runtime
4718 // can materialize the device-side copy of the pointer container.
4719 mapData.Types.push_back(
4720 isDevicePtr ? mapType
4721 : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL);
4722 mapData.Mappers.push_back(nullptr);
4723 }
4724 mapData.Names.push_back(LLVM::createMappingInformation(
4725 mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
4726 mapData.DevicePointers.push_back(
4727 isDevicePtr ? llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer
4728 : llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
4729 mapData.IsAMapping.push_back(false);
4730 mapData.IsAMember.push_back(checkIsAMember(hasDevAddrOperands, mapOp));
4731 }
4732}
4733
4734static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp) {
4735 auto *res = llvm::find(mapData.MapClause, memberOp);
4736 assert(res != mapData.MapClause.end() &&
4737 "MapInfoOp for member not found in MapData, cannot return index");
4738 return std::distance(mapData.MapClause.begin(), res);
4739}
4740
4742 omp::MapInfoOp mapInfo) {
4743 ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
4744 llvm::SmallVector<size_t> occludedChildren;
4745 llvm::sort(
4746 indices.begin(), indices.end(), [&](const size_t a, const size_t b) {
4747 // Bail early if we are asked to look at the same index. If we do not
4748 // bail early, we can end up mistakenly adding indices to
4749 // occludedChildren. This can occur with some types of libc++ hardening.
4750 if (a == b)
4751 return false;
4752
4753 auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]);
4754 auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]);
4755
4756 for (auto it : llvm::zip(memberIndicesA, memberIndicesB)) {
4757 int64_t aIndex = mlir::cast<IntegerAttr>(std::get<0>(it)).getInt();
4758 int64_t bIndex = mlir::cast<IntegerAttr>(std::get<1>(it)).getInt();
4759
4760 if (aIndex == bIndex)
4761 continue;
4762
4763 if (aIndex < bIndex)
4764 return true;
4765
4766 if (aIndex > bIndex)
4767 return false;
4768 }
4769
4770 // Iterated up until the end of the smallest member and
4771 // they were found to be equal up to that point, so select
4772 // the member with the lowest index count, so the "parent"
4773 bool memberAParent = memberIndicesA.size() < memberIndicesB.size();
4774 if (memberAParent)
4775 occludedChildren.push_back(b);
4776 else
4777 occludedChildren.push_back(a);
4778 return memberAParent;
4779 });
4780
4781 // We remove children from the index list that are overshadowed by
4782 // a parent, this prevents us retrieving these as the first or last
4783 // element when the parent is the correct element in these cases.
4784 for (auto v : occludedChildren)
4785 indices.erase(std::remove(indices.begin(), indices.end(), v),
4786 indices.end());
4787}
4788
4789static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo,
4790 bool first) {
4791 ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
4792 // Only 1 member has been mapped, we can return it.
4793 if (indexAttr.size() == 1)
4794 return cast<omp::MapInfoOp>(mapInfo.getMembers()[0].getDefiningOp());
4795 llvm::SmallVector<size_t> indices(indexAttr.size());
4796 std::iota(indices.begin(), indices.end(), 0);
4797 sortMapIndices(indices, mapInfo);
4798 return llvm::cast<omp::MapInfoOp>(
4799 mapInfo.getMembers()[first ? indices.front() : indices.back()]
4800 .getDefiningOp());
4801}
4802
4803/// This function calculates the array/pointer offset for map data provided
4804/// with bounds operations, e.g. when provided something like the following:
4805///
4806/// Fortran
4807/// map(tofrom: array(2:5, 3:2))
4808///
4809/// We must calculate the initial pointer offset to pass across, this function
4810/// performs this using bounds.
4811///
4812/// TODO/WARNING: This only supports Fortran's column major indexing currently
4813/// as is noted in the note below and comments in the function, we must extend
4814/// this function when we add a C++ frontend.
4815/// NOTE: which while specified in row-major order it currently needs to be
4816/// flipped for Fortran's column order array allocation and access (as
4817/// opposed to C++'s row-major, hence the backwards processing where order is
4818/// important). This is likely important to keep in mind for the future when
4819/// we incorporate a C++ frontend, both frontends will need to agree on the
4820/// ordering of generated bounds operations (one may have to flip them) to
4821/// make the below lowering frontend agnostic. The offload size
4822/// calcualtion may also have to be adjusted for C++.
4823static std::vector<llvm::Value *>
4825 llvm::IRBuilderBase &builder, bool isArrayTy,
4826 OperandRange bounds) {
4827 std::vector<llvm::Value *> idx;
4828 // There's no bounds to calculate an offset from, we can safely
4829 // ignore and return no indices.
4830 if (bounds.empty())
4831 return idx;
4832
4833 // If we have an array type, then we have its type so can treat it as a
4834 // normal GEP instruction where the bounds operations are simply indexes
4835 // into the array. We currently do reverse order of the bounds, which
4836 // I believe leans more towards Fortran's column-major in memory.
4837 if (isArrayTy) {
4838 idx.push_back(builder.getInt64(0));
4839 for (int i = bounds.size() - 1; i >= 0; --i) {
4840 if (auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
4841 bounds[i].getDefiningOp())) {
4842 idx.push_back(moduleTranslation.lookupValue(boundOp.getLowerBound()));
4843 }
4844 }
4845 } else {
4846 // If we do not have an array type, but we have bounds, then we're dealing
4847 // with a pointer that's being treated like an array and we have the
4848 // underlying type e.g. an i32, or f64 etc, e.g. a fortran descriptor base
4849 // address (pointer pointing to the actual data) so we must caclulate the
4850 // offset using a single index which the following loop attempts to
4851 // compute using the standard column-major algorithm e.g for a 3D array:
4852 //
4853 // ((((c_idx * b_len) + b_idx) * a_len) + a_idx)
4854 //
4855 // It is of note that it's doing column-major rather than row-major at the
4856 // moment, but having a way for the frontend to indicate which major format
4857 // to use or standardizing/canonicalizing the order of the bounds to compute
4858 // the offset may be useful in the future when there's other frontends with
4859 // different formats.
4860 std::vector<llvm::Value *> dimensionIndexSizeOffset;
4861 for (int i = bounds.size() - 1; i >= 0; --i) {
4862 if (auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
4863 bounds[i].getDefiningOp())) {
4864 if (i == ((int)bounds.size() - 1))
4865 idx.emplace_back(
4866 moduleTranslation.lookupValue(boundOp.getLowerBound()));
4867 else
4868 idx.back() = builder.CreateAdd(
4869 builder.CreateMul(idx.back(), moduleTranslation.lookupValue(
4870 boundOp.getExtent())),
4871 moduleTranslation.lookupValue(boundOp.getLowerBound()));
4872 }
4873 }
4874 }
4875
4876 return idx;
4877}
4878
4880 llvm::transform(values, std::back_inserter(ints), [](Attribute value) {
4881 return cast<IntegerAttr>(value).getInt();
4882 });
4883}
4884
4885// Gathers members that are overlapping in the parent, excluding members that
4886// themselves overlap, keeping the top-most (closest to parents level) map.
4887static void
4889 omp::MapInfoOp parentOp) {
4890 // No members mapped, no overlaps.
4891 if (parentOp.getMembers().empty())
4892 return;
4893
4894 // Single member, we can insert and return early.
4895 if (parentOp.getMembers().size() == 1) {
4896 overlapMapDataIdxs.push_back(0);
4897 return;
4898 }
4899
4900 // 1) collect list of top-level overlapping members from MemberOp
4902 ArrayAttr indexAttr = parentOp.getMembersIndexAttr();
4903 for (auto [memIndex, indicesAttr] : llvm::enumerate(indexAttr))
4904 memberByIndex.push_back(
4905 std::make_pair(memIndex, cast<ArrayAttr>(indicesAttr)));
4906
4907 // Sort the smallest first (higher up the parent -> member chain), so that
4908 // when we remove members, we remove as much as we can in the initial
4909 // iterations, shortening the number of passes required.
4910 llvm::sort(memberByIndex.begin(), memberByIndex.end(),
4911 [&](auto a, auto b) { return a.second.size() < b.second.size(); });
4912
4913 // Remove elements from the vector if there is a parent element that
4914 // supersedes it. i.e. if member [0] is mapped, we can remove members [0,1],
4915 // [0,2].. etc.
4917 for (auto v : memberByIndex) {
4918 llvm::SmallVector<int64_t> vArr(v.second.size());
4919 getAsIntegers(v.second, vArr);
4920 skipList.push_back(
4921 *std::find_if(memberByIndex.begin(), memberByIndex.end(), [&](auto x) {
4922 if (v == x)
4923 return false;
4924 llvm::SmallVector<int64_t> xArr(x.second.size());
4925 getAsIntegers(x.second, xArr);
4926 return std::equal(vArr.begin(), vArr.end(), xArr.begin()) &&
4927 xArr.size() >= vArr.size();
4928 }));
4929 }
4930
4931 // Collect the indices, as we need the base pointer etc. from the MapData
4932 // structure which is primarily accessible via index at the moment.
4933 for (auto v : memberByIndex)
4934 if (find(skipList.begin(), skipList.end(), v) == skipList.end())
4935 overlapMapDataIdxs.push_back(v.first);
4936}
4937
4938// The intent is to verify if the mapped data being passed is a
4939// pointer -> pointee that requires special handling in certain cases,
4940// e.g. applying the OMP_MAP_PTR_AND_OBJ map type.
4941//
4942// There may be a better way to verify this, but unfortunately with
4943// opaque pointers we lose the ability to easily check if something is
4944// a pointer whilst maintaining access to the underlying type.
4945static bool checkIfPointerMap(omp::MapInfoOp mapOp) {
4946 // If we have a varPtrPtr field assigned then the underlying type is a pointer
4947 if (mapOp.getVarPtrPtr())
4948 return true;
4949
4950 // If the map data is declare target with a link clause, then it's represented
4951 // as a pointer when we lower it to LLVM-IR even if at the MLIR level it has
4952 // no relation to pointers.
4953 if (isDeclareTargetLink(mapOp.getVarPtr()))
4954 return true;
4955
4956 return false;
4957}
4958
4959// This creates two insertions into the MapInfosTy data structure for the
4960// "parent" of a set of members, (usually a container e.g.
4961// class/structure/derived type) when subsequent members have also been
4962// explicitly mapped on the same map clause. Certain types, such as Fortran
4963// descriptors are mapped like this as well, however, the members are
4964// implicit as far as a user is concerned, but we must explicitly map them
4965// internally.
4966//
4967// This function also returns the memberOfFlag for this particular parent,
4968// which is utilised in subsequent member mappings (by modifying there map type
4969// with it) to indicate that a member is part of this parent and should be
4970// treated by the runtime as such. Important to achieve the correct mapping.
4971//
4972// This function borrows a lot from Clang's emitCombinedEntry function
4973// inside of CGOpenMPRuntime.cpp
4974static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
4975 LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder,
4976 llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo,
4977 MapInfoData &mapData, uint64_t mapDataIndex,
4978 TargetDirectiveEnumTy targetDirective) {
4979 assert(!ompBuilder.Config.isTargetDevice() &&
4980 "function only supported for host device codegen");
4981
4982 auto parentClause =
4983 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
4984
4985 auto *parentMapper = mapData.Mappers[mapDataIndex];
4986
4987 // Map the first segment of the parent. If a user-defined mapper is attached,
4988 // include the parent's to/from-style bits (and common modifiers) in this
4989 // base entry so the mapper receives correct copy semantics via its 'type'
4990 // parameter. Also keep TARGET_PARAM when required for kernel arguments.
4991 llvm::omp::OpenMPOffloadMappingFlags baseFlag =
4992 (targetDirective == TargetDirectiveEnumTy::Target &&
4993 !mapData.IsDeclareTarget[mapDataIndex])
4994 ? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM
4995 : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
4996
4997 if (parentMapper) {
4998 using mapFlags = llvm::omp::OpenMPOffloadMappingFlags;
4999 // Preserve relevant map-type bits from the parent clause. These include
5000 // the copy direction (TO/FROM), as well as commonly used modifiers that
5001 // should be visible to the mapper for correct behaviour.
5002 mapFlags parentFlags = mapData.Types[mapDataIndex];
5003 mapFlags preserve = mapFlags::OMP_MAP_TO | mapFlags::OMP_MAP_FROM |
5004 mapFlags::OMP_MAP_ALWAYS | mapFlags::OMP_MAP_CLOSE |
5005 mapFlags::OMP_MAP_PRESENT | mapFlags::OMP_MAP_OMPX_HOLD;
5006 baseFlag |= (parentFlags & preserve);
5007 }
5008
5009 combinedInfo.Types.emplace_back(baseFlag);
5010 combinedInfo.DevicePointers.emplace_back(
5011 mapData.DevicePointers[mapDataIndex]);
5012 // Only attach the mapper to the base entry when we are mapping the whole
5013 // parent. Combined/segment entries must not carry a mapper; otherwise the
5014 // mapper can be invoked with a partial size, which is undefined behaviour.
5015 combinedInfo.Mappers.emplace_back(
5016 parentMapper && !parentClause.getPartialMap() ? parentMapper : nullptr);
5017 combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
5018 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
5019 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
5020
5021 // Calculate size of the parent object being mapped based on the
5022 // addresses at runtime, highAddr - lowAddr = size. This of course
5023 // doesn't factor in allocated data like pointers, hence the further
5024 // processing of members specified by users, or in the case of
5025 // Fortran pointers and allocatables, the mapping of the pointed to
5026 // data by the descriptor (which itself, is a structure containing
5027 // runtime information on the dynamically allocated data).
5028 llvm::Value *lowAddr, *highAddr;
5029 if (!parentClause.getPartialMap()) {
5030 lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
5031 builder.getPtrTy());
5032 highAddr = builder.CreatePointerCast(
5033 builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
5034 mapData.Pointers[mapDataIndex], 1),
5035 builder.getPtrTy());
5036 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
5037 } else {
5038 auto mapOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
5039 int firstMemberIdx = getMapDataMemberIdx(
5040 mapData, getFirstOrLastMappedMemberPtr(mapOp, true));
5041 lowAddr = builder.CreatePointerCast(mapData.Pointers[firstMemberIdx],
5042 builder.getPtrTy());
5043 int lastMemberIdx = getMapDataMemberIdx(
5044 mapData, getFirstOrLastMappedMemberPtr(mapOp, false));
5045 highAddr = builder.CreatePointerCast(
5046 builder.CreateGEP(mapData.BaseType[lastMemberIdx],
5047 mapData.Pointers[lastMemberIdx], builder.getInt64(1)),
5048 builder.getPtrTy());
5049 combinedInfo.Pointers.emplace_back(mapData.Pointers[firstMemberIdx]);
5050 }
5051
5052 llvm::Value *size = builder.CreateIntCast(
5053 builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
5054 builder.getInt64Ty(),
5055 /*isSigned=*/false);
5056 combinedInfo.Sizes.push_back(size);
5057
5058 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag =
5059 ompBuilder.getMemberOfFlag(combinedInfo.BasePointers.size() - 1);
5060
5061 // This creates the initial MEMBER_OF mapping that consists of
5062 // the parent/top level container (same as above effectively, except
5063 // with a fixed initial compile time size and separate maptype which
5064 // indicates the true mape type (tofrom etc.). This parent mapping is
5065 // only relevant if the structure in its totality is being mapped,
5066 // otherwise the above suffices.
5067 if (!parentClause.getPartialMap()) {
5068 // TODO: This will need to be expanded to include the whole host of logic
5069 // for the map flags that Clang currently supports (e.g. it should do some
5070 // further case specific flag modifications). For the moment, it handles
5071 // what we support as expected.
5072 llvm::omp::OpenMPOffloadMappingFlags mapFlag = mapData.Types[mapDataIndex];
5073 bool hasMapClose = (llvm::omp::OpenMPOffloadMappingFlags(mapFlag) &
5074 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE) ==
5075 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
5076 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
5077
5078 if (targetDirective == TargetDirectiveEnumTy::TargetUpdate || hasMapClose) {
5079 combinedInfo.Types.emplace_back(mapFlag);
5080 combinedInfo.DevicePointers.emplace_back(
5081 mapData.DevicePointers[mapDataIndex]);
5082 combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
5083 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
5084 combinedInfo.BasePointers.emplace_back(
5085 mapData.BasePointers[mapDataIndex]);
5086 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
5087 combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIndex]);
5088 combinedInfo.Mappers.emplace_back(nullptr);
5089 } else {
5090 llvm::SmallVector<size_t> overlapIdxs;
5091 // Find all of the members that "overlap", i.e. occlude other members that
5092 // were mapped alongside the parent, e.g. member [0], occludes [0,1] and
5093 // [0,2], but not [1,0].
5094 getOverlappedMembers(overlapIdxs, parentClause);
5095 // We need to make sure the overlapped members are sorted in order of
5096 // lowest address to highest address.
5097 sortMapIndices(overlapIdxs, parentClause);
5098
5099 lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
5100 builder.getPtrTy());
5101 highAddr = builder.CreatePointerCast(
5102 builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
5103 mapData.Pointers[mapDataIndex], 1),
5104 builder.getPtrTy());
5105
5106 // TODO: We may want to skip arrays/array sections in this as Clang does.
5107 // It appears to be an optimisation rather than a necessity though,
5108 // but this requires further investigation. However, we would have to make
5109 // sure to not exclude maps with bounds that ARE pointers, as these are
5110 // processed as separate components, i.e. pointer + data.
5111 for (auto v : overlapIdxs) {
5112 auto mapDataOverlapIdx = getMapDataMemberIdx(
5113 mapData,
5114 cast<omp::MapInfoOp>(parentClause.getMembers()[v].getDefiningOp()));
5115 combinedInfo.Types.emplace_back(mapFlag);
5116 combinedInfo.DevicePointers.emplace_back(
5117 mapData.DevicePointers[mapDataOverlapIdx]);
5118 combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
5119 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
5120 combinedInfo.BasePointers.emplace_back(
5121 mapData.BasePointers[mapDataIndex]);
5122 combinedInfo.Mappers.emplace_back(nullptr);
5123 combinedInfo.Pointers.emplace_back(lowAddr);
5124 combinedInfo.Sizes.emplace_back(builder.CreateIntCast(
5125 builder.CreatePtrDiff(builder.getInt8Ty(),
5126 mapData.OriginalValue[mapDataOverlapIdx],
5127 lowAddr),
5128 builder.getInt64Ty(), /*isSigned=*/true));
5129 lowAddr = builder.CreateConstGEP1_32(
5130 checkIfPointerMap(llvm::cast<omp::MapInfoOp>(
5131 mapData.MapClause[mapDataOverlapIdx]))
5132 ? builder.getPtrTy()
5133 : mapData.BaseType[mapDataOverlapIdx],
5134 mapData.BasePointers[mapDataOverlapIdx], 1);
5135 }
5136
5137 combinedInfo.Types.emplace_back(mapFlag);
5138 combinedInfo.DevicePointers.emplace_back(
5139 mapData.DevicePointers[mapDataIndex]);
5140 combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
5141 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
5142 combinedInfo.BasePointers.emplace_back(
5143 mapData.BasePointers[mapDataIndex]);
5144 combinedInfo.Mappers.emplace_back(nullptr);
5145 combinedInfo.Pointers.emplace_back(lowAddr);
5146 combinedInfo.Sizes.emplace_back(builder.CreateIntCast(
5147 builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
5148 builder.getInt64Ty(), true));
5149 }
5150 }
5151 return memberOfFlag;
5152}
5153
5154// This function is intended to add explicit mappings of members
5156 LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder,
5157 llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo,
5158 MapInfoData &mapData, uint64_t mapDataIndex,
5159 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag,
5160 TargetDirectiveEnumTy targetDirective) {
5161 assert(!ompBuilder.Config.isTargetDevice() &&
5162 "function only supported for host device codegen");
5163
5164 auto parentClause =
5165 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
5166
5167 for (auto mappedMembers : parentClause.getMembers()) {
5168 auto memberClause =
5169 llvm::cast<omp::MapInfoOp>(mappedMembers.getDefiningOp());
5170 int memberDataIdx = getMapDataMemberIdx(mapData, memberClause);
5171
5172 assert(memberDataIdx >= 0 && "could not find mapped member of structure");
5173
5174 // If we're currently mapping a pointer to a block of data, we must
5175 // initially map the pointer, and then attatch/bind the data with a
5176 // subsequent map to the pointer. This segment of code generates the
5177 // pointer mapping, which can in certain cases be optimised out as Clang
5178 // currently does in its lowering. However, for the moment we do not do so,
5179 // in part as we currently have substantially less information on the data
5180 // being mapped at this stage.
5181 if (checkIfPointerMap(memberClause)) {
5182 auto mapFlag = convertClauseMapFlags(memberClause.getMapType());
5183 mapFlag &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
5184 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
5185 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
5186 combinedInfo.Types.emplace_back(mapFlag);
5187 combinedInfo.DevicePointers.emplace_back(
5188 llvm::OpenMPIRBuilder::DeviceInfoTy::None);
5189 combinedInfo.Mappers.emplace_back(nullptr);
5190 combinedInfo.Names.emplace_back(
5191 LLVM::createMappingInformation(memberClause.getLoc(), ompBuilder));
5192 combinedInfo.BasePointers.emplace_back(
5193 mapData.BasePointers[mapDataIndex]);
5194 combinedInfo.Pointers.emplace_back(mapData.BasePointers[memberDataIdx]);
5195 combinedInfo.Sizes.emplace_back(builder.getInt64(
5196 moduleTranslation.getLLVMModule()->getDataLayout().getPointerSize()));
5197 }
5198
5199 // Same MemberOfFlag to indicate its link with parent and other members
5200 // of.
5201 auto mapFlag = convertClauseMapFlags(memberClause.getMapType());
5202 mapFlag &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
5203 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
5204 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
5205 bool isDeclTargetTo = isDeclareTargetTo(parentClause.getVarPtr()
5206 ? parentClause.getVarPtr()
5207 : parentClause.getVarPtrPtr());
5208 if (checkIfPointerMap(memberClause) &&
5209 (!isDeclTargetTo ||
5210 (targetDirective != TargetDirectiveEnumTy::TargetUpdate &&
5211 targetDirective != TargetDirectiveEnumTy::TargetData))) {
5212 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
5213 }
5214
5215 combinedInfo.Types.emplace_back(mapFlag);
5216 combinedInfo.DevicePointers.emplace_back(
5217 mapData.DevicePointers[memberDataIdx]);
5218 combinedInfo.Mappers.emplace_back(mapData.Mappers[memberDataIdx]);
5219 combinedInfo.Names.emplace_back(
5220 LLVM::createMappingInformation(memberClause.getLoc(), ompBuilder));
5221 uint64_t basePointerIndex =
5222 checkIfPointerMap(memberClause) ? memberDataIdx : mapDataIndex;
5223 combinedInfo.BasePointers.emplace_back(
5224 mapData.BasePointers[basePointerIndex]);
5225 combinedInfo.Pointers.emplace_back(mapData.Pointers[memberDataIdx]);
5226
5227 llvm::Value *size = mapData.Sizes[memberDataIdx];
5228 if (checkIfPointerMap(memberClause)) {
5229 size = builder.CreateSelect(
5230 builder.CreateIsNull(mapData.Pointers[memberDataIdx]),
5231 builder.getInt64(0), size);
5232 }
5233
5234 combinedInfo.Sizes.emplace_back(size);
5235 }
5236}
5237
5238static void processIndividualMap(MapInfoData &mapData, size_t mapDataIdx,
5239 MapInfosTy &combinedInfo,
5240 TargetDirectiveEnumTy targetDirective,
5241 int mapDataParentIdx = -1) {
5242 // Declare Target Mappings are excluded from being marked as
5243 // OMP_MAP_TARGET_PARAM as they are not passed as parameters, they're
5244 // marked with OMP_MAP_PTR_AND_OBJ instead.
5245 auto mapFlag = mapData.Types[mapDataIdx];
5246 auto mapInfoOp = llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIdx]);
5247
5248 bool isPtrTy = checkIfPointerMap(mapInfoOp);
5249 if (isPtrTy)
5250 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
5251
5252 if (targetDirective == TargetDirectiveEnumTy::Target &&
5253 !mapData.IsDeclareTarget[mapDataIdx])
5254 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
5255
5256 if (mapInfoOp.getMapCaptureType() == omp::VariableCaptureKind::ByCopy &&
5257 !isPtrTy)
5258 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
5259
5260 // if we're provided a mapDataParentIdx, then the data being mapped is
5261 // part of a larger object (in a parent <-> member mapping) and in this
5262 // case our BasePointer should be the parent.
5263 if (mapDataParentIdx >= 0)
5264 combinedInfo.BasePointers.emplace_back(
5265 mapData.BasePointers[mapDataParentIdx]);
5266 else
5267 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIdx]);
5268
5269 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIdx]);
5270 combinedInfo.DevicePointers.emplace_back(mapData.DevicePointers[mapDataIdx]);
5271 combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIdx]);
5272 combinedInfo.Names.emplace_back(mapData.Names[mapDataIdx]);
5273 combinedInfo.Types.emplace_back(mapFlag);
5274 combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIdx]);
5275}
5276
5278 llvm::IRBuilderBase &builder,
5279 llvm::OpenMPIRBuilder &ompBuilder,
5280 DataLayout &dl, MapInfosTy &combinedInfo,
5281 MapInfoData &mapData, uint64_t mapDataIndex,
5282 TargetDirectiveEnumTy targetDirective) {
5283 assert(!ompBuilder.Config.isTargetDevice() &&
5284 "function only supported for host device codegen");
5285
5286 auto parentClause =
5287 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
5288
5289 // If we have a partial map (no parent referenced in the map clauses of the
5290 // directive, only members) and only a single member, we do not need to bind
5291 // the map of the member to the parent, we can pass the member separately.
5292 if (parentClause.getMembers().size() == 1 && parentClause.getPartialMap()) {
5293 auto memberClause = llvm::cast<omp::MapInfoOp>(
5294 parentClause.getMembers()[0].getDefiningOp());
5295 int memberDataIdx = getMapDataMemberIdx(mapData, memberClause);
5296 // Note: Clang treats arrays with explicit bounds that fall into this
5297 // category as a parent with map case, however, it seems this isn't a
5298 // requirement, and processing them as an individual map is fine. So,
5299 // we will handle them as individual maps for the moment, as it's
5300 // difficult for us to check this as we always require bounds to be
5301 // specified currently and it's also marginally more optimal (single
5302 // map rather than two). The difference may come from the fact that
5303 // Clang maps array without bounds as pointers (which we do not
5304 // currently do), whereas we treat them as arrays in all cases
5305 // currently.
5306 processIndividualMap(mapData, memberDataIdx, combinedInfo, targetDirective,
5307 mapDataIndex);
5308 return;
5309 }
5310
5311 llvm::omp::OpenMPOffloadMappingFlags memberOfParentFlag =
5312 mapParentWithMembers(moduleTranslation, builder, ompBuilder, dl,
5313 combinedInfo, mapData, mapDataIndex,
5314 targetDirective);
5315 processMapMembersWithParent(moduleTranslation, builder, ompBuilder, dl,
5316 combinedInfo, mapData, mapDataIndex,
5317 memberOfParentFlag, targetDirective);
5318}
5319
5320// This is a variation on Clang's GenerateOpenMPCapturedVars, which
5321// generates different operation (e.g. load/store) combinations for
5322// arguments to the kernel, based on map capture kinds which are then
5323// utilised in the combinedInfo in place of the original Map value.
5324static void
5325createAlteredByCaptureMap(MapInfoData &mapData,
5326 LLVM::ModuleTranslation &moduleTranslation,
5327 llvm::IRBuilderBase &builder) {
5328 assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
5329 "function only supported for host device codegen");
5330 for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
5331 // if it's declare target, skip it, it's handled separately.
5332 if (!mapData.IsDeclareTarget[i]) {
5333 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
5334 omp::VariableCaptureKind captureKind = mapOp.getMapCaptureType();
5335 bool isPtrTy = checkIfPointerMap(mapOp);
5336
5337 // Currently handles array sectioning lowerbound case, but more
5338 // logic may be required in the future. Clang invokes EmitLValue,
5339 // which has specialised logic for special Clang types such as user
5340 // defines, so it is possible we will have to extend this for
5341 // structures or other complex types. As the general idea is that this
5342 // function mimics some of the logic from Clang that we require for
5343 // kernel argument passing from host -> device.
5344 switch (captureKind) {
5345 case omp::VariableCaptureKind::ByRef: {
5346 llvm::Value *newV = mapData.Pointers[i];
5347 std::vector<llvm::Value *> offsetIdx = calculateBoundsOffset(
5348 moduleTranslation, builder, mapData.BaseType[i]->isArrayTy(),
5349 mapOp.getBounds());
5350 if (isPtrTy)
5351 newV = builder.CreateLoad(builder.getPtrTy(), newV);
5352
5353 if (!offsetIdx.empty())
5354 newV = builder.CreateInBoundsGEP(mapData.BaseType[i], newV, offsetIdx,
5355 "array_offset");
5356 mapData.Pointers[i] = newV;
5357 } break;
5358 case omp::VariableCaptureKind::ByCopy: {
5359 llvm::Type *type = mapData.BaseType[i];
5360 llvm::Value *newV;
5361 if (mapData.Pointers[i]->getType()->isPointerTy())
5362 newV = builder.CreateLoad(type, mapData.Pointers[i]);
5363 else
5364 newV = mapData.Pointers[i];
5365
5366 if (!isPtrTy) {
5367 auto curInsert = builder.saveIP();
5368 llvm::DebugLoc DbgLoc = builder.getCurrentDebugLocation();
5369 builder.restoreIP(findAllocaInsertPoint(builder, moduleTranslation));
5370 auto *memTempAlloc =
5371 builder.CreateAlloca(builder.getPtrTy(), nullptr, ".casted");
5372 builder.SetCurrentDebugLocation(DbgLoc);
5373 builder.restoreIP(curInsert);
5374
5375 builder.CreateStore(newV, memTempAlloc);
5376 newV = builder.CreateLoad(builder.getPtrTy(), memTempAlloc);
5377 }
5378
5379 mapData.Pointers[i] = newV;
5380 mapData.BasePointers[i] = newV;
5381 } break;
5382 case omp::VariableCaptureKind::This:
5383 case omp::VariableCaptureKind::VLAType:
5384 mapData.MapClause[i]->emitOpError("Unhandled capture kind");
5385 break;
5386 }
5387 }
5388 }
5389}
5390
5391// Generate all map related information and fill the combinedInfo.
5392static void genMapInfos(llvm::IRBuilderBase &builder,
5393 LLVM::ModuleTranslation &moduleTranslation,
5394 DataLayout &dl, MapInfosTy &combinedInfo,
5395 MapInfoData &mapData,
5396 TargetDirectiveEnumTy targetDirective) {
5397 assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
5398 "function only supported for host device codegen");
5399 // We wish to modify some of the methods in which arguments are
5400 // passed based on their capture type by the target region, this can
5401 // involve generating new loads and stores, which changes the
5402 // MLIR value to LLVM value mapping, however, we only wish to do this
5403 // locally for the current function/target and also avoid altering
5404 // ModuleTranslation, so we remap the base pointer or pointer stored
5405 // in the map infos corresponding MapInfoData, which is later accessed
5406 // by genMapInfos and createTarget to help generate the kernel and
5407 // kernel arg structure. It primarily becomes relevant in cases like
5408 // bycopy, or byref range'd arrays. In the default case, we simply
5409 // pass thee pointer byref as both basePointer and pointer.
5410 createAlteredByCaptureMap(mapData, moduleTranslation, builder);
5411
5412 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
5413
5414 // We operate under the assumption that all vectors that are
5415 // required in MapInfoData are of equal lengths (either filled with
5416 // default constructed data or appropiate information) so we can
5417 // utilise the size from any component of MapInfoData, if we can't
5418 // something is missing from the initial MapInfoData construction.
5419 for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
5420 // NOTE/TODO: We currently do not support arbitrary depth record
5421 // type mapping.
5422 if (mapData.IsAMember[i])
5423 continue;
5424
5425 auto mapInfoOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[i]);
5426 if (!mapInfoOp.getMembers().empty()) {
5427 processMapWithMembersOf(moduleTranslation, builder, *ompBuilder, dl,
5428 combinedInfo, mapData, i, targetDirective);
5429 continue;
5430 }
5431
5432 processIndividualMap(mapData, i, combinedInfo, targetDirective);
5433 }
5434}
5435
5436static llvm::Expected<llvm::Function *>
5437emitUserDefinedMapper(Operation *declMapperOp, llvm::IRBuilderBase &builder,
5438 LLVM::ModuleTranslation &moduleTranslation,
5439 llvm::StringRef mapperFuncName,
5440 TargetDirectiveEnumTy targetDirective);
5441
5442static llvm::Expected<llvm::Function *>
5443getOrCreateUserDefinedMapperFunc(Operation *op, llvm::IRBuilderBase &builder,
5444 LLVM::ModuleTranslation &moduleTranslation,
5445 TargetDirectiveEnumTy targetDirective) {
5446 assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
5447 "function only supported for host device codegen");
5448 auto declMapperOp = cast<omp::DeclareMapperOp>(op);
5449 std::string mapperFuncName =
5450 moduleTranslation.getOpenMPBuilder()->createPlatformSpecificName(
5451 {"omp_mapper", declMapperOp.getSymName()});
5452
5453 if (auto *lookupFunc = moduleTranslation.lookupFunction(mapperFuncName))
5454 return lookupFunc;
5455
5456 // Recursive types can cause re-entrant mapper emission. The mapper function
5457 // is created by OpenMPIRBuilder before the callbacks run, so it may already
5458 // exist in the LLVM module even though it is not yet registered in the
5459 // ModuleTranslation mapping table. Reuse and register it to break the
5460 // recursion.
5461 if (llvm::Function *existingFunc =
5462 moduleTranslation.getLLVMModule()->getFunction(mapperFuncName)) {
5463 moduleTranslation.mapFunction(mapperFuncName, existingFunc);
5464 return existingFunc;
5465 }
5466
5467 return emitUserDefinedMapper(declMapperOp, builder, moduleTranslation,
5468 mapperFuncName, targetDirective);
5469}
5470
5471static llvm::Expected<llvm::Function *>
5472emitUserDefinedMapper(Operation *op, llvm::IRBuilderBase &builder,
5473 LLVM::ModuleTranslation &moduleTranslation,
5474 llvm::StringRef mapperFuncName,
5475 TargetDirectiveEnumTy targetDirective) {
5476 assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
5477 "function only supported for host device codegen");
5478 auto declMapperOp = cast<omp::DeclareMapperOp>(op);
5479 auto declMapperInfoOp = declMapperOp.getDeclareMapperInfo();
5480 DataLayout dl = DataLayout(declMapperOp->getParentOfType<ModuleOp>());
5481 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
5482 llvm::Type *varType = moduleTranslation.convertType(declMapperOp.getType());
5483 SmallVector<Value> mapVars = declMapperInfoOp.getMapVars();
5484
5485 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
5486
5487 // Fill up the arrays with all the mapped variables.
5488 MapInfosTy combinedInfo;
5489 auto genMapInfoCB =
5490 [&](InsertPointTy codeGenIP, llvm::Value *ptrPHI,
5491 llvm::Value *unused2) -> llvm::OpenMPIRBuilder::MapInfosOrErrorTy {
5492 builder.restoreIP(codeGenIP);
5493 moduleTranslation.mapValue(declMapperOp.getSymVal(), ptrPHI);
5494 moduleTranslation.mapBlock(&declMapperOp.getRegion().front(),
5495 builder.GetInsertBlock());
5496 if (failed(moduleTranslation.convertBlock(declMapperOp.getRegion().front(),
5497 /*ignoreArguments=*/true,
5498 builder)))
5499 return llvm::make_error<PreviouslyReportedError>();
5500 MapInfoData mapData;
5501 collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, dl,
5502 builder);
5503 genMapInfos(builder, moduleTranslation, dl, combinedInfo, mapData,
5504 targetDirective);
5505
5506 // Drop the mapping that is no longer necessary so that the same region
5507 // can be processed multiple times.
5508 moduleTranslation.forgetMapping(declMapperOp.getRegion());
5509 return combinedInfo;
5510 };
5511
5512 auto customMapperCB = [&](unsigned i) -> llvm::Expected<llvm::Function *> {
5513 if (!combinedInfo.Mappers[i])
5514 return nullptr;
5515 return getOrCreateUserDefinedMapperFunc(combinedInfo.Mappers[i], builder,
5516 moduleTranslation, targetDirective);
5517 };
5518
5519 llvm::Expected<llvm::Function *> newFn = ompBuilder->emitUserDefinedMapper(
5520 genMapInfoCB, varType, mapperFuncName, customMapperCB);
5521 if (!newFn)
5522 return newFn.takeError();
5523 if ([[maybe_unused]] llvm::Function *mappedFunc =
5524 moduleTranslation.lookupFunction(mapperFuncName)) {
5525 assert(mappedFunc == *newFn &&
5526 "mapper function mapping disagrees with emitted function");
5527 } else {
5528 moduleTranslation.mapFunction(mapperFuncName, *newFn);
5529 }
5530 return *newFn;
5531}
5532
5533static LogicalResult
5534convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
5535 LLVM::ModuleTranslation &moduleTranslation) {
5536 llvm::Value *ifCond = nullptr;
5537 llvm::Value *deviceID = builder.getInt64(llvm::omp::OMP_DEVICEID_UNDEF);
5538 SmallVector<Value> mapVars;
5539 SmallVector<Value> useDevicePtrVars;
5540 SmallVector<Value> useDeviceAddrVars;
5541 llvm::omp::RuntimeFunction RTLFn;
5542 DataLayout DL = DataLayout(op->getParentOfType<ModuleOp>());
5543 TargetDirectiveEnumTy targetDirective = getTargetDirectiveEnumTyFromOp(op);
5544
5545 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
5546 llvm::OpenMPIRBuilder::TargetDataInfo info(
5547 /*RequiresDevicePointerInfo=*/true,
5548 /*SeparateBeginEndCalls=*/true);
5549 assert(!ompBuilder->Config.isTargetDevice() &&
5550 "target data/enter/exit/update are host ops");
5551 bool isOffloadEntry = !ompBuilder->Config.TargetTriples.empty();
5552
5553 auto getDeviceID = [&](mlir::Value dev) -> llvm::Value * {
5554 llvm::Value *v = moduleTranslation.lookupValue(dev);
5555 return builder.CreateIntCast(v, builder.getInt64Ty(), /*isSigned=*/true);
5556 };
5557
5558 LogicalResult result =
5560 .Case([&](omp::TargetDataOp dataOp) {
5561 if (failed(checkImplementationStatus(*dataOp)))
5562 return failure();
5563
5564 if (auto ifVar = dataOp.getIfExpr())
5565 ifCond = moduleTranslation.lookupValue(ifVar);
5566
5567 if (mlir::Value devId = dataOp.getDevice())
5568 deviceID = getDeviceID(devId);
5569
5570 mapVars = dataOp.getMapVars();
5571 useDevicePtrVars = dataOp.getUseDevicePtrVars();
5572 useDeviceAddrVars = dataOp.getUseDeviceAddrVars();
5573 return success();
5574 })
5575 .Case([&](omp::TargetEnterDataOp enterDataOp) -> LogicalResult {
5576 if (failed(checkImplementationStatus(*enterDataOp)))
5577 return failure();
5578
5579 if (auto ifVar = enterDataOp.getIfExpr())
5580 ifCond = moduleTranslation.lookupValue(ifVar);
5581
5582 if (mlir::Value devId = enterDataOp.getDevice())
5583 deviceID = getDeviceID(devId);
5584
5585 RTLFn =
5586 enterDataOp.getNowait()
5587 ? llvm::omp::OMPRTL___tgt_target_data_begin_nowait_mapper
5588 : llvm::omp::OMPRTL___tgt_target_data_begin_mapper;
5589 mapVars = enterDataOp.getMapVars();
5590 info.HasNoWait = enterDataOp.getNowait();
5591 return success();
5592 })
5593 .Case([&](omp::TargetExitDataOp exitDataOp) -> LogicalResult {
5594 if (failed(checkImplementationStatus(*exitDataOp)))
5595 return failure();
5596
5597 if (auto ifVar = exitDataOp.getIfExpr())
5598 ifCond = moduleTranslation.lookupValue(ifVar);
5599
5600 if (mlir::Value devId = exitDataOp.getDevice())
5601 deviceID = getDeviceID(devId);
5602
5603 RTLFn = exitDataOp.getNowait()
5604 ? llvm::omp::OMPRTL___tgt_target_data_end_nowait_mapper
5605 : llvm::omp::OMPRTL___tgt_target_data_end_mapper;
5606 mapVars = exitDataOp.getMapVars();
5607 info.HasNoWait = exitDataOp.getNowait();
5608 return success();
5609 })
5610 .Case([&](omp::TargetUpdateOp updateDataOp) -> LogicalResult {
5611 if (failed(checkImplementationStatus(*updateDataOp)))
5612 return failure();
5613
5614 if (auto ifVar = updateDataOp.getIfExpr())
5615 ifCond = moduleTranslation.lookupValue(ifVar);
5616
5617 if (mlir::Value devId = updateDataOp.getDevice())
5618 deviceID = getDeviceID(devId);
5619
5620 RTLFn =
5621 updateDataOp.getNowait()
5622 ? llvm::omp::OMPRTL___tgt_target_data_update_nowait_mapper
5623 : llvm::omp::OMPRTL___tgt_target_data_update_mapper;
5624 mapVars = updateDataOp.getMapVars();
5625 info.HasNoWait = updateDataOp.getNowait();
5626 return success();
5627 })
5628 .DefaultUnreachable("unexpected operation");
5629
5630 if (failed(result))
5631 return failure();
5632 // Pretend we have IF(false) if we're not doing offload.
5633 if (!isOffloadEntry)
5634 ifCond = builder.getFalse();
5635
5636 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
5637 MapInfoData mapData;
5638 collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, DL,
5639 builder, useDevicePtrVars, useDeviceAddrVars);
5640
5641 // Fill up the arrays with all the mapped variables.
5642 MapInfosTy combinedInfo;
5643 auto genMapInfoCB = [&](InsertPointTy codeGenIP) -> MapInfosTy & {
5644 builder.restoreIP(codeGenIP);
5645 genMapInfos(builder, moduleTranslation, DL, combinedInfo, mapData,
5646 targetDirective);
5647 return combinedInfo;
5648 };
5649
5650 // Define a lambda to apply mappings between use_device_addr and
5651 // use_device_ptr base pointers, and their associated block arguments.
5652 auto mapUseDevice =
5653 [&moduleTranslation](
5654 llvm::OpenMPIRBuilder::DeviceInfoTy type,
5656 llvm::SmallVectorImpl<Value> &useDeviceVars, MapInfoData &mapInfoData,
5657 llvm::function_ref<llvm::Value *(llvm::Value *)> mapper = nullptr) {
5658 for (auto [arg, useDevVar] :
5659 llvm::zip_equal(blockArgs, useDeviceVars)) {
5660
5661 auto getMapBasePtr = [](omp::MapInfoOp mapInfoOp) {
5662 return mapInfoOp.getVarPtrPtr() ? mapInfoOp.getVarPtrPtr()
5663 : mapInfoOp.getVarPtr();
5664 };
5665
5666 auto useDevMap = cast<omp::MapInfoOp>(useDevVar.getDefiningOp());
5667 for (auto [mapClause, devicePointer, basePointer] : llvm::zip_equal(
5668 mapInfoData.MapClause, mapInfoData.DevicePointers,
5669 mapInfoData.BasePointers)) {
5670 auto mapOp = cast<omp::MapInfoOp>(mapClause);
5671 if (getMapBasePtr(mapOp) != getMapBasePtr(useDevMap) ||
5672 devicePointer != type)
5673 continue;
5674
5675 if (llvm::Value *devPtrInfoMap =
5676 mapper ? mapper(basePointer) : basePointer) {
5677 moduleTranslation.mapValue(arg, devPtrInfoMap);
5678 break;
5679 }
5680 }
5681 }
5682 };
5683
5684 using BodyGenTy = llvm::OpenMPIRBuilder::BodyGenTy;
5685 auto bodyGenCB = [&](InsertPointTy codeGenIP, BodyGenTy bodyGenType)
5686 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
5687 // We must always restoreIP regardless of doing anything the caller
5688 // does not restore it, leading to incorrect (no) branch generation.
5689 builder.restoreIP(codeGenIP);
5690 assert(isa<omp::TargetDataOp>(op) &&
5691 "BodyGen requested for non TargetDataOp");
5692 auto blockArgIface = cast<omp::BlockArgOpenMPOpInterface>(op);
5693 Region &region = cast<omp::TargetDataOp>(op).getRegion();
5694 switch (bodyGenType) {
5695 case BodyGenTy::Priv:
5696 // Check if any device ptr/addr info is available
5697 if (!info.DevicePtrInfoMap.empty()) {
5698 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
5699 blockArgIface.getUseDeviceAddrBlockArgs(),
5700 useDeviceAddrVars, mapData,
5701 [&](llvm::Value *basePointer) -> llvm::Value * {
5702 if (!info.DevicePtrInfoMap[basePointer].second)
5703 return nullptr;
5704 return builder.CreateLoad(
5705 builder.getPtrTy(),
5706 info.DevicePtrInfoMap[basePointer].second);
5707 });
5708 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
5709 blockArgIface.getUseDevicePtrBlockArgs(), useDevicePtrVars,
5710 mapData, [&](llvm::Value *basePointer) {
5711 return info.DevicePtrInfoMap[basePointer].second;
5712 });
5713
5714 if (failed(inlineConvertOmpRegions(region, "omp.data.region", builder,
5715 moduleTranslation)))
5716 return llvm::make_error<PreviouslyReportedError>();
5717 }
5718 break;
5719 case BodyGenTy::DupNoPriv:
5720 if (info.DevicePtrInfoMap.empty()) {
5721 // For host device we still need to do the mapping for codegen,
5722 // otherwise it may try to lookup a missing value.
5723 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
5724 blockArgIface.getUseDeviceAddrBlockArgs(),
5725 useDeviceAddrVars, mapData);
5726 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
5727 blockArgIface.getUseDevicePtrBlockArgs(), useDevicePtrVars,
5728 mapData);
5729 }
5730 break;
5731 case BodyGenTy::NoPriv:
5732 // If device info is available then region has already been generated
5733 if (info.DevicePtrInfoMap.empty()) {
5734 if (failed(inlineConvertOmpRegions(region, "omp.data.region", builder,
5735 moduleTranslation)))
5736 return llvm::make_error<PreviouslyReportedError>();
5737 }
5738 break;
5739 }
5740 return builder.saveIP();
5741 };
5742
5743 auto customMapperCB =
5744 [&](unsigned int i) -> llvm::Expected<llvm::Function *> {
5745 if (!combinedInfo.Mappers[i])
5746 return nullptr;
5747 info.HasMapper = true;
5748 return getOrCreateUserDefinedMapperFunc(combinedInfo.Mappers[i], builder,
5749 moduleTranslation, targetDirective);
5750 };
5751
5752 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
5753 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
5754 findAllocaInsertPoint(builder, moduleTranslation);
5755 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = [&]() {
5756 if (isa<omp::TargetDataOp>(op))
5757 return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(),
5758 deviceID, ifCond, info, genMapInfoCB,
5759 customMapperCB,
5760 /*MapperFunc=*/nullptr, bodyGenCB,
5761 /*DeviceAddrCB=*/nullptr);
5762 return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(),
5763 deviceID, ifCond, info, genMapInfoCB,
5764 customMapperCB, &RTLFn);
5765 }();
5766
5767 if (failed(handleError(afterIP, *op)))
5768 return failure();
5769
5770 builder.restoreIP(*afterIP);
5771 return success();
5772}
5773
5774static LogicalResult
5775convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder,
5776 LLVM::ModuleTranslation &moduleTranslation) {
5777 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
5778 auto distributeOp = cast<omp::DistributeOp>(opInst);
5779 if (failed(checkImplementationStatus(opInst)))
5780 return failure();
5781
5782 /// Process teams op reduction in distribute if the reduction is contained in
5783 /// the distribute op.
5784 omp::TeamsOp teamsOp = opInst.getParentOfType<omp::TeamsOp>();
5785 bool doDistributeReduction =
5786 teamsOp ? teamsReductionContainedInDistribute(teamsOp) : false;
5787
5788 DenseMap<Value, llvm::Value *> reductionVariableMap;
5789 unsigned numReductionVars = teamsOp ? teamsOp.getNumReductionVars() : 0;
5791 SmallVector<llvm::Value *> privateReductionVariables(numReductionVars);
5792 llvm::ArrayRef<bool> isByRef;
5793
5794 if (doDistributeReduction) {
5795 isByRef = getIsByRef(teamsOp.getReductionByref());
5796 assert(isByRef.size() == teamsOp.getNumReductionVars());
5797
5798 collectReductionDecls(teamsOp, reductionDecls);
5799 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
5800 findAllocaInsertPoint(builder, moduleTranslation);
5801
5802 MutableArrayRef<BlockArgument> reductionArgs =
5803 llvm::cast<omp::BlockArgOpenMPOpInterface>(*teamsOp)
5804 .getReductionBlockArgs();
5805
5807 teamsOp, reductionArgs, builder, moduleTranslation, allocaIP,
5808 reductionDecls, privateReductionVariables, reductionVariableMap,
5809 isByRef)))
5810 return failure();
5811 }
5812
5813 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
5814 auto bodyGenCB = [&](InsertPointTy allocaIP,
5815 InsertPointTy codeGenIP) -> llvm::Error {
5816 // Save the alloca insertion point on ModuleTranslation stack for use in
5817 // nested regions.
5819 moduleTranslation, allocaIP);
5820
5821 // DistributeOp has only one region associated with it.
5822 builder.restoreIP(codeGenIP);
5823 PrivateVarsInfo privVarsInfo(distributeOp);
5824
5826 allocatePrivateVars(builder, moduleTranslation, privVarsInfo, allocaIP);
5827 if (handleError(afterAllocas, opInst).failed())
5828 return llvm::make_error<PreviouslyReportedError>();
5829
5830 if (handleError(initPrivateVars(builder, moduleTranslation, privVarsInfo),
5831 opInst)
5832 .failed())
5833 return llvm::make_error<PreviouslyReportedError>();
5834
5835 if (failed(copyFirstPrivateVars(
5836 distributeOp, builder, moduleTranslation, privVarsInfo.mlirVars,
5837 privVarsInfo.llvmVars, privVarsInfo.privatizers,
5838 distributeOp.getPrivateNeedsBarrier())))
5839 return llvm::make_error<PreviouslyReportedError>();
5840
5841 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
5842 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
5844 convertOmpOpRegions(distributeOp.getRegion(), "omp.distribute.region",
5845 builder, moduleTranslation);
5846 if (!regionBlock)
5847 return regionBlock.takeError();
5848 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
5849
5850 // Skip applying a workshare loop below when translating 'distribute
5851 // parallel do' (it's been already handled by this point while translating
5852 // the nested omp.wsloop).
5853 if (!isa_and_present<omp::WsloopOp>(distributeOp.getNestedWrapper())) {
5854 // TODO: Add support for clauses which are valid for DISTRIBUTE
5855 // constructs. Static schedule is the default.
5856 bool hasDistSchedule = distributeOp.getDistScheduleStatic();
5857 auto schedule = hasDistSchedule ? omp::ClauseScheduleKind::Distribute
5858 : omp::ClauseScheduleKind::Static;
5859 // dist_schedule clauses are ordered - otherise this should be false
5860 bool isOrdered = hasDistSchedule;
5861 std::optional<omp::ScheduleModifier> scheduleMod;
5862 bool isSimd = false;
5863 llvm::omp::WorksharingLoopType workshareLoopType =
5864 llvm::omp::WorksharingLoopType::DistributeStaticLoop;
5865 bool loopNeedsBarrier = false;
5866 llvm::Value *chunk = moduleTranslation.lookupValue(
5867 distributeOp.getDistScheduleChunkSize());
5868 llvm::CanonicalLoopInfo *loopInfo =
5869 findCurrentLoopInfo(moduleTranslation);
5870 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
5871 ompBuilder->applyWorkshareLoop(
5872 ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
5873 convertToScheduleKind(schedule), chunk, isSimd,
5874 scheduleMod == omp::ScheduleModifier::monotonic,
5875 scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
5876 workshareLoopType, false, hasDistSchedule, chunk);
5877
5878 if (!wsloopIP)
5879 return wsloopIP.takeError();
5880 }
5881 if (failed(cleanupPrivateVars(builder, moduleTranslation,
5882 distributeOp.getLoc(), privVarsInfo.llvmVars,
5883 privVarsInfo.privatizers)))
5884 return llvm::make_error<PreviouslyReportedError>();
5885
5886 return llvm::Error::success();
5887 };
5888
5889 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
5890 findAllocaInsertPoint(builder, moduleTranslation);
5891 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
5892 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
5893 ompBuilder->createDistribute(ompLoc, allocaIP, bodyGenCB);
5894
5895 if (failed(handleError(afterIP, opInst)))
5896 return failure();
5897
5898 builder.restoreIP(*afterIP);
5899
5900 if (doDistributeReduction) {
5901 // Process the reductions if required.
5903 teamsOp, builder, moduleTranslation, allocaIP, reductionDecls,
5904 privateReductionVariables, isByRef,
5905 /*isNoWait*/ false, /*isTeamsReduction*/ true);
5906 }
5907 return success();
5908}
5909
5910/// Lowers the FlagsAttr which is applied to the module on the device
5911/// pass when offloading, this attribute contains OpenMP RTL globals that can
5912/// be passed as flags to the frontend, otherwise they are set to default
5913static LogicalResult
5914convertFlagsAttr(Operation *op, mlir::omp::FlagsAttr attribute,
5915 LLVM::ModuleTranslation &moduleTranslation) {
5916 if (!cast<mlir::ModuleOp>(op))
5917 return failure();
5918
5919 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
5920
5921 ompBuilder->M.addModuleFlag(llvm::Module::Max, "openmp-device",
5922 attribute.getOpenmpDeviceVersion());
5923
5924 if (attribute.getNoGpuLib())
5925 return success();
5926
5927 ompBuilder->createGlobalFlag(
5928 attribute.getDebugKind() /*LangOpts().OpenMPTargetDebug*/,
5929 "__omp_rtl_debug_kind");
5930 ompBuilder->createGlobalFlag(
5931 attribute
5932 .getAssumeTeamsOversubscription() /*LangOpts().OpenMPTeamSubscription*/
5933 ,
5934 "__omp_rtl_assume_teams_oversubscription");
5935 ompBuilder->createGlobalFlag(
5936 attribute
5937 .getAssumeThreadsOversubscription() /*LangOpts().OpenMPThreadSubscription*/
5938 ,
5939 "__omp_rtl_assume_threads_oversubscription");
5940 ompBuilder->createGlobalFlag(
5941 attribute.getAssumeNoThreadState() /*LangOpts().OpenMPNoThreadState*/,
5942 "__omp_rtl_assume_no_thread_state");
5943 ompBuilder->createGlobalFlag(
5944 attribute
5945 .getAssumeNoNestedParallelism() /*LangOpts().OpenMPNoNestedParallelism*/
5946 ,
5947 "__omp_rtl_assume_no_nested_parallelism");
5948 return success();
5949}
5950
5951static void getTargetEntryUniqueInfo(llvm::TargetRegionEntryInfo &targetInfo,
5952 omp::TargetOp targetOp,
5953 llvm::StringRef parentName = "") {
5954 auto fileLoc = targetOp.getLoc()->findInstanceOf<FileLineColLoc>();
5955
5956 assert(fileLoc && "No file found from location");
5957 StringRef fileName = fileLoc.getFilename().getValue();
5958
5959 llvm::sys::fs::UniqueID id;
5960 uint64_t line = fileLoc.getLine();
5961 if (auto ec = llvm::sys::fs::getUniqueID(fileName, id)) {
5962 size_t fileHash = llvm::hash_value(fileName.str());
5963 size_t deviceId = 0xdeadf17e;
5964 targetInfo =
5965 llvm::TargetRegionEntryInfo(parentName, deviceId, fileHash, line);
5966 } else {
5967 targetInfo = llvm::TargetRegionEntryInfo(parentName, id.getDevice(),
5968 id.getFile(), line);
5969 }
5970}
5971
5972static void
5973handleDeclareTargetMapVar(MapInfoData &mapData,
5974 LLVM::ModuleTranslation &moduleTranslation,
5975 llvm::IRBuilderBase &builder, llvm::Function *func) {
5976 assert(moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
5977 "function only supported for target device codegen");
5978 llvm::IRBuilderBase::InsertPointGuard guard(builder);
5979 for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
5980 // In the case of declare target mapped variables, the basePointer is
5981 // the reference pointer generated by the convertDeclareTargetAttr
5982 // method. Whereas the kernelValue is the original variable, so for
5983 // the device we must replace all uses of this original global variable
5984 // (stored in kernelValue) with the reference pointer (stored in
5985 // basePointer for declare target mapped variables), as for device the
5986 // data is mapped into this reference pointer and should be loaded
5987 // from it, the original variable is discarded. On host both exist and
5988 // metadata is generated (elsewhere in the convertDeclareTargetAttr)
5989 // function to link the two variables in the runtime and then both the
5990 // reference pointer and the pointer are assigned in the kernel argument
5991 // structure for the host.
5992 if (mapData.IsDeclareTarget[i]) {
5993 // If the original map value is a constant, then we have to make sure all
5994 // of it's uses within the current kernel/function that we are going to
5995 // rewrite are converted to instructions, as we will be altering the old
5996 // use (OriginalValue) from a constant to an instruction, which will be
5997 // illegal and ICE the compiler if the user is a constant expression of
5998 // some kind e.g. a constant GEP.
5999 if (auto *constant = dyn_cast<llvm::Constant>(mapData.OriginalValue[i]))
6000 convertUsersOfConstantsToInstructions(constant, func, false);
6001
6002 // The users iterator will get invalidated if we modify an element,
6003 // so we populate this vector of uses to alter each user on an
6004 // individual basis to emit its own load (rather than one load for
6005 // all).
6007 for (llvm::User *user : mapData.OriginalValue[i]->users())
6008 userVec.push_back(user);
6009
6010 for (llvm::User *user : userVec) {
6011 if (auto *insn = dyn_cast<llvm::Instruction>(user)) {
6012 if (insn->getFunction() == func) {
6013 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
6014 llvm::Value *substitute = mapData.BasePointers[i];
6015 if (isDeclareTargetLink(mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr()
6016 : mapOp.getVarPtr())) {
6017 builder.SetCurrentDebugLocation(insn->getDebugLoc());
6018 substitute = builder.CreateLoad(
6019 mapData.BasePointers[i]->getType(), mapData.BasePointers[i]);
6020 cast<llvm::LoadInst>(substitute)->moveBefore(insn->getIterator());
6021 }
6022 user->replaceUsesOfWith(mapData.OriginalValue[i], substitute);
6023 }
6024 }
6025 }
6026 }
6027 }
6028}
6029
6030// The createDeviceArgumentAccessor function generates
6031// instructions for retrieving (acessing) kernel
6032// arguments inside of the device kernel for use by
6033// the kernel. This enables different semantics such as
6034// the creation of temporary copies of data allowing
6035// semantics like read-only/no host write back kernel
6036// arguments.
6037//
6038// This currently implements a very light version of Clang's
6039// EmitParmDecl's handling of direct argument handling as well
6040// as a portion of the argument access generation based on
6041// capture types found at the end of emitOutlinedFunctionPrologue
6042// in Clang. The indirect path handling of EmitParmDecl's may be
6043// required for future work, but a direct 1-to-1 copy doesn't seem
6044// possible as the logic is rather scattered throughout Clang's
6045// lowering and perhaps we wish to deviate slightly.
6046//
6047// \param mapData - A container containing vectors of information
6048// corresponding to the input argument, which should have a
6049// corresponding entry in the MapInfoData containers
6050// OrigialValue's.
6051// \param arg - This is the generated kernel function argument that
6052// corresponds to the passed in input argument. We generated different
6053// accesses of this Argument, based on capture type and other Input
6054// related information.
6055// \param input - This is the host side value that will be passed to
6056// the kernel i.e. the kernel input, we rewrite all uses of this within
6057// the kernel (as we generate the kernel body based on the target's region
6058// which maintians references to the original input) to the retVal argument
6059// apon exit of this function inside of the OMPIRBuilder. This interlinks
6060// the kernel argument to future uses of it in the function providing
6061// appropriate "glue" instructions inbetween.
6062// \param retVal - This is the value that all uses of input inside of the
6063// kernel will be re-written to, the goal of this function is to generate
6064// an appropriate location for the kernel argument to be accessed from,
6065// e.g. ByRef will result in a temporary allocation location and then
6066// a store of the kernel argument into this allocated memory which
6067// will then be loaded from, ByCopy will use the allocated memory
6068// directly.
6069static llvm::IRBuilderBase::InsertPoint
6070createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg,
6071 llvm::Value *input, llvm::Value *&retVal,
6072 llvm::IRBuilderBase &builder,
6073 llvm::OpenMPIRBuilder &ompBuilder,
6074 LLVM::ModuleTranslation &moduleTranslation,
6075 llvm::IRBuilderBase::InsertPoint allocaIP,
6076 llvm::IRBuilderBase::InsertPoint codeGenIP) {
6077 assert(ompBuilder.Config.isTargetDevice() &&
6078 "function only supported for target device codegen");
6079 builder.restoreIP(allocaIP);
6080
6081 omp::VariableCaptureKind capture = omp::VariableCaptureKind::ByRef;
6082 LLVM::TypeToLLVMIRTranslator typeToLLVMIRTranslator(
6083 ompBuilder.M.getContext());
6084 unsigned alignmentValue = 0;
6085 // Find the associated MapInfoData entry for the current input
6086 for (size_t i = 0; i < mapData.MapClause.size(); ++i)
6087 if (mapData.OriginalValue[i] == input) {
6088 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
6089 capture = mapOp.getMapCaptureType();
6090 // Get information of alignment of mapped object
6091 alignmentValue = typeToLLVMIRTranslator.getPreferredAlignment(
6092 mapOp.getVarType(), ompBuilder.M.getDataLayout());
6093 break;
6094 }
6095
6096 unsigned int allocaAS = ompBuilder.M.getDataLayout().getAllocaAddrSpace();
6097 unsigned int defaultAS =
6098 ompBuilder.M.getDataLayout().getProgramAddressSpace();
6099
6100 // Create the alloca for the argument the current point.
6101 llvm::Value *v = builder.CreateAlloca(arg.getType(), allocaAS);
6102
6103 if (allocaAS != defaultAS && arg.getType()->isPointerTy())
6104 v = builder.CreateAddrSpaceCast(v, builder.getPtrTy(defaultAS));
6105
6106 builder.CreateStore(&arg, v);
6107
6108 builder.restoreIP(codeGenIP);
6109
6110 switch (capture) {
6111 case omp::VariableCaptureKind::ByCopy: {
6112 retVal = v;
6113 break;
6114 }
6115 case omp::VariableCaptureKind::ByRef: {
6116 llvm::LoadInst *loadInst = builder.CreateAlignedLoad(
6117 v->getType(), v,
6118 ompBuilder.M.getDataLayout().getPrefTypeAlign(v->getType()));
6119 // CreateAlignedLoad function creates similar LLVM IR:
6120 // %res = load ptr, ptr %input, align 8
6121 // This LLVM IR does not contain information about alignment
6122 // of the loaded value. We need to add !align metadata to unblock
6123 // optimizer. The existence of the !align metadata on the instruction
6124 // tells the optimizer that the value loaded is known to be aligned to
6125 // a boundary specified by the integer value in the metadata node.
6126 // Example:
6127 // %res = load ptr, ptr %input, align 8, !align !align_md_node
6128 // ^ ^
6129 // | |
6130 // alignment of %input address |
6131 // |
6132 // alignment of %res object
6133 if (v->getType()->isPointerTy() && alignmentValue) {
6134 llvm::MDBuilder MDB(builder.getContext());
6135 loadInst->setMetadata(
6136 llvm::LLVMContext::MD_align,
6137 llvm::MDNode::get(builder.getContext(),
6138 MDB.createConstant(llvm::ConstantInt::get(
6139 llvm::Type::getInt64Ty(builder.getContext()),
6140 alignmentValue))));
6141 }
6142 retVal = loadInst;
6143
6144 break;
6145 }
6146 case omp::VariableCaptureKind::This:
6147 case omp::VariableCaptureKind::VLAType:
6148 // TODO: Consider returning error to use standard reporting for
6149 // unimplemented features.
6150 assert(false && "Currently unsupported capture kind");
6151 break;
6152 }
6153
6154 return builder.saveIP();
6155}
6156
6157/// Follow uses of `host_eval`-defined block arguments of the given `omp.target`
6158/// operation and populate output variables with their corresponding host value
6159/// (i.e. operand evaluated outside of the target region), based on their uses
6160/// inside of the target region.
6161///
6162/// Loop bounds and steps are only optionally populated, if output vectors are
6163/// provided.
6164static void
6165extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
6166 Value &numTeamsLower, Value &numTeamsUpper,
6167 Value &threadLimit,
6168 llvm::SmallVectorImpl<Value> *lowerBounds = nullptr,
6169 llvm::SmallVectorImpl<Value> *upperBounds = nullptr,
6170 llvm::SmallVectorImpl<Value> *steps = nullptr) {
6171 auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
6172 for (auto item : llvm::zip_equal(targetOp.getHostEvalVars(),
6173 blockArgIface.getHostEvalBlockArgs())) {
6174 Value hostEvalVar = std::get<0>(item), blockArg = std::get<1>(item);
6175
6176 for (Operation *user : blockArg.getUsers()) {
6178 .Case([&](omp::TeamsOp teamsOp) {
6179 if (teamsOp.getNumTeamsLower() == blockArg)
6180 numTeamsLower = hostEvalVar;
6181 else if (llvm::is_contained(teamsOp.getNumTeamsUpperVars(),
6182 blockArg))
6183 numTeamsUpper = hostEvalVar;
6184 else if (!teamsOp.getThreadLimitVars().empty() &&
6185 teamsOp.getThreadLimit(0) == blockArg)
6186 threadLimit = hostEvalVar;
6187 else
6188 llvm_unreachable("unsupported host_eval use");
6189 })
6190 .Case([&](omp::ParallelOp parallelOp) {
6191 if (!parallelOp.getNumThreadsVars().empty() &&
6192 parallelOp.getNumThreads(0) == blockArg)
6193 numThreads = hostEvalVar;
6194 else
6195 llvm_unreachable("unsupported host_eval use");
6196 })
6197 .Case([&](omp::LoopNestOp loopOp) {
6198 auto processBounds =
6199 [&](OperandRange opBounds,
6200 llvm::SmallVectorImpl<Value> *outBounds) -> bool {
6201 bool found = false;
6202 for (auto [i, lb] : llvm::enumerate(opBounds)) {
6203 if (lb == blockArg) {
6204 found = true;
6205 if (outBounds)
6206 (*outBounds)[i] = hostEvalVar;
6207 }
6208 }
6209 return found;
6210 };
6211 bool found =
6212 processBounds(loopOp.getLoopLowerBounds(), lowerBounds);
6213 found = processBounds(loopOp.getLoopUpperBounds(), upperBounds) ||
6214 found;
6215 found = processBounds(loopOp.getLoopSteps(), steps) || found;
6216 (void)found;
6217 assert(found && "unsupported host_eval use");
6218 })
6219 .DefaultUnreachable("unsupported host_eval use");
6220 }
6221 }
6222}
6223
6224/// If \p op is of the given type parameter, return it casted to that type.
6225/// Otherwise, if its immediate parent operation (or some other higher-level
6226/// parent, if \p immediateParent is false) is of that type, return that parent
6227/// casted to the given type.
6228///
6229/// If \p op is \c null or neither it or its parent(s) are of the specified
6230/// type, return a \c null operation.
6231template <typename OpTy>
6232static OpTy castOrGetParentOfType(Operation *op, bool immediateParent = false) {
6233 if (!op)
6234 return OpTy();
6235
6236 if (OpTy casted = dyn_cast<OpTy>(op))
6237 return casted;
6238
6239 if (immediateParent)
6240 return dyn_cast_if_present<OpTy>(op->getParentOp());
6241
6242 return op->getParentOfType<OpTy>();
6243}
6244
6245/// If the given \p value is defined by an \c llvm.mlir.constant operation and
6246/// it is of an integer type, return its value.
6247static std::optional<int64_t> extractConstInteger(Value value) {
6248 if (!value)
6249 return std::nullopt;
6250
6251 if (auto constOp = value.getDefiningOp<LLVM::ConstantOp>())
6252 if (auto constAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
6253 return constAttr.getInt();
6254
6255 return std::nullopt;
6256}
6257
6258static uint64_t getTypeByteSize(mlir::Type type, const DataLayout &dl) {
6259 uint64_t sizeInBits = dl.getTypeSizeInBits(type);
6260 uint64_t sizeInBytes = sizeInBits / 8;
6261 return sizeInBytes;
6262}
6263
6264template <typename OpTy>
6265static uint64_t getReductionDataSize(OpTy &op) {
6266 if (op.getNumReductionVars() > 0) {
6268 collectReductionDecls(op, reductions);
6269
6271 members.reserve(reductions.size());
6272 for (omp::DeclareReductionOp &red : reductions)
6273 members.push_back(red.getType());
6274 Operation *opp = op.getOperation();
6275 auto structType = mlir::LLVM::LLVMStructType::getLiteral(
6276 opp->getContext(), members, /*isPacked=*/false);
6277 DataLayout dl = DataLayout(opp->getParentOfType<ModuleOp>());
6278 return getTypeByteSize(structType, dl);
6279 }
6280 return 0;
6281}
6282
6283/// Populate default `MinTeams`, `MaxTeams` and `MaxThreads` to their default
6284/// values as stated by the corresponding clauses, if constant.
6285///
6286/// These default values must be set before the creation of the outlined LLVM
6287/// function for the target region, so that they can be used to initialize the
6288/// corresponding global `ConfigurationEnvironmentTy` structure.
6289static void
6290initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
6291 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs,
6292 bool isTargetDevice, bool isGPU) {
6293 // TODO: Handle constant 'if' clauses.
6294
6295 Value numThreads, numTeamsLower, numTeamsUpper, threadLimit;
6296 if (!isTargetDevice) {
6297 extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper,
6298 threadLimit);
6299 } else {
6300 // In the target device, values for these clauses are not passed as
6301 // host_eval, but instead evaluated prior to entry to the region. This
6302 // ensures values are mapped and available inside of the target region.
6303 if (auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
6304 numTeamsLower = teamsOp.getNumTeamsLower();
6305 // Handle num_teams upper bounds (only first value for now)
6306 if (!teamsOp.getNumTeamsUpperVars().empty())
6307 numTeamsUpper = teamsOp.getNumTeams(0);
6308 if (!teamsOp.getThreadLimitVars().empty())
6309 threadLimit = teamsOp.getThreadLimit(0);
6310 }
6311
6312 if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) {
6313 if (!parallelOp.getNumThreadsVars().empty())
6314 numThreads = parallelOp.getNumThreads(0);
6315 }
6316 }
6317
6318 // Handle clauses impacting the number of teams.
6319
6320 int32_t minTeamsVal = 1, maxTeamsVal = -1;
6321 if (castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
6322 // TODO: Use `hostNumTeamsLower` to initialize `minTeamsVal`. For now,
6323 // match clang and set min and max to the same value.
6324 if (numTeamsUpper) {
6325 if (auto val = extractConstInteger(numTeamsUpper))
6326 minTeamsVal = maxTeamsVal = *val;
6327 } else {
6328 minTeamsVal = maxTeamsVal = 0;
6329 }
6330 } else if (castOrGetParentOfType<omp::ParallelOp>(capturedOp,
6331 /*immediateParent=*/true) ||
6333 /*immediateParent=*/true)) {
6334 minTeamsVal = maxTeamsVal = 1;
6335 } else {
6336 minTeamsVal = maxTeamsVal = -1;
6337 }
6338
6339 // Handle clauses impacting the number of threads.
6340
6341 auto setMaxValueFromClause = [](Value clauseValue, int32_t &result) {
6342 if (!clauseValue)
6343 return;
6344
6345 if (auto val = extractConstInteger(clauseValue))
6346 result = *val;
6347
6348 // Found an applicable clause, so it's not undefined. Mark as unknown
6349 // because it's not constant.
6350 if (result < 0)
6351 result = 0;
6352 };
6353
6354 // Extract 'thread_limit' clause from 'target' and 'teams' directives.
6355 int32_t targetThreadLimitVal = -1, teamsThreadLimitVal = -1;
6356 if (!targetOp.getThreadLimitVars().empty())
6357 setMaxValueFromClause(targetOp.getThreadLimit(0), targetThreadLimitVal);
6358 setMaxValueFromClause(threadLimit, teamsThreadLimitVal);
6359
6360 // Extract 'max_threads' clause from 'parallel' or set to 1 if it's SIMD.
6361 int32_t maxThreadsVal = -1;
6363 setMaxValueFromClause(numThreads, maxThreadsVal);
6364 else if (castOrGetParentOfType<omp::SimdOp>(capturedOp,
6365 /*immediateParent=*/true))
6366 maxThreadsVal = 1;
6367
6368 // For max values, < 0 means unset, == 0 means set but unknown. Select the
6369 // minimum value between 'max_threads' and 'thread_limit' clauses that were
6370 // set.
6371 int32_t combinedMaxThreadsVal = targetThreadLimitVal;
6372 if (combinedMaxThreadsVal < 0 ||
6373 (teamsThreadLimitVal >= 0 && teamsThreadLimitVal < combinedMaxThreadsVal))
6374 combinedMaxThreadsVal = teamsThreadLimitVal;
6375
6376 if (combinedMaxThreadsVal < 0 ||
6377 (maxThreadsVal >= 0 && maxThreadsVal < combinedMaxThreadsVal))
6378 combinedMaxThreadsVal = maxThreadsVal;
6379
6380 int32_t reductionDataSize = 0;
6381 if (isGPU && capturedOp) {
6382 if (auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp))
6383 reductionDataSize = getReductionDataSize(teamsOp);
6384 }
6385
6386 // Update kernel bounds structure for the `OpenMPIRBuilder` to use.
6387 omp::TargetRegionFlags kernelFlags = targetOp.getKernelExecFlags(capturedOp);
6388 assert(
6389 omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic |
6390 omp::TargetRegionFlags::spmd) &&
6391 "invalid kernel flags");
6392 attrs.ExecFlags =
6393 omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic)
6394 ? omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::spmd)
6395 ? llvm::omp::OMP_TGT_EXEC_MODE_GENERIC_SPMD
6396 : llvm::omp::OMP_TGT_EXEC_MODE_GENERIC
6397 : llvm::omp::OMP_TGT_EXEC_MODE_SPMD;
6398 if (omp::bitEnumContainsAll(kernelFlags,
6399 omp::TargetRegionFlags::spmd |
6400 omp::TargetRegionFlags::no_loop) &&
6401 !omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic))
6402 attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_SPMD_NO_LOOP;
6403
6404 attrs.MinTeams = minTeamsVal;
6405 attrs.MaxTeams.front() = maxTeamsVal;
6406 attrs.MinThreads = 1;
6407 attrs.MaxThreads.front() = combinedMaxThreadsVal;
6408 attrs.ReductionDataSize = reductionDataSize;
6409 // TODO: Allow modified buffer length similar to
6410 // fopenmp-cuda-teams-reduction-recs-num flag in clang.
6411 if (attrs.ReductionDataSize != 0)
6412 attrs.ReductionBufferLength = 1024;
6413}
6414
6415/// Gather LLVM runtime values for all clauses evaluated in the host that are
6416/// passed to the kernel invocation.
6417///
6418/// This function must be called only when compiling for the host. Also, it will
6419/// only provide correct results if it's called after the body of \c targetOp
6420/// has been fully generated.
6421static void
6422initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
6423 LLVM::ModuleTranslation &moduleTranslation,
6424 omp::TargetOp targetOp, Operation *capturedOp,
6425 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs) {
6426 omp::LoopNestOp loopOp = castOrGetParentOfType<omp::LoopNestOp>(capturedOp);
6427 unsigned numLoops = loopOp ? loopOp.getNumLoops() : 0;
6428
6429 Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
6430 llvm::SmallVector<Value> lowerBounds(numLoops), upperBounds(numLoops),
6431 steps(numLoops);
6432 extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper,
6433 teamsThreadLimit, &lowerBounds, &upperBounds, &steps);
6434
6435 // TODO: Handle constant 'if' clauses.
6436 if (!targetOp.getThreadLimitVars().empty()) {
6437 Value targetThreadLimit = targetOp.getThreadLimit(0);
6438 attrs.TargetThreadLimit.front() =
6439 moduleTranslation.lookupValue(targetThreadLimit);
6440 }
6441
6442 // The __kmpc_push_num_teams_51 function expects int32 as the arguments. So,
6443 // truncate or sign extend lower and upper num_teams bounds as well as
6444 // thread_limit to match int32 ABI requirements for the OpenMP runtime.
6445 if (numTeamsLower)
6446 attrs.MinTeams = builder.CreateSExtOrTrunc(
6447 moduleTranslation.lookupValue(numTeamsLower), builder.getInt32Ty());
6448
6449 if (numTeamsUpper)
6450 attrs.MaxTeams.front() = builder.CreateSExtOrTrunc(
6451 moduleTranslation.lookupValue(numTeamsUpper), builder.getInt32Ty());
6452
6453 if (teamsThreadLimit)
6454 attrs.TeamsThreadLimit.front() = builder.CreateSExtOrTrunc(
6455 moduleTranslation.lookupValue(teamsThreadLimit), builder.getInt32Ty());
6456
6457 if (numThreads)
6458 attrs.MaxThreads = moduleTranslation.lookupValue(numThreads);
6459
6460 if (omp::bitEnumContainsAny(targetOp.getKernelExecFlags(capturedOp),
6461 omp::TargetRegionFlags::trip_count)) {
6462 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
6463 attrs.LoopTripCount = nullptr;
6464
6465 // To calculate the trip count, we multiply together the trip counts of
6466 // every collapsed canonical loop. We don't need to create the loop nests
6467 // here, since we're only interested in the trip count.
6468 for (auto [loopLower, loopUpper, loopStep] :
6469 llvm::zip_equal(lowerBounds, upperBounds, steps)) {
6470 llvm::Value *lowerBound = moduleTranslation.lookupValue(loopLower);
6471 llvm::Value *upperBound = moduleTranslation.lookupValue(loopUpper);
6472 llvm::Value *step = moduleTranslation.lookupValue(loopStep);
6473
6474 if (!lowerBound || !upperBound || !step) {
6475 attrs.LoopTripCount = nullptr;
6476 break;
6477 }
6478
6479 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
6480 llvm::Value *tripCount = ompBuilder->calculateCanonicalLoopTripCount(
6481 loc, lowerBound, upperBound, step, /*IsSigned=*/true,
6482 loopOp.getLoopInclusive());
6483
6484 if (!attrs.LoopTripCount) {
6485 attrs.LoopTripCount = tripCount;
6486 continue;
6487 }
6488
6489 // TODO: Enable UndefinedSanitizer to diagnose an overflow here.
6490 attrs.LoopTripCount = builder.CreateMul(attrs.LoopTripCount, tripCount,
6491 {}, /*HasNUW=*/true);
6492 }
6493 }
6494
6495 attrs.DeviceID = builder.getInt64(llvm::omp::OMP_DEVICEID_UNDEF);
6496 if (mlir::Value devId = targetOp.getDevice()) {
6497 attrs.DeviceID = moduleTranslation.lookupValue(devId);
6498 attrs.DeviceID =
6499 builder.CreateSExtOrTrunc(attrs.DeviceID, builder.getInt64Ty());
6500 }
6501}
6502
6503static LogicalResult
6504convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
6505 LLVM::ModuleTranslation &moduleTranslation) {
6506 auto targetOp = cast<omp::TargetOp>(opInst);
6507
6508 // The current debug location already has the DISubprogram for the outlined
6509 // function that will be created for the target op. We save it here so that
6510 // we can set it on the outlined function.
6511 llvm::DebugLoc outlinedFnLoc = builder.getCurrentDebugLocation();
6512 if (failed(checkImplementationStatus(opInst)))
6513 return failure();
6514
6515 // During the handling of target op, we will generate instructions in the
6516 // parent function like call to the oulined function or branch to a new
6517 // BasicBlock. We set the debug location here to parent function so that those
6518 // get the correct debug locations. For outlined functions, the normal MLIR op
6519 // conversion will automatically pick the correct location.
6520 llvm::BasicBlock *parentBB = builder.GetInsertBlock();
6521 assert(parentBB && "No insert block is set for the builder");
6522 llvm::Function *parentLLVMFn = parentBB->getParent();
6523 assert(parentLLVMFn && "Parent Function must be valid");
6524 if (llvm::DISubprogram *SP = parentLLVMFn->getSubprogram())
6525 builder.SetCurrentDebugLocation(llvm::DILocation::get(
6526 parentLLVMFn->getContext(), outlinedFnLoc.getLine(),
6527 outlinedFnLoc.getCol(), SP, outlinedFnLoc.getInlinedAt()));
6528
6529 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
6530 bool isTargetDevice = ompBuilder->Config.isTargetDevice();
6531 bool isGPU = ompBuilder->Config.isGPU();
6532
6533 auto parentFn = opInst.getParentOfType<LLVM::LLVMFuncOp>();
6534 auto argIface = cast<omp::BlockArgOpenMPOpInterface>(opInst);
6535 auto &targetRegion = targetOp.getRegion();
6536 // Holds the private vars that have been mapped along with the block
6537 // argument that corresponds to the MapInfoOp corresponding to the private
6538 // var in question. So, for instance:
6539 //
6540 // %10 = omp.map.info var_ptr(%6#0 : !fir.ref<!fir.box<!fir.heap<i32>>>, ..)
6541 // omp.target map_entries(%10 -> %arg0) private(@box.privatizer %6#0-> %arg1)
6542 //
6543 // Then, %10 has been created so that the descriptor can be used by the
6544 // privatizer @box.privatizer on the device side. Here we'd record {%6#0,
6545 // %arg0} in the mappedPrivateVars map.
6546 llvm::DenseMap<Value, Value> mappedPrivateVars;
6547 DataLayout dl = DataLayout(opInst.getParentOfType<ModuleOp>());
6548 SmallVector<Value> mapVars = targetOp.getMapVars();
6549 SmallVector<Value> hdaVars = targetOp.getHasDeviceAddrVars();
6550 ArrayRef<BlockArgument> mapBlockArgs = argIface.getMapBlockArgs();
6551 ArrayRef<BlockArgument> hdaBlockArgs = argIface.getHasDeviceAddrBlockArgs();
6552 llvm::Function *llvmOutlinedFn = nullptr;
6553 TargetDirectiveEnumTy targetDirective =
6554 getTargetDirectiveEnumTyFromOp(&opInst);
6555
6556 // TODO: It can also be false if a compile-time constant `false` IF clause is
6557 // specified.
6558 bool isOffloadEntry =
6559 isTargetDevice || !ompBuilder->Config.TargetTriples.empty();
6560
6561 // For some private variables, the MapsForPrivatizedVariablesPass
6562 // creates MapInfoOp instances. Go through the private variables and
6563 // the mapped variables so that during codegeneration we are able
6564 // to quickly look up the corresponding map variable, if any for each
6565 // private variable.
6566 if (!targetOp.getPrivateVars().empty() && !targetOp.getMapVars().empty()) {
6567 OperandRange privateVars = targetOp.getPrivateVars();
6568 std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms();
6569 std::optional<DenseI64ArrayAttr> privateMapIndices =
6570 targetOp.getPrivateMapsAttr();
6571
6572 for (auto [privVarIdx, privVarSymPair] :
6573 llvm::enumerate(llvm::zip_equal(privateVars, *privateSyms))) {
6574 auto privVar = std::get<0>(privVarSymPair);
6575 auto privSym = std::get<1>(privVarSymPair);
6576
6577 SymbolRefAttr privatizerName = llvm::cast<SymbolRefAttr>(privSym);
6578 omp::PrivateClauseOp privatizer =
6579 findPrivatizer(targetOp, privatizerName);
6580
6581 if (!privatizer.needsMap())
6582 continue;
6583
6584 mlir::Value mappedValue =
6585 targetOp.getMappedValueForPrivateVar(privVarIdx);
6586 assert(mappedValue && "Expected to find mapped value for a privatized "
6587 "variable that needs mapping");
6588
6589 // The MapInfoOp defining the map var isn't really needed later.
6590 // So, we don't store it in any datastructure. Instead, we just
6591 // do some sanity checks on it right now.
6592 auto mapInfoOp = mappedValue.getDefiningOp<omp::MapInfoOp>();
6593 [[maybe_unused]] Type varType = mapInfoOp.getVarType();
6594
6595 // Check #1: Check that the type of the private variable matches
6596 // the type of the variable being mapped.
6597 if (!isa<LLVM::LLVMPointerType>(privVar.getType()))
6598 assert(
6599 varType == privVar.getType() &&
6600 "Type of private var doesn't match the type of the mapped value");
6601
6602 // Ok, only 1 sanity check for now.
6603 // Record the block argument corresponding to this mapvar.
6604 mappedPrivateVars.insert(
6605 {privVar,
6606 targetRegion.getArgument(argIface.getMapBlockArgsStart() +
6607 (*privateMapIndices)[privVarIdx])});
6608 }
6609 }
6610
6611 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
6612 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP)
6613 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
6614 llvm::IRBuilderBase::InsertPointGuard guard(builder);
6615 builder.SetCurrentDebugLocation(llvm::DebugLoc());
6616 // Forward target-cpu and target-features function attributes from the
6617 // original function to the new outlined function.
6618 llvm::Function *llvmParentFn =
6619 moduleTranslation.lookupFunction(parentFn.getName());
6620 llvmOutlinedFn = codeGenIP.getBlock()->getParent();
6621 assert(llvmParentFn && llvmOutlinedFn &&
6622 "Both parent and outlined functions must exist at this point");
6623
6624 if (outlinedFnLoc && llvmParentFn->getSubprogram())
6625 llvmOutlinedFn->setSubprogram(outlinedFnLoc->getScope()->getSubprogram());
6626
6627 if (auto attr = llvmParentFn->getFnAttribute("target-cpu");
6628 attr.isStringAttribute())
6629 llvmOutlinedFn->addFnAttr(attr);
6630
6631 if (auto attr = llvmParentFn->getFnAttribute("target-features");
6632 attr.isStringAttribute())
6633 llvmOutlinedFn->addFnAttr(attr);
6634
6635 for (auto [arg, mapOp] : llvm::zip_equal(mapBlockArgs, mapVars)) {
6636 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
6637 llvm::Value *mapOpValue =
6638 moduleTranslation.lookupValue(mapInfoOp.getVarPtr());
6639 moduleTranslation.mapValue(arg, mapOpValue);
6640 }
6641 for (auto [arg, mapOp] : llvm::zip_equal(hdaBlockArgs, hdaVars)) {
6642 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
6643 llvm::Value *mapOpValue =
6644 moduleTranslation.lookupValue(mapInfoOp.getVarPtr());
6645 moduleTranslation.mapValue(arg, mapOpValue);
6646 }
6647
6648 // Do privatization after moduleTranslation has already recorded
6649 // mapped values.
6650 PrivateVarsInfo privateVarsInfo(targetOp);
6651
6653 allocatePrivateVars(builder, moduleTranslation, privateVarsInfo,
6654 allocaIP, &mappedPrivateVars);
6655
6656 if (failed(handleError(afterAllocas, *targetOp)))
6657 return llvm::make_error<PreviouslyReportedError>();
6658
6659 builder.restoreIP(codeGenIP);
6660 if (handleError(initPrivateVars(builder, moduleTranslation, privateVarsInfo,
6661 &mappedPrivateVars),
6662 *targetOp)
6663 .failed())
6664 return llvm::make_error<PreviouslyReportedError>();
6665
6666 if (failed(copyFirstPrivateVars(
6667 targetOp, builder, moduleTranslation, privateVarsInfo.mlirVars,
6668 privateVarsInfo.llvmVars, privateVarsInfo.privatizers,
6669 targetOp.getPrivateNeedsBarrier(), &mappedPrivateVars)))
6670 return llvm::make_error<PreviouslyReportedError>();
6671
6672 SmallVector<Region *> privateCleanupRegions;
6673 llvm::transform(privateVarsInfo.privatizers,
6674 std::back_inserter(privateCleanupRegions),
6675 [](omp::PrivateClauseOp privatizer) {
6676 return &privatizer.getDeallocRegion();
6677 });
6678
6680 targetRegion, "omp.target", builder, moduleTranslation);
6681
6682 if (!exitBlock)
6683 return exitBlock.takeError();
6684
6685 builder.SetInsertPoint(*exitBlock);
6686 if (!privateCleanupRegions.empty()) {
6687 if (failed(inlineOmpRegionCleanup(
6688 privateCleanupRegions, privateVarsInfo.llvmVars,
6689 moduleTranslation, builder, "omp.targetop.private.cleanup",
6690 /*shouldLoadCleanupRegionArg=*/false))) {
6691 return llvm::createStringError(
6692 "failed to inline `dealloc` region of `omp.private` "
6693 "op in the target region");
6694 }
6695 return builder.saveIP();
6696 }
6697
6698 return InsertPointTy(exitBlock.get(), exitBlock.get()->end());
6699 };
6700
6701 StringRef parentName = parentFn.getName();
6702
6703 llvm::TargetRegionEntryInfo entryInfo;
6704
6705 getTargetEntryUniqueInfo(entryInfo, targetOp, parentName);
6706
6707 MapInfoData mapData;
6708 collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, dl,
6709 builder, /*useDevPtrOperands=*/{},
6710 /*useDevAddrOperands=*/{}, hdaVars);
6711
6712 MapInfosTy combinedInfos;
6713 auto genMapInfoCB =
6714 [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP) -> MapInfosTy & {
6715 builder.restoreIP(codeGenIP);
6716 genMapInfos(builder, moduleTranslation, dl, combinedInfos, mapData,
6717 targetDirective);
6718 return combinedInfos;
6719 };
6720
6721 auto argAccessorCB = [&](llvm::Argument &arg, llvm::Value *input,
6722 llvm::Value *&retVal, InsertPointTy allocaIP,
6723 InsertPointTy codeGenIP)
6724 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
6725 llvm::IRBuilderBase::InsertPointGuard guard(builder);
6726 builder.SetCurrentDebugLocation(llvm::DebugLoc());
6727 // We just return the unaltered argument for the host function
6728 // for now, some alterations may be required in the future to
6729 // keep host fallback functions working identically to the device
6730 // version (e.g. pass ByCopy values should be treated as such on
6731 // host and device, currently not always the case)
6732 if (!isTargetDevice) {
6733 retVal = cast<llvm::Value>(&arg);
6734 return codeGenIP;
6735 }
6736
6737 return createDeviceArgumentAccessor(mapData, arg, input, retVal, builder,
6738 *ompBuilder, moduleTranslation,
6739 allocaIP, codeGenIP);
6740 };
6741
6742 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs runtimeAttrs;
6743 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs;
6744 Operation *targetCapturedOp = targetOp.getInnermostCapturedOmpOp();
6745 initTargetDefaultAttrs(targetOp, targetCapturedOp, defaultAttrs,
6746 isTargetDevice, isGPU);
6747
6748 // Collect host-evaluated values needed to properly launch the kernel from the
6749 // host.
6750 if (!isTargetDevice)
6751 initTargetRuntimeAttrs(builder, moduleTranslation, targetOp,
6752 targetCapturedOp, runtimeAttrs);
6753
6754 // Pass host-evaluated values as parameters to the kernel / host fallback,
6755 // except if they are constants. In any case, map the MLIR block argument to
6756 // the corresponding LLVM values.
6758 SmallVector<Value> hostEvalVars = targetOp.getHostEvalVars();
6759 ArrayRef<BlockArgument> hostEvalBlockArgs = argIface.getHostEvalBlockArgs();
6760 for (auto [arg, var] : llvm::zip_equal(hostEvalBlockArgs, hostEvalVars)) {
6761 llvm::Value *value = moduleTranslation.lookupValue(var);
6762 moduleTranslation.mapValue(arg, value);
6763
6764 if (!llvm::isa<llvm::Constant>(value))
6765 kernelInput.push_back(value);
6766 }
6767
6768 for (size_t i = 0, e = mapData.OriginalValue.size(); i != e; ++i) {
6769 // declare target arguments are not passed to kernels as arguments
6770 // TODO: We currently do not handle cases where a member is explicitly
6771 // passed in as an argument, this will likley need to be handled in
6772 // the near future, rather than using IsAMember, it may be better to
6773 // test if the relevant BlockArg is used within the target region and
6774 // then use that as a basis for exclusion in the kernel inputs.
6775 if (!mapData.IsDeclareTarget[i] && !mapData.IsAMember[i])
6776 kernelInput.push_back(mapData.OriginalValue[i]);
6777 }
6778
6780 buildDependData(targetOp.getDependKinds(), targetOp.getDependVars(),
6781 moduleTranslation, dds);
6782
6783 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
6784 findAllocaInsertPoint(builder, moduleTranslation);
6785 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
6786
6787 llvm::OpenMPIRBuilder::TargetDataInfo info(
6788 /*RequiresDevicePointerInfo=*/false,
6789 /*SeparateBeginEndCalls=*/true);
6790
6791 auto customMapperCB =
6792 [&](unsigned int i) -> llvm::Expected<llvm::Function *> {
6793 if (!combinedInfos.Mappers[i])
6794 return nullptr;
6795 info.HasMapper = true;
6796 return getOrCreateUserDefinedMapperFunc(combinedInfos.Mappers[i], builder,
6797 moduleTranslation, targetDirective);
6798 };
6799
6800 llvm::Value *ifCond = nullptr;
6801 if (Value targetIfCond = targetOp.getIfExpr())
6802 ifCond = moduleTranslation.lookupValue(targetIfCond);
6803
6804 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
6805 moduleTranslation.getOpenMPBuilder()->createTarget(
6806 ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), info, entryInfo,
6807 defaultAttrs, runtimeAttrs, ifCond, kernelInput, genMapInfoCB, bodyCB,
6808 argAccessorCB, customMapperCB, dds, targetOp.getNowait());
6809
6810 if (failed(handleError(afterIP, opInst)))
6811 return failure();
6812
6813 builder.restoreIP(*afterIP);
6814
6815 // Remap access operations to declare target reference pointers for the
6816 // device, essentially generating extra loadop's as necessary
6817 if (moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice())
6818 handleDeclareTargetMapVar(mapData, moduleTranslation, builder,
6819 llvmOutlinedFn);
6820
6821 return success();
6822}
6823
6824static LogicalResult
6825convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute,
6826 LLVM::ModuleTranslation &moduleTranslation) {
6827 // Amend omp.declare_target by deleting the IR of the outlined functions
6828 // created for target regions. They cannot be filtered out from MLIR earlier
6829 // because the omp.target operation inside must be translated to LLVM, but
6830 // the wrapper functions themselves must not remain at the end of the
6831 // process. We know that functions where omp.declare_target does not match
6832 // omp.is_target_device at this stage can only be wrapper functions because
6833 // those that aren't are removed earlier as an MLIR transformation pass.
6834 if (FunctionOpInterface funcOp = dyn_cast<FunctionOpInterface>(op)) {
6835 if (auto offloadMod = dyn_cast<omp::OffloadModuleInterface>(
6836 op->getParentOfType<ModuleOp>().getOperation())) {
6837 if (!offloadMod.getIsTargetDevice())
6838 return success();
6839
6840 omp::DeclareTargetDeviceType declareType =
6841 attribute.getDeviceType().getValue();
6842
6843 if (declareType == omp::DeclareTargetDeviceType::host) {
6844 llvm::Function *llvmFunc =
6845 moduleTranslation.lookupFunction(funcOp.getName());
6846 llvmFunc->dropAllReferences();
6847 llvmFunc->eraseFromParent();
6848 }
6849 }
6850 return success();
6851 }
6852
6853 if (LLVM::GlobalOp gOp = dyn_cast<LLVM::GlobalOp>(op)) {
6854 llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
6855 if (auto *gVal = llvmModule->getNamedValue(gOp.getSymName())) {
6856 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
6857 bool isDeclaration = gOp.isDeclaration();
6858 bool isExternallyVisible =
6859 gOp.getVisibility() != mlir::SymbolTable::Visibility::Private;
6860 auto loc = op->getLoc()->findInstanceOf<FileLineColLoc>();
6861 llvm::StringRef mangledName = gOp.getSymName();
6862 auto captureClause =
6863 convertToCaptureClauseKind(attribute.getCaptureClause().getValue());
6864 auto deviceClause =
6865 convertToDeviceClauseKind(attribute.getDeviceType().getValue());
6866 // unused for MLIR at the moment, required in Clang for book
6867 // keeping
6868 std::vector<llvm::GlobalVariable *> generatedRefs;
6869
6870 std::vector<llvm::Triple> targetTriple;
6871 auto targetTripleAttr = dyn_cast_or_null<mlir::StringAttr>(
6872 op->getParentOfType<mlir::ModuleOp>()->getAttr(
6873 LLVM::LLVMDialect::getTargetTripleAttrName()));
6874 if (targetTripleAttr)
6875 targetTriple.emplace_back(targetTripleAttr.data());
6876
6877 auto fileInfoCallBack = [&loc]() {
6878 std::string filename = "";
6879 std::uint64_t lineNo = 0;
6880
6881 if (loc) {
6882 filename = loc.getFilename().str();
6883 lineNo = loc.getLine();
6884 }
6885
6886 return std::pair<std::string, std::uint64_t>(llvm::StringRef(filename),
6887 lineNo);
6888 };
6889
6890 auto vfs = llvm::vfs::getRealFileSystem();
6891
6892 ompBuilder->registerTargetGlobalVariable(
6893 captureClause, deviceClause, isDeclaration, isExternallyVisible,
6894 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack, *vfs),
6895 mangledName, generatedRefs, /*OpenMPSimd*/ false, targetTriple,
6896 /*GlobalInitializer*/ nullptr, /*VariableLinkage*/ nullptr,
6897 gVal->getType(), gVal);
6898
6899 if (ompBuilder->Config.isTargetDevice() &&
6900 (attribute.getCaptureClause().getValue() !=
6901 mlir::omp::DeclareTargetCaptureClause::to ||
6902 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
6903 ompBuilder->getAddrOfDeclareTargetVar(
6904 captureClause, deviceClause, isDeclaration, isExternallyVisible,
6905 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack, *vfs),
6906 mangledName, generatedRefs, /*OpenMPSimd*/ false, targetTriple,
6907 gVal->getType(), /*GlobalInitializer*/ nullptr,
6908 /*VariableLinkage*/ nullptr);
6909 }
6910 }
6911 }
6912
6913 return success();
6914}
6915
6916namespace {
6917
6918/// Implementation of the dialect interface that converts operations belonging
6919/// to the OpenMP dialect to LLVM IR.
6920class OpenMPDialectLLVMIRTranslationInterface
6921 : public LLVMTranslationDialectInterface {
6922public:
6923 using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
6924
6925 /// Translates the given operation to LLVM IR using the provided IR builder
6926 /// and saving the state in `moduleTranslation`.
6927 LogicalResult
6928 convertOperation(Operation *op, llvm::IRBuilderBase &builder,
6929 LLVM::ModuleTranslation &moduleTranslation) const final;
6930
6931 /// Given an OpenMP MLIR attribute, create the corresponding LLVM-IR,
6932 /// runtime calls, or operation amendments
6933 LogicalResult
6934 amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
6935 NamedAttribute attribute,
6936 LLVM::ModuleTranslation &moduleTranslation) const final;
6937};
6938
6939} // namespace
6940
6941LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
6942 Operation *op, ArrayRef<llvm::Instruction *> instructions,
6943 NamedAttribute attribute,
6944 LLVM::ModuleTranslation &moduleTranslation) const {
6945 return llvm::StringSwitch<llvm::function_ref<LogicalResult(Attribute)>>(
6946 attribute.getName())
6947 .Case("omp.is_target_device",
6948 [&](Attribute attr) {
6949 if (auto deviceAttr = dyn_cast<BoolAttr>(attr)) {
6950 llvm::OpenMPIRBuilderConfig &config =
6951 moduleTranslation.getOpenMPBuilder()->Config;
6952 config.setIsTargetDevice(deviceAttr.getValue());
6953 return success();
6954 }
6955 return failure();
6956 })
6957 .Case("omp.is_gpu",
6958 [&](Attribute attr) {
6959 if (auto gpuAttr = dyn_cast<BoolAttr>(attr)) {
6960 llvm::OpenMPIRBuilderConfig &config =
6961 moduleTranslation.getOpenMPBuilder()->Config;
6962 config.setIsGPU(gpuAttr.getValue());
6963 return success();
6964 }
6965 return failure();
6966 })
6967 .Case("omp.host_ir_filepath",
6968 [&](Attribute attr) {
6969 if (auto filepathAttr = dyn_cast<StringAttr>(attr)) {
6970 llvm::OpenMPIRBuilder *ompBuilder =
6971 moduleTranslation.getOpenMPBuilder();
6972 auto VFS = llvm::vfs::getRealFileSystem();
6973 ompBuilder->loadOffloadInfoMetadata(*VFS,
6974 filepathAttr.getValue());
6975 return success();
6976 }
6977 return failure();
6978 })
6979 .Case("omp.flags",
6980 [&](Attribute attr) {
6981 if (auto rtlAttr = dyn_cast<omp::FlagsAttr>(attr))
6982 return convertFlagsAttr(op, rtlAttr, moduleTranslation);
6983 return failure();
6984 })
6985 .Case("omp.version",
6986 [&](Attribute attr) {
6987 if (auto versionAttr = dyn_cast<omp::VersionAttr>(attr)) {
6988 llvm::OpenMPIRBuilder *ompBuilder =
6989 moduleTranslation.getOpenMPBuilder();
6990 ompBuilder->M.addModuleFlag(llvm::Module::Max, "openmp",
6991 versionAttr.getVersion());
6992 return success();
6993 }
6994 return failure();
6995 })
6996 .Case("omp.declare_target",
6997 [&](Attribute attr) {
6998 if (auto declareTargetAttr =
6999 dyn_cast<omp::DeclareTargetAttr>(attr))
7000 return convertDeclareTargetAttr(op, declareTargetAttr,
7001 moduleTranslation);
7002 return failure();
7003 })
7004 .Case("omp.requires",
7005 [&](Attribute attr) {
7006 if (auto requiresAttr = dyn_cast<omp::ClauseRequiresAttr>(attr)) {
7007 using Requires = omp::ClauseRequires;
7008 Requires flags = requiresAttr.getValue();
7009 llvm::OpenMPIRBuilderConfig &config =
7010 moduleTranslation.getOpenMPBuilder()->Config;
7011 config.setHasRequiresReverseOffload(
7012 bitEnumContainsAll(flags, Requires::reverse_offload));
7013 config.setHasRequiresUnifiedAddress(
7014 bitEnumContainsAll(flags, Requires::unified_address));
7015 config.setHasRequiresUnifiedSharedMemory(
7016 bitEnumContainsAll(flags, Requires::unified_shared_memory));
7017 config.setHasRequiresDynamicAllocators(
7018 bitEnumContainsAll(flags, Requires::dynamic_allocators));
7019 return success();
7020 }
7021 return failure();
7022 })
7023 .Case("omp.target_triples",
7024 [&](Attribute attr) {
7025 if (auto triplesAttr = dyn_cast<ArrayAttr>(attr)) {
7026 llvm::OpenMPIRBuilderConfig &config =
7027 moduleTranslation.getOpenMPBuilder()->Config;
7028 config.TargetTriples.clear();
7029 config.TargetTriples.reserve(triplesAttr.size());
7030 for (Attribute tripleAttr : triplesAttr) {
7031 if (auto tripleStrAttr = dyn_cast<StringAttr>(tripleAttr))
7032 config.TargetTriples.emplace_back(tripleStrAttr.getValue());
7033 else
7034 return failure();
7035 }
7036 return success();
7037 }
7038 return failure();
7039 })
7040 .Default([](Attribute) {
7041 // Fall through for omp attributes that do not require lowering.
7042 return success();
7043 })(attribute.getValue());
7044
7045 return failure();
7046}
7047
7048// Returns true if the operation is not inside a TargetOp, it is part of a
7049// function and that function is not declare target.
7050static bool isHostDeviceOp(Operation *op) {
7051 // Assumes no reverse offloading
7052 if (op->getParentOfType<omp::TargetOp>())
7053 return false;
7054
7055 if (auto parentFn = op->getParentOfType<LLVM::LLVMFuncOp>()) {
7056 if (auto declareTargetIface =
7057 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
7058 parentFn.getOperation()))
7059 if (declareTargetIface.isDeclareTarget() &&
7060 declareTargetIface.getDeclareTargetDeviceType() !=
7061 mlir::omp::DeclareTargetDeviceType::host)
7062 return false;
7063
7064 return true;
7065 }
7066
7067 return false;
7068}
7069
7070static llvm::Function *getOmpTargetAlloc(llvm::IRBuilderBase &builder,
7071 llvm::Module *llvmModule) {
7072 llvm::Type *i64Ty = builder.getInt64Ty();
7073 llvm::Type *i32Ty = builder.getInt32Ty();
7074 llvm::Type *returnType = builder.getPtrTy(0);
7075 llvm::FunctionType *fnType =
7076 llvm::FunctionType::get(returnType, {i64Ty, i32Ty}, false);
7077 llvm::Function *func = cast<llvm::Function>(
7078 llvmModule->getOrInsertFunction("omp_target_alloc", fnType).getCallee());
7079 return func;
7080}
7081
7082static LogicalResult
7083convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder,
7084 LLVM::ModuleTranslation &moduleTranslation) {
7085 auto allocMemOp = cast<omp::TargetAllocMemOp>(opInst);
7086 if (!allocMemOp)
7087 return failure();
7088
7089 // Get "omp_target_alloc" function
7090 llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
7091 llvm::Function *ompTargetAllocFunc = getOmpTargetAlloc(builder, llvmModule);
7092 // Get the corresponding device value in llvm
7093 mlir::Value deviceNum = allocMemOp.getDevice();
7094 llvm::Value *llvmDeviceNum = moduleTranslation.lookupValue(deviceNum);
7095 // Get the allocation size.
7096 llvm::DataLayout dataLayout = llvmModule->getDataLayout();
7097 mlir::Type heapTy = allocMemOp.getAllocatedType();
7098 llvm::Type *llvmHeapTy = moduleTranslation.convertType(heapTy);
7099 llvm::TypeSize typeSize = dataLayout.getTypeStoreSize(llvmHeapTy);
7100 llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue());
7101 for (auto typeParam : allocMemOp.getTypeparams())
7102 allocSize =
7103 builder.CreateMul(allocSize, moduleTranslation.lookupValue(typeParam));
7104 // Create call to "omp_target_alloc" with the args as translated llvm values.
7105 llvm::CallInst *call =
7106 builder.CreateCall(ompTargetAllocFunc, {allocSize, llvmDeviceNum});
7107 llvm::Value *resultI64 = builder.CreatePtrToInt(call, builder.getInt64Ty());
7108
7109 // Map the result
7110 moduleTranslation.mapValue(allocMemOp.getResult(), resultI64);
7111 return success();
7112}
7113
7114static llvm::Function *getOmpTargetFree(llvm::IRBuilderBase &builder,
7115 llvm::Module *llvmModule) {
7116 llvm::Type *ptrTy = builder.getPtrTy(0);
7117 llvm::Type *i32Ty = builder.getInt32Ty();
7118 llvm::Type *voidTy = builder.getVoidTy();
7119 llvm::FunctionType *fnType =
7120 llvm::FunctionType::get(voidTy, {ptrTy, i32Ty}, false);
7121 llvm::Function *func = dyn_cast<llvm::Function>(
7122 llvmModule->getOrInsertFunction("omp_target_free", fnType).getCallee());
7123 return func;
7124}
7125
7126static LogicalResult
7127convertTargetFreeMemOp(Operation &opInst, llvm::IRBuilderBase &builder,
7128 LLVM::ModuleTranslation &moduleTranslation) {
7129 auto freeMemOp = cast<omp::TargetFreeMemOp>(opInst);
7130 if (!freeMemOp)
7131 return failure();
7132
7133 // Get "omp_target_free" function
7134 llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
7135 llvm::Function *ompTragetFreeFunc = getOmpTargetFree(builder, llvmModule);
7136 // Get the corresponding device value in llvm
7137 mlir::Value deviceNum = freeMemOp.getDevice();
7138 llvm::Value *llvmDeviceNum = moduleTranslation.lookupValue(deviceNum);
7139 // Get the corresponding heapref value in llvm
7140 mlir::Value heapref = freeMemOp.getHeapref();
7141 llvm::Value *llvmHeapref = moduleTranslation.lookupValue(heapref);
7142 // Convert heapref int to ptr and call "omp_target_free"
7143 llvm::Value *intToPtr =
7144 builder.CreateIntToPtr(llvmHeapref, builder.getPtrTy(0));
7145 builder.CreateCall(ompTragetFreeFunc, {intToPtr, llvmDeviceNum});
7146 return success();
7147}
7148
7149/// Given an OpenMP MLIR operation, create the corresponding LLVM IR (including
7150/// OpenMP runtime calls).
7151LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
7152 Operation *op, llvm::IRBuilderBase &builder,
7153 LLVM::ModuleTranslation &moduleTranslation) const {
7154 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
7155
7156 if (ompBuilder->Config.isTargetDevice() &&
7157 !isa<omp::TargetOp, omp::MapInfoOp, omp::TerminatorOp, omp::YieldOp>(
7158 op) &&
7159 isHostDeviceOp(op))
7160 return op->emitOpError() << "unsupported host op found in device";
7161
7162 // For each loop, introduce one stack frame to hold loop information. Ensure
7163 // this is only done for the outermost loop wrapper to prevent introducing
7164 // multiple stack frames for a single loop. Initially set to null, the loop
7165 // information structure is initialized during translation of the nested
7166 // omp.loop_nest operation, making it available to translation of all loop
7167 // wrappers after their body has been successfully translated.
7168 bool isOutermostLoopWrapper =
7169 isa_and_present<omp::LoopWrapperInterface>(op) &&
7170 !dyn_cast_if_present<omp::LoopWrapperInterface>(op->getParentOp());
7171
7172 if (isOutermostLoopWrapper)
7173 moduleTranslation.stackPush<OpenMPLoopInfoStackFrame>();
7174
7175 auto result =
7176 llvm::TypeSwitch<Operation *, LogicalResult>(op)
7177 .Case([&](omp::BarrierOp op) -> LogicalResult {
7179 return failure();
7180
7181 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
7182 ompBuilder->createBarrier(builder.saveIP(),
7183 llvm::omp::OMPD_barrier);
7184 LogicalResult res = handleError(afterIP, *op);
7185 if (res.succeeded()) {
7186 // If the barrier generated a cancellation check, the insertion
7187 // point might now need to be changed to a new continuation block
7188 builder.restoreIP(*afterIP);
7189 }
7190 return res;
7191 })
7192 .Case([&](omp::TaskyieldOp op) {
7194 return failure();
7195
7196 ompBuilder->createTaskyield(builder.saveIP());
7197 return success();
7198 })
7199 .Case([&](omp::FlushOp op) {
7201 return failure();
7202
7203 // No support in Openmp runtime function (__kmpc_flush) to accept
7204 // the argument list.
7205 // OpenMP standard states the following:
7206 // "An implementation may implement a flush with a list by ignoring
7207 // the list, and treating it the same as a flush without a list."
7208 //
7209 // The argument list is discarded so that, flush with a list is
7210 // treated same as a flush without a list.
7211 ompBuilder->createFlush(builder.saveIP());
7212 return success();
7213 })
7214 .Case([&](omp::ParallelOp op) {
7215 return convertOmpParallel(op, builder, moduleTranslation);
7216 })
7217 .Case([&](omp::MaskedOp) {
7218 return convertOmpMasked(*op, builder, moduleTranslation);
7219 })
7220 .Case([&](omp::MasterOp) {
7221 return convertOmpMaster(*op, builder, moduleTranslation);
7222 })
7223 .Case([&](omp::CriticalOp) {
7224 return convertOmpCritical(*op, builder, moduleTranslation);
7225 })
7226 .Case([&](omp::OrderedRegionOp) {
7227 return convertOmpOrderedRegion(*op, builder, moduleTranslation);
7228 })
7229 .Case([&](omp::OrderedOp) {
7230 return convertOmpOrdered(*op, builder, moduleTranslation);
7231 })
7232 .Case([&](omp::WsloopOp) {
7233 return convertOmpWsloop(*op, builder, moduleTranslation);
7234 })
7235 .Case([&](omp::SimdOp) {
7236 return convertOmpSimd(*op, builder, moduleTranslation);
7237 })
7238 .Case([&](omp::AtomicReadOp) {
7239 return convertOmpAtomicRead(*op, builder, moduleTranslation);
7240 })
7241 .Case([&](omp::AtomicWriteOp) {
7242 return convertOmpAtomicWrite(*op, builder, moduleTranslation);
7243 })
7244 .Case([&](omp::AtomicUpdateOp op) {
7245 return convertOmpAtomicUpdate(op, builder, moduleTranslation);
7246 })
7247 .Case([&](omp::AtomicCaptureOp op) {
7248 return convertOmpAtomicCapture(op, builder, moduleTranslation);
7249 })
7250 .Case([&](omp::CancelOp op) {
7251 return convertOmpCancel(op, builder, moduleTranslation);
7252 })
7253 .Case([&](omp::CancellationPointOp op) {
7254 return convertOmpCancellationPoint(op, builder, moduleTranslation);
7255 })
7256 .Case([&](omp::SectionsOp) {
7257 return convertOmpSections(*op, builder, moduleTranslation);
7258 })
7259 .Case([&](omp::SingleOp op) {
7260 return convertOmpSingle(op, builder, moduleTranslation);
7261 })
7262 .Case([&](omp::TeamsOp op) {
7263 return convertOmpTeams(op, builder, moduleTranslation);
7264 })
7265 .Case([&](omp::TaskOp op) {
7266 return convertOmpTaskOp(op, builder, moduleTranslation);
7267 })
7268 .Case([&](omp::TaskloopOp op) {
7269 return convertOmpTaskloopOp(*op, builder, moduleTranslation);
7270 })
7271 .Case([&](omp::TaskgroupOp op) {
7272 return convertOmpTaskgroupOp(op, builder, moduleTranslation);
7273 })
7274 .Case([&](omp::TaskwaitOp op) {
7275 return convertOmpTaskwaitOp(op, builder, moduleTranslation);
7276 })
7277 .Case<omp::YieldOp, omp::TerminatorOp, omp::DeclareMapperOp,
7278 omp::DeclareMapperInfoOp, omp::DeclareReductionOp,
7279 omp::CriticalDeclareOp>([](auto op) {
7280 // `yield` and `terminator` can be just omitted. The block structure
7281 // was created in the region that handles their parent operation.
7282 // `declare_reduction` will be used by reductions and is not
7283 // converted directly, skip it.
7284 // `declare_mapper` and `declare_mapper.info` are handled whenever
7285 // they are referred to through a `map` clause.
7286 // `critical.declare` is only used to declare names of critical
7287 // sections which will be used by `critical` ops and hence can be
7288 // ignored for lowering. The OpenMP IRBuilder will create unique
7289 // name for critical section names.
7290 return success();
7291 })
7292 .Case([&](omp::ThreadprivateOp) {
7293 return convertOmpThreadprivate(*op, builder, moduleTranslation);
7294 })
7295 .Case<omp::TargetDataOp, omp::TargetEnterDataOp,
7296 omp::TargetExitDataOp, omp::TargetUpdateOp>([&](auto op) {
7297 return convertOmpTargetData(op, builder, moduleTranslation);
7298 })
7299 .Case([&](omp::TargetOp) {
7300 return convertOmpTarget(*op, builder, moduleTranslation);
7301 })
7302 .Case([&](omp::DistributeOp) {
7303 return convertOmpDistribute(*op, builder, moduleTranslation);
7304 })
7305 .Case([&](omp::LoopNestOp) {
7306 return convertOmpLoopNest(*op, builder, moduleTranslation);
7307 })
7308 .Case<omp::MapInfoOp, omp::MapBoundsOp, omp::PrivateClauseOp>(
7309 [&](auto op) {
7310 // No-op, should be handled by relevant owning operations e.g.
7311 // TargetOp, TargetEnterDataOp, TargetExitDataOp, TargetDataOp
7312 // etc. and then discarded
7313 return success();
7314 })
7315 .Case([&](omp::NewCliOp op) {
7316 // Meta-operation: Doesn't do anything by itself, but used to
7317 // identify a loop.
7318 return success();
7319 })
7320 .Case([&](omp::CanonicalLoopOp op) {
7321 return convertOmpCanonicalLoopOp(op, builder, moduleTranslation);
7322 })
7323 .Case([&](omp::UnrollHeuristicOp op) {
7324 // FIXME: Handling omp.unroll_heuristic as an executable requires
7325 // that the generator (e.g. omp.canonical_loop) has been seen first.
7326 // For construct that require all codegen to occur inside a callback
7327 // (e.g. OpenMPIRBilder::createParallel), all codegen of that
7328 // contained region including their transformations must occur at
7329 // the omp.canonical_loop.
7330 return applyUnrollHeuristic(op, builder, moduleTranslation);
7331 })
7332 .Case([&](omp::TileOp op) {
7333 return applyTile(op, builder, moduleTranslation);
7334 })
7335 .Case([&](omp::FuseOp op) {
7336 return applyFuse(op, builder, moduleTranslation);
7337 })
7338 .Case([&](omp::TargetAllocMemOp) {
7339 return convertTargetAllocMemOp(*op, builder, moduleTranslation);
7340 })
7341 .Case([&](omp::TargetFreeMemOp) {
7342 return convertTargetFreeMemOp(*op, builder, moduleTranslation);
7343 })
7344 .Default([&](Operation *inst) {
7345 return inst->emitError()
7346 << "not yet implemented: " << inst->getName();
7347 });
7348
7349 if (isOutermostLoopWrapper)
7350 moduleTranslation.stackPop();
7351
7352 return result;
7353}
7354
7356 registry.insert<omp::OpenMPDialect>();
7357 registry.addExtension(+[](MLIRContext *ctx, omp::OpenMPDialect *dialect) {
7358 dialect->addInterfaces<OpenMPDialectLLVMIRTranslationInterface>();
7359 });
7360}
7361
7363 DialectRegistry registry;
7365 context.appendDialectRegistry(registry);
7366}
for(Operation *op :ops)
return success()
lhs
static ze_device_handle_t getDevice(const uint32_t driverIdx=0, const int32_t devIdx=0)
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
if(!isCopyOut)
static bool mapTypeToBool(ClauseMapFlags value, ClauseMapFlags flag)
static void handleDeclareTargetMapVar(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::Function *func)
static LogicalResult convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP atomic update operation using OpenMPIRBuilder.
static llvm::omp::OrderKind convertOrderKind(std::optional< omp::ClauseOrderKind > o)
Convert Order attribute to llvm::omp::OrderKind.
static bool teamsReductionContainedInDistribute(omp::TeamsOp teamsOp)
static LogicalResult convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
owningDataPtrPtrReductionGens[i]
static Operation * getGlobalOpFromValue(Value value)
static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind convertToCaptureClauseKind(mlir::omp::DeclareTargetCaptureClause captureClause)
static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo, bool first)
static OpTy castOrGetParentOfType(Operation *op, bool immediateParent=false)
If op is of the given type parameter, return it casted to that type. Otherwise, if its immediate pare...
static LogicalResult convertOmpOrderedRegion(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'ordered_region' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertTargetFreeMemOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static void processMapMembersWithParent(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, llvm::omp::OpenMPOffloadMappingFlags memberOfFlag, TargetDirectiveEnumTy targetDirective)
static LogicalResult convertOmpAtomicWrite(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an omp.atomic.write operation to LLVM IR.
static OwningAtomicReductionGen makeAtomicReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Create an OpenMPIRBuilder-compatible atomic reduction generator for the given reduction declaration.
static OwningDataPtrPtrReductionGen makeRefDataPtrGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, bool isByRef)
Create an OpenMPIRBuilder-compatible data_ptr_ptr reduction generator for the given reduction declara...
static llvm::Value * getRefPtrIfDeclareTarget(Value value, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult applyUnrollHeuristic(omp::UnrollHeuristicOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Apply a #pragma omp unroll / "!$omp unroll" transformation using the OpenMPIRBuilder.
static LogicalResult convertOmpMaster(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'master' operation into LLVM IR using OpenMPIRBuilder.
static void getAsIntegers(ArrayAttr values, llvm::SmallVector< int64_t > &ints)
static llvm::Value * findAssociatedValue(Value privateVar, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Return the llvm::Value * corresponding to the privateVar that is being privatized....
static ArrayRef< bool > getIsByRef(std::optional< ArrayRef< bool > > attr)
static LogicalResult convertOmpTaskwaitOp(omp::TaskwaitOp twOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP loop nest into LLVM IR using OpenMPIRBuilder.
static void popCancelFinalizationCB(const ArrayRef< llvm::BranchInst * > cancelTerminators, llvm::OpenMPIRBuilder &ompBuilder, const llvm::OpenMPIRBuilder::InsertPointTy &afterIP)
If we cancelled the construct, we should branch to the finalization block of that construct....
static LogicalResult cleanupPrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, Location loc, SmallVectorImpl< llvm::Value * > &llvmPrivateVars, SmallVectorImpl< omp::PrivateClauseOp > &privateDecls)
static void createAlteredByCaptureMap(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder)
static LogicalResult convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP task construct into LLVM IR using OpenMPIRBuilder.
static void genMapInfos(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, TargetDirectiveEnumTy targetDirective)
static llvm::AtomicOrdering convertAtomicOrdering(std::optional< omp::ClauseMemoryOrderKind > ao)
Convert an Atomic Ordering attribute to llvm::AtomicOrdering.
static void setInsertPointForPossiblyEmptyBlock(llvm::IRBuilderBase &builder, llvm::BasicBlock *block=nullptr)
static LogicalResult convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'critical' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, TargetDirectiveEnumTy targetDirective)
static LogicalResult convertOmpCanonicalLoopOp(omp::CanonicalLoopOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Convert an omp.canonical_loop to LLVM-IR.
static LogicalResult convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static std::optional< int64_t > extractConstInteger(Value value)
If the given value is defined by an llvm.mlir.constant operation and it is of an integer type,...
static llvm::Expected< llvm::Value * > initPrivateVar(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, omp::PrivateClauseOp &privDecl, llvm::Value *nonPrivateVar, BlockArgument &blockArg, llvm::Value *llvmPrivateVar, llvm::BasicBlock *privInitBlock, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Initialize a single (first)private variable. You probably want to use allocateAndInitPrivateVars inst...
static LogicalResult allocAndInitializeReductionVars(OP op, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, llvm::ArrayRef< bool > isByRef)
static LogicalResult convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Function * getOmpTargetAlloc(llvm::IRBuilderBase &builder, llvm::Module *llvmModule)
static llvm::Expected< llvm::Function * > emitUserDefinedMapper(Operation *declMapperOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::StringRef mapperFuncName, TargetDirectiveEnumTy targetDirective)
static llvm::Expected< llvm::BasicBlock * > allocatePrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, PrivateVarsInfo &privateVarsInfo, const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Allocate and initialize delayed private variables. Returns the basic block which comes after all of t...
static LogicalResult convertOmpOrdered(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'ordered' operation into LLVM IR using OpenMPIRBuilder.
static void processMapWithMembersOf(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, TargetDirectiveEnumTy targetDirective)
static LogicalResult convertOmpTaskloopOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpMasked(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'masked' operation into LLVM IR using OpenMPIRBuilder.
static llvm::AtomicRMWInst::BinOp convertBinOpToAtomic(Operation &op)
Converts an LLVM dialect binary operation to the corresponding enum value for atomicrmw supported bin...
static void pushCancelFinalizationCB(SmallVectorImpl< llvm::BranchInst * > &cancelTerminators, llvm::IRBuilderBase &llvmBuilder, llvm::OpenMPIRBuilder &ompBuilder, mlir::Operation *op, llvm::omp::Directive cancelDirective)
Shared implementation of a callback which adds a termiator for the new block created for the branch t...
static LogicalResult convertOmpCancel(omp::CancelOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp)
allocatedType moduleTranslation static convertType(allocatedType) LogicalResult inlineOmpRegionCleanup(llvm::SmallVectorImpl< Region * > &cleanupRegions, llvm::ArrayRef< llvm::Value * > privateVariables, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, StringRef regionName, bool shouldLoadCleanupRegionArg=true)
handling of DeclareReductionOp's cleanup region
static LogicalResult applyFuse(omp::FuseOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Apply a #pragma omp fuse / !$omp fuse transformation using the OpenMPIRBuilder.
static llvm::SmallString< 64 > getDeclareTargetRefPtrSuffix(LLVM::GlobalOp globalOp, llvm::OpenMPIRBuilder &ompBuilder)
static llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Error initPrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, PrivateVarsInfo &privateVarsInfo, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
static llvm::CanonicalLoopInfo * findCurrentLoopInfo(LLVM::ModuleTranslation &moduleTranslation)
Find the loop information structure for the loop nest being translated.
static OwningReductionGen makeReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Create an OpenMPIRBuilder-compatible reduction generator for the given reduction declaration.
static std::vector< llvm::Value * > calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, bool isArrayTy, OperandRange bounds)
This function calculates the array/pointer offset for map data provided with bounds operations,...
static LogicalResult convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts the OpenMP parallel operation to LLVM IR.
static llvm::OpenMPIRBuilder::InsertPointTy findAllocaInsertPoint(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Find the insertion point for allocas given the current insertion point for normal operations in the b...
static LogicalResult inlineConvertOmpRegions(Region &region, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::Value * > *continuationBlockArgs=nullptr)
Translates the blocks contained in the given region and appends them to at the current insertion poin...
static LogicalResult convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP Threadprivate operation into LLVM IR using OpenMPIRBuilder.
static omp::PrivateClauseOp findPrivatizer(Operation *from, SymbolRefAttr symbolName)
Looks up from the operation from and returns the PrivateClauseOp with name symbolName.
static LogicalResult convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Expected< llvm::Function * > getOrCreateUserDefinedMapperFunc(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, TargetDirectiveEnumTy targetDirective)
static uint64_t getTypeByteSize(mlir::Type type, const DataLayout &dl)
static void processIndividualMap(MapInfoData &mapData, size_t mapDataIdx, MapInfosTy &combinedInfo, TargetDirectiveEnumTy targetDirective, int mapDataParentIdx=-1)
static void getTargetEntryUniqueInfo(llvm::TargetRegionEntryInfo &targetInfo, omp::TargetOp targetOp, llvm::StringRef parentName="")
static void extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads, Value &numTeamsLower, Value &numTeamsUpper, Value &threadLimit, llvm::SmallVectorImpl< Value > *lowerBounds=nullptr, llvm::SmallVectorImpl< Value > *upperBounds=nullptr, llvm::SmallVectorImpl< Value > *steps=nullptr)
Follow uses of host_eval-defined block arguments of the given omp.target operation and populate outpu...
static llvm::Expected< llvm::BasicBlock * > convertOmpOpRegions(Region &region, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::PHINode * > *continuationBlockPHIs=nullptr)
Converts the given region that appears within an OpenMP dialect operation to LLVM IR,...
static LogicalResult copyFirstPrivateVars(mlir::Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::Value * > &moldVars, ArrayRef< llvm::Value * > llvmPrivateVars, SmallVectorImpl< omp::PrivateClauseOp > &privateDecls, bool insertBarrier, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
static LogicalResult allocReductionVars(T loop, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, SmallVectorImpl< DeferredStore > &deferredStores, llvm::ArrayRef< bool > isByRefs)
Allocate space for privatized reduction variables.
static bool constructIsCancellable(Operation *op)
Returns true if the construct contains omp.cancel or omp.cancellation_point.
static llvm::omp::OpenMPOffloadMappingFlags convertClauseMapFlags(omp::ClauseMapFlags mlirFlags)
static llvm::IRBuilderBase::InsertPoint createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg, llvm::Value *input, llvm::Value *&retVal, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase::InsertPoint allocaIP, llvm::IRBuilderBase::InsertPoint codeGenIP)
static void mapInitializationArgs(T loop, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, DenseMap< Value, llvm::Value * > &reductionVariableMap, unsigned i)
Map input arguments to reduction initialization region.
static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind)
Convert ProcBindKind from MLIR-generated enum to LLVM enum.
static void getOverlappedMembers(llvm::SmallVectorImpl< size_t > &overlapMapDataIdxs, omp::MapInfoOp parentOp)
static LogicalResult convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP single construct into LLVM IR using OpenMPIRBuilder.
static bool isDeclareTargetTo(Value value)
static uint64_t getArrayElementSizeInBits(LLVM::LLVMArrayType arrTy, DataLayout &dl)
static void collectReductionDecls(T op, SmallVectorImpl< omp::DeclareReductionOp > &reductions)
Populates reductions with reduction declarations used in the given op.
static LogicalResult handleError(llvm::Error error, Operation &op)
static LogicalResult convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind convertToDeviceClauseKind(mlir::omp::DeclareTargetDeviceType deviceClause)
static LogicalResult checkImplementationStatus(Operation &op)
Check whether translation to LLVM IR for the given operation is currently supported.
static LogicalResult createReductionsAndCleanup(OP op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, ArrayRef< llvm::Value * > privateReductionVariables, ArrayRef< bool > isByRef, bool isNowait=false, bool isTeamsReduction=false)
static LogicalResult convertOmpCancellationPoint(omp::CancellationPointOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static bool opIsInSingleThread(mlir::Operation *op)
This can't always be determined statically, but when we can, it is good to avoid generating compiler-...
static uint64_t getReductionDataSize(OpTy &op)
static LogicalResult convertOmpAtomicRead(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Convert omp.atomic.read operation to LLVM IR.
static llvm::omp::Directive convertCancellationConstructType(omp::ClauseCancellationConstructType directive)
static void initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp, llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs, bool isTargetDevice, bool isGPU)
Populate default MinTeams, MaxTeams and MaxThreads to their default values as stated by the correspon...
static void initTargetRuntimeAttrs(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, omp::TargetOp targetOp, Operation *capturedOp, llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs)
Gather LLVM runtime values for all clauses evaluated in the host that are passed to the kernel invoca...
static LogicalResult convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static bool isHostDeviceOp(Operation *op)
static bool isDeclareTargetLink(Value value)
static LogicalResult convertFlagsAttr(Operation *op, mlir::omp::FlagsAttr attribute, LLVM::ModuleTranslation &moduleTranslation)
Lowers the FlagsAttr which is applied to the module on the device pass when offloading,...
static bool checkIfPointerMap(omp::MapInfoOp mapOp)
static LogicalResult applyTile(omp::TileOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Apply a #pragma omp tile / !$omp tile transformation using the OpenMPIRBuilder.
static void sortMapIndices(llvm::SmallVectorImpl< size_t > &indices, omp::MapInfoOp mapInfo)
static void buildDependData(std::optional< ArrayAttr > dependKinds, OperandRange dependVars, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::OpenMPIRBuilder::DependData > &dds)
static llvm::Function * getOmpTargetFree(llvm::IRBuilderBase &builder, llvm::Module *llvmModule)
static LogicalResult convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder.
static void collectMapDataFromMapOperands(MapInfoData &mapData, SmallVectorImpl< Value > &mapVars, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, llvm::IRBuilderBase &builder, ArrayRef< Value > useDevPtrOperands={}, ArrayRef< Value > useDevAddrOperands={}, ArrayRef< Value > hasDevAddrOperands={})
static void extractAtomicControlFlags(omp::AtomicUpdateOp atomicUpdateOp, bool &isIgnoreDenormalMode, bool &isFineGrainedMemory, bool &isRemoteMemory)
static Operation * genLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr, unsigned numCases, bool needsUniv, ArrayRef< TensorLevel > tidLvls)
Generates a for-loop or a while-loop, depending on whether it implements singleton iteration or co-it...
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition TypeID.h:331
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:309
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
Operation & front()
Definition Block.h:163
Operation & back()
Definition Block.h:162
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
iterator begin()
Definition Block.h:153
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
An instance of this location represents a tuple of file, line number, and column number.
Definition Location.h:174
Implementation class for module translation.
llvm::BasicBlock * lookupBlock(Block *block) const
Finds an LLVM IR basic block that corresponds to the given MLIR block.
WalkResult stackWalk(llvm::function_ref< WalkResult(T &)> callback)
Calls callback for every ModuleTranslation stack frame of type T starting from the top of the stack.
void stackPush(Args &&...args)
Creates a stack frame of type T on ModuleTranslation stack.
LogicalResult convertBlock(Block &bb, bool ignoreArguments, llvm::IRBuilderBase &builder)
Translates the contents of the given block to LLVM IR using this translator.
SmallVector< llvm::Value * > lookupValues(ValueRange values)
Looks up remapped a list of remapped values.
void mapFunction(StringRef name, llvm::Function *func)
Stores the mapping between a function name and its LLVM IR representation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
void invalidateOmpLoop(omp::NewCliOp mlir)
Mark an OpenMP loop as having been consumed.
SymbolTableCollection & symbolTable()
llvm::Type * convertType(Type type)
Converts the type from MLIR LLVM dialect to LLVM.
llvm::OpenMPIRBuilder * getOpenMPBuilder()
Returns the OpenMP IR builder associated with the LLVM IR module being constructed.
void mapOmpLoop(omp::NewCliOp mlir, llvm::CanonicalLoopInfo *llvm)
Map an MLIR OpenMP dialect CanonicalLoopInfo to its lowered LLVM-IR OpenMPIRBuilder CanonicalLoopInfo...
llvm::GlobalValue * lookupGlobal(Operation *op)
Finds an LLVM IR global value that corresponds to the given MLIR operation defining a global value.
SaveStateStack< T, ModuleTranslation > SaveStack
RAII object calling stackPush/stackPop on construction/destruction.
llvm::Function * lookupFunction(StringRef name) const
Finds an LLVM IR function by its name.
void mapBlock(Block *mlir, llvm::BasicBlock *llvm)
Stores the mapping between an MLIR block and LLVM IR basic block.
llvm::Module * getLLVMModule()
Returns the LLVM module in which the IR is being constructed.
void stackPop()
Pops the last element from the ModuleTranslation stack.
void forgetMapping(Region &region)
Removes the mapping for blocks contained in the region and values defined in these blocks.
void mapValue(Value mlir, llvm::Value *llvm)
Stores the mapping between an MLIR value and its LLVM IR counterpart.
llvm::CanonicalLoopInfo * lookupOMPLoop(omp::NewCliOp mlir) const
Find the LLVM-IR loop that represents an MLIR loop.
llvm::LLVMContext & getLLVMContext() const
Returns the LLVM context in which the IR is being constructed.
Utility class to translate MLIR LLVM dialect types to LLVM IR.
Definition TypeToLLVM.h:39
unsigned getPreferredAlignment(Type type, const llvm::DataLayout &layout)
Returns the preferred alignment for the type given the data layout.
T findInstanceOf()
Return an instance of the given location type if one is nested under the current location.
Definition Location.h:45
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
void appendDialectRegistry(const DialectRegistry &registry)
Append the contents of the given dialect registry to the registry associated with this context.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:715
Value getOperand(unsigned idx)
Definition Operation.h:379
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:436
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:244
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:255
unsigned getNumOperands()
Definition Operation.h:375
OperandRange operand_range
Definition Operation.h:400
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:259
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:407
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:826
user_range getUsers()
Returns a range of all users.
Definition Operation.h:902
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:237
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
BlockArgListType getArguments()
Definition Region.h:81
bool empty()
Definition Region.h:60
unsigned getNumArguments()
Definition Region.h:123
iterator begin()
Definition Region.h:55
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition Region.h:200
BlockListType & getBlocks()
Definition Region.h:45
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
Concrete CRTP base class for StateStack frames.
Definition StateStack.h:47
@ Private
The symbol is private and may only be referenced by SymbolRefAttrs local to the operations within the...
Definition SymbolTable.h:97
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult advance()
Definition WalkResult.h:47
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition WalkResult.h:51
static WalkResult interrupt()
Definition WalkResult.h:46
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Definition CallGraph.h:229
void connectPHINodes(Region &region, const ModuleTranslation &state)
For all blocks in the region that were converted to LLVM IR using the given ModuleTranslation,...
llvm::Constant * createMappingInformation(Location loc, llvm::OpenMPIRBuilder &builder)
Create a constant string representing the mapping information extracted from the MLIR location inform...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
SetVector< Block * > getBlocksSortedByDominance(Region &region)
Gets a list of blocks that is sorted according to dominance.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void registerOpenMPDialectTranslation(DialectRegistry &registry)
Register the OpenMP dialect and the translation from it to the LLVM IR in the given registry;.
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:123
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition Utils.cpp:1306
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:118
A util to collect info needed to convert delayed privatizers from MLIR to LLVM.
SmallVector< mlir::Value > mlirVars
SmallVector< omp::PrivateClauseOp > privatizers
MutableArrayRef< BlockArgument > blockArgs
SmallVector< llvm::Value * > llvmVars