MLIR 23.0.0git
OpenMPToLLVMIRTranslation.cpp
Go to the documentation of this file.
1//===- OpenMPToLLVMIRTranslation.cpp - Translate OpenMP dialect to LLVM IR-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a translation between the MLIR OpenMP dialect and LLVM
10// IR.
11//
12//===----------------------------------------------------------------------===//
19#include "mlir/IR/Operation.h"
20#include "mlir/Support/LLVM.h"
23
24#include "llvm/ADT/ArrayRef.h"
25#include "llvm/ADT/SmallVector.h"
26#include "llvm/ADT/TypeSwitch.h"
27#include "llvm/Frontend/OpenMP/OMPConstants.h"
28#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
29#include "llvm/IR/Constants.h"
30#include "llvm/IR/DebugInfoMetadata.h"
31#include "llvm/IR/DerivedTypes.h"
32#include "llvm/IR/IRBuilder.h"
33#include "llvm/IR/MDBuilder.h"
34#include "llvm/IR/ReplaceConstant.h"
35#include "llvm/Support/FileSystem.h"
36#include "llvm/Support/VirtualFileSystem.h"
37#include "llvm/TargetParser/Triple.h"
38#include "llvm/Transforms/Utils/ModuleUtils.h"
39
40#include <cstdint>
41#include <iterator>
42#include <numeric>
43#include <optional>
44#include <utility>
45
46using namespace mlir;
47
48namespace {
49static llvm::omp::ScheduleKind
50convertToScheduleKind(std::optional<omp::ClauseScheduleKind> schedKind) {
51 if (!schedKind.has_value())
52 return llvm::omp::OMP_SCHEDULE_Default;
53 switch (schedKind.value()) {
54 case omp::ClauseScheduleKind::Static:
55 return llvm::omp::OMP_SCHEDULE_Static;
56 case omp::ClauseScheduleKind::Dynamic:
57 return llvm::omp::OMP_SCHEDULE_Dynamic;
58 case omp::ClauseScheduleKind::Guided:
59 return llvm::omp::OMP_SCHEDULE_Guided;
60 case omp::ClauseScheduleKind::Auto:
61 return llvm::omp::OMP_SCHEDULE_Auto;
62 case omp::ClauseScheduleKind::Runtime:
63 return llvm::omp::OMP_SCHEDULE_Runtime;
64 case omp::ClauseScheduleKind::Distribute:
65 return llvm::omp::OMP_SCHEDULE_Distribute;
66 }
67 llvm_unreachable("unhandled schedule clause argument");
68}
69
70/// ModuleTranslation stack frame for OpenMP operations. This keeps track of the
71/// insertion points for allocas.
72class OpenMPAllocaStackFrame
73 : public StateStackFrameBase<OpenMPAllocaStackFrame> {
74public:
76
77 explicit OpenMPAllocaStackFrame(llvm::OpenMPIRBuilder::InsertPointTy allocaIP)
78 : allocaInsertPoint(allocaIP) {}
79 llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
80};
81
82/// Stack frame to hold a \see llvm::CanonicalLoopInfo representing the
83/// collapsed canonical loop information corresponding to an \c omp.loop_nest
84/// operation.
85class OpenMPLoopInfoStackFrame
86 : public StateStackFrameBase<OpenMPLoopInfoStackFrame> {
87public:
88 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPLoopInfoStackFrame)
89 llvm::CanonicalLoopInfo *loopInfo = nullptr;
90};
91
92/// Custom error class to signal translation errors that don't need reporting,
93/// since encountering them will have already triggered relevant error messages.
94///
95/// Its purpose is to serve as the glue between MLIR failures represented as
96/// \see LogicalResult instances and \see llvm::Error instances used to
97/// propagate errors through the \see llvm::OpenMPIRBuilder. Generally, when an
98/// error of the first type is raised, a message is emitted directly (the \see
99/// LogicalResult itself does not hold any information). If we need to forward
100/// this error condition as an \see llvm::Error while avoiding triggering some
101/// redundant error reporting later on, we need a custom \see llvm::ErrorInfo
102/// class to just signal this situation has happened.
103///
104/// For example, this class should be used to trigger errors from within
105/// callbacks passed to the \see OpenMPIRBuilder when they were triggered by the
106/// translation of their own regions. This unclutters the error log from
107/// redundant messages.
108class PreviouslyReportedError
109 : public llvm::ErrorInfo<PreviouslyReportedError> {
110public:
111 void log(raw_ostream &) const override {
112 // Do not log anything.
113 }
114
115 std::error_code convertToErrorCode() const override {
116 llvm_unreachable(
117 "PreviouslyReportedError doesn't support ECError conversion");
118 }
119
120 // Used by ErrorInfo::classID.
121 static char ID;
122};
123
124char PreviouslyReportedError::ID = 0;
125
126/*
127 * Custom class for processing linear clause for omp.wsloop
128 * and omp.simd. Linear clause translation requires setup,
129 * initialization, update, and finalization at varying
130 * basic blocks in the IR. This class helps maintain
131 * internal state to allow consistent translation in
132 * each of these stages.
133 */
134
135class LinearClauseProcessor {
136
137private:
138 SmallVector<llvm::Value *> linearPreconditionVars;
139 SmallVector<llvm::Value *> linearLoopBodyTemps;
140 SmallVector<llvm::Value *> linearOrigVal;
141 SmallVector<llvm::Value *> linearSteps;
142 SmallVector<llvm::Type *> linearVarTypes;
143 llvm::BasicBlock *linearFinalizationBB;
144 llvm::BasicBlock *linearExitBB;
145 llvm::BasicBlock *linearLastIterExitBB;
146
147public:
148 // Register type for the linear variables
149 void registerType(LLVM::ModuleTranslation &moduleTranslation,
150 mlir::Attribute &ty) {
151 linearVarTypes.push_back(moduleTranslation.convertType(
152 mlir::cast<mlir::TypeAttr>(ty).getValue()));
153 }
154
155 // Allocate space for linear variabes
156 void createLinearVar(llvm::IRBuilderBase &builder,
157 LLVM::ModuleTranslation &moduleTranslation,
158 llvm::Value *linearVar, int idx) {
159 linearPreconditionVars.push_back(
160 builder.CreateAlloca(linearVarTypes[idx], nullptr, ".linear_var"));
161 llvm::Value *linearLoopBodyTemp =
162 builder.CreateAlloca(linearVarTypes[idx], nullptr, ".linear_result");
163 linearOrigVal.push_back(linearVar);
164 linearLoopBodyTemps.push_back(linearLoopBodyTemp);
165 }
166
167 // Initialize linear step
168 inline void initLinearStep(LLVM::ModuleTranslation &moduleTranslation,
169 mlir::Value &linearStep) {
170 linearSteps.push_back(moduleTranslation.lookupValue(linearStep));
171 }
172
173 // Emit IR for initialization of linear variables
174 void initLinearVar(llvm::IRBuilderBase &builder,
175 LLVM::ModuleTranslation &moduleTranslation,
176 llvm::BasicBlock *loopPreHeader) {
177 builder.SetInsertPoint(loopPreHeader->getTerminator());
178 for (size_t index = 0; index < linearOrigVal.size(); index++) {
179 llvm::LoadInst *linearVarLoad =
180 builder.CreateLoad(linearVarTypes[index], linearOrigVal[index]);
181 builder.CreateStore(linearVarLoad, linearPreconditionVars[index]);
182 }
183 }
184
185 // Emit IR for updating Linear variables
186 void updateLinearVar(llvm::IRBuilderBase &builder, llvm::BasicBlock *loopBody,
187 llvm::Value *loopInductionVar) {
188 builder.SetInsertPoint(loopBody->getTerminator());
189 for (size_t index = 0; index < linearPreconditionVars.size(); index++) {
190 llvm::Type *linearVarType = linearVarTypes[index];
191 llvm::Value *iv = loopInductionVar;
192 llvm::Value *step = linearSteps[index];
193
194 if (!iv->getType()->isIntegerTy())
195 llvm_unreachable("OpenMP loop induction variable must be an integer "
196 "type");
197
198 if (linearVarType->isIntegerTy()) {
199 // Integer path: normalize all arithmetic to linearVarType
200 iv = builder.CreateSExtOrTrunc(iv, linearVarType);
201 step = builder.CreateSExtOrTrunc(step, linearVarType);
202
203 llvm::LoadInst *linearVarStart =
204 builder.CreateLoad(linearVarType, linearPreconditionVars[index]);
205 llvm::Value *mulInst = builder.CreateMul(iv, step);
206 llvm::Value *addInst = builder.CreateAdd(linearVarStart, mulInst);
207 builder.CreateStore(addInst, linearLoopBodyTemps[index]);
208 } else if (linearVarType->isFloatingPointTy()) {
209 // Float path: perform multiply in integer, then convert to float
210 step = builder.CreateSExtOrTrunc(step, iv->getType());
211 llvm::Value *mulInst = builder.CreateMul(iv, step);
212
213 llvm::LoadInst *linearVarStart =
214 builder.CreateLoad(linearVarType, linearPreconditionVars[index]);
215 llvm::Value *mulFp = builder.CreateSIToFP(mulInst, linearVarType);
216 llvm::Value *addInst = builder.CreateFAdd(linearVarStart, mulFp);
217 builder.CreateStore(addInst, linearLoopBodyTemps[index]);
218 } else {
219 llvm_unreachable(
220 "Linear variable must be of integer or floating-point type");
221 }
222 }
223 }
224
225 // Linear variable finalization is conditional on the last logical iteration.
226 // Create BB splits to manage the same.
227 void splitLinearFiniBB(llvm::IRBuilderBase &builder,
228 llvm::BasicBlock *loopExit) {
229 linearFinalizationBB = loopExit->splitBasicBlock(
230 loopExit->getTerminator(), "omp_loop.linear_finalization");
231 linearExitBB = linearFinalizationBB->splitBasicBlock(
232 linearFinalizationBB->getTerminator(), "omp_loop.linear_exit");
233 linearLastIterExitBB = linearFinalizationBB->splitBasicBlock(
234 linearFinalizationBB->getTerminator(), "omp_loop.linear_lastiter_exit");
235 }
236
237 // Finalize the linear vars
238 llvm::OpenMPIRBuilder::InsertPointOrErrorTy
239 finalizeLinearVar(llvm::IRBuilderBase &builder,
240 LLVM::ModuleTranslation &moduleTranslation,
241 llvm::Value *lastIter) {
242 // Emit condition to check whether last logical iteration is being executed
243 builder.SetInsertPoint(linearFinalizationBB->getTerminator());
244 llvm::Value *loopLastIterLoad = builder.CreateLoad(
245 llvm::Type::getInt32Ty(builder.getContext()), lastIter);
246 llvm::Value *isLast =
247 builder.CreateCmp(llvm::CmpInst::ICMP_NE, loopLastIterLoad,
248 llvm::ConstantInt::get(
249 llvm::Type::getInt32Ty(builder.getContext()), 0));
250 // Store the linear variable values to original variables.
251 builder.SetInsertPoint(linearLastIterExitBB->getTerminator());
252 for (size_t index = 0; index < linearOrigVal.size(); index++) {
253 llvm::LoadInst *linearVarTemp =
254 builder.CreateLoad(linearVarTypes[index], linearLoopBodyTemps[index]);
255 builder.CreateStore(linearVarTemp, linearOrigVal[index]);
256 }
257
258 // Create conditional branch such that the linear variable
259 // values are stored to original variables only at the
260 // last logical iteration
261 builder.SetInsertPoint(linearFinalizationBB->getTerminator());
262 builder.CreateCondBr(isLast, linearLastIterExitBB, linearExitBB);
263 linearFinalizationBB->getTerminator()->eraseFromParent();
264 // Emit barrier
265 builder.SetInsertPoint(linearExitBB->getTerminator());
266 return moduleTranslation.getOpenMPBuilder()->createBarrier(
267 builder.saveIP(), llvm::omp::OMPD_barrier);
268 }
269
270 // Emit stores for linear variables. Useful in case of SIMD
271 // construct.
272 void emitStoresForLinearVar(llvm::IRBuilderBase &builder) {
273 for (size_t index = 0; index < linearOrigVal.size(); index++) {
274 llvm::LoadInst *linearVarTemp =
275 builder.CreateLoad(linearVarTypes[index], linearLoopBodyTemps[index]);
276 builder.CreateStore(linearVarTemp, linearOrigVal[index]);
277 }
278 }
279
280 // Rewrite all uses of the original variable in `BBName`
281 // with the linear variable in-place
282 void rewriteInPlace(llvm::IRBuilderBase &builder, const std::string &BBName,
283 size_t varIndex) {
284 llvm::SmallVector<llvm::User *> users;
285 for (llvm::User *user : linearOrigVal[varIndex]->users())
286 users.push_back(user);
287 for (auto *user : users) {
288 if (auto *userInst = dyn_cast<llvm::Instruction>(user)) {
289 if (userInst->getParent()->getName().str().find(BBName) !=
290 std::string::npos)
291 user->replaceUsesOfWith(linearOrigVal[varIndex],
292 linearLoopBodyTemps[varIndex]);
293 }
294 }
295 }
296};
297
298} // namespace
299
300/// Looks up from the operation from and returns the PrivateClauseOp with
301/// name symbolName
302static omp::PrivateClauseOp findPrivatizer(Operation *from,
303 SymbolRefAttr symbolName) {
304 omp::PrivateClauseOp privatizer =
306 symbolName);
307 assert(privatizer && "privatizer not found in the symbol table");
308 return privatizer;
309}
310
311/// Check whether translation to LLVM IR for the given operation is currently
312/// supported. If not, descriptive diagnostics will be emitted to let users know
313/// this is a not-yet-implemented feature.
314///
315/// \returns success if no unimplemented features are needed to translate the
316/// given operation.
317static LogicalResult checkImplementationStatus(Operation &op) {
318 auto todo = [&op](StringRef clauseName) {
319 return op.emitError() << "not yet implemented: Unhandled clause "
320 << clauseName << " in " << op.getName()
321 << " operation";
322 };
323
324 auto checkAffinity = [&todo](auto op, LogicalResult &result) {
325 if (!op.getAffinityVars().empty())
326 result = todo("affinity");
327 };
328 auto checkAllocate = [&todo](auto op, LogicalResult &result) {
329 if (!op.getAllocateVars().empty() || !op.getAllocatorVars().empty())
330 result = todo("allocate");
331 };
332 auto checkBare = [&todo](auto op, LogicalResult &result) {
333 if (op.getBare())
334 result = todo("ompx_bare");
335 };
336 auto checkDepend = [&todo](auto op, LogicalResult &result) {
337 if (!op.getDependVars().empty() || op.getDependKinds())
338 result = todo("depend");
339 };
340 auto checkHint = [](auto op, LogicalResult &) {
341 if (op.getHint())
342 op.emitWarning("hint clause discarded");
343 };
344 auto checkInReduction = [&todo](auto op, LogicalResult &result) {
345 if (!op.getInReductionVars().empty() || op.getInReductionByref() ||
346 op.getInReductionSyms())
347 result = todo("in_reduction");
348 };
349 auto checkNowait = [&todo](auto op, LogicalResult &result) {
350 if (op.getNowait())
351 result = todo("nowait");
352 };
353 auto checkOrder = [&todo](auto op, LogicalResult &result) {
354 if (op.getOrder() || op.getOrderMod())
355 result = todo("order");
356 };
357 auto checkParLevelSimd = [&todo](auto op, LogicalResult &result) {
358 if (op.getParLevelSimd())
359 result = todo("parallelization-level");
360 };
361 auto checkPrivate = [&todo](auto op, LogicalResult &result) {
362 if (!op.getPrivateVars().empty() || op.getPrivateSyms())
363 result = todo("privatization");
364 };
365 auto checkReduction = [&todo](auto op, LogicalResult &result) {
366 if (isa<omp::TeamsOp>(op) || isa<omp::TaskloopOp>(op))
367 if (!op.getReductionVars().empty() || op.getReductionByref() ||
368 op.getReductionSyms())
369 result = todo("reduction");
370 if (op.getReductionMod() &&
371 op.getReductionMod().value() != omp::ReductionModifier::defaultmod)
372 result = todo("reduction with modifier");
373 };
374 auto checkTaskReduction = [&todo](auto op, LogicalResult &result) {
375 if (!op.getTaskReductionVars().empty() || op.getTaskReductionByref() ||
376 op.getTaskReductionSyms())
377 result = todo("task_reduction");
378 };
379 auto checkNumTeams = [&todo](auto op, LogicalResult &result) {
380 if (op.hasNumTeamsMultiDim())
381 result = todo("num_teams with multi-dimensional values");
382 };
383 auto checkNumThreads = [&todo](auto op, LogicalResult &result) {
384 if (op.hasNumThreadsMultiDim())
385 result = todo("num_threads with multi-dimensional values");
386 };
387
388 auto checkThreadLimit = [&todo](auto op, LogicalResult &result) {
389 if (op.hasThreadLimitMultiDim())
390 result = todo("thread_limit with multi-dimensional values");
391 };
392
393 LogicalResult result = success();
395 .Case([&](omp::DistributeOp op) {
396 checkAllocate(op, result);
397 checkOrder(op, result);
398 })
399 .Case([&](omp::OrderedRegionOp op) { checkParLevelSimd(op, result); })
400 .Case([&](omp::SectionsOp op) {
401 checkAllocate(op, result);
402 checkPrivate(op, result);
403 checkReduction(op, result);
404 })
405 .Case([&](omp::SingleOp op) {
406 checkAllocate(op, result);
407 checkPrivate(op, result);
408 })
409 .Case([&](omp::TeamsOp op) {
410 checkAllocate(op, result);
411 checkPrivate(op, result);
412 checkNumTeams(op, result);
413 checkThreadLimit(op, result);
414 })
415 .Case([&](omp::TaskOp op) {
416 checkAffinity(op, result);
417 checkAllocate(op, result);
418 checkInReduction(op, result);
419 })
420 .Case([&](omp::TaskgroupOp op) {
421 checkAllocate(op, result);
422 checkTaskReduction(op, result);
423 })
424 .Case([&](omp::TaskwaitOp op) {
425 checkDepend(op, result);
426 checkNowait(op, result);
427 })
428 .Case([&](omp::TaskloopOp op) {
429 checkAllocate(op, result);
430 checkInReduction(op, result);
431 checkReduction(op, result);
432 })
433 .Case([&](omp::WsloopOp op) {
434 checkAllocate(op, result);
435 checkOrder(op, result);
436 checkReduction(op, result);
437 })
438 .Case([&](omp::ParallelOp op) {
439 checkAllocate(op, result);
440 checkReduction(op, result);
441 checkNumThreads(op, result);
442 })
443 .Case([&](omp::SimdOp op) { checkReduction(op, result); })
444 .Case<omp::AtomicReadOp, omp::AtomicWriteOp, omp::AtomicUpdateOp,
445 omp::AtomicCaptureOp>([&](auto op) { checkHint(op, result); })
446 .Case<omp::TargetEnterDataOp, omp::TargetExitDataOp>(
447 [&](auto op) { checkDepend(op, result); })
448 .Case([&](omp::TargetUpdateOp op) { checkDepend(op, result); })
449 .Case([&](omp::TargetOp op) {
450 checkAllocate(op, result);
451 checkBare(op, result);
452 checkInReduction(op, result);
453 checkThreadLimit(op, result);
454 })
455 .Default([](Operation &) {
456 // Assume all clauses for an operation can be translated unless they are
457 // checked above.
458 });
459 return result;
460}
461
462static LogicalResult handleError(llvm::Error error, Operation &op) {
463 LogicalResult result = success();
464 if (error) {
465 llvm::handleAllErrors(
466 std::move(error),
467 [&](const PreviouslyReportedError &) { result = failure(); },
468 [&](const llvm::ErrorInfoBase &err) {
469 result = op.emitError(err.message());
470 });
471 }
472 return result;
473}
474
475template <typename T>
476static LogicalResult handleError(llvm::Expected<T> &result, Operation &op) {
477 if (!result)
478 return handleError(result.takeError(), op);
479
480 return success();
481}
482
483/// Find the insertion point for allocas given the current insertion point for
484/// normal operations in the builder.
485static llvm::OpenMPIRBuilder::InsertPointTy
486findAllocaInsertPoint(llvm::IRBuilderBase &builder,
487 LLVM::ModuleTranslation &moduleTranslation) {
488 // If there is an alloca insertion point on stack, i.e. we are in a nested
489 // operation and a specific point was provided by some surrounding operation,
490 // use it.
491 llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
492 WalkResult walkResult = moduleTranslation.stackWalk<OpenMPAllocaStackFrame>(
493 [&](OpenMPAllocaStackFrame &frame) {
494 allocaInsertPoint = frame.allocaInsertPoint;
495 return WalkResult::interrupt();
496 });
497 // In cases with multiple levels of outlining, the tree walk might find an
498 // alloca insertion point that is inside the original function while the
499 // builder insertion point is inside the outlined function. We need to make
500 // sure that we do not use it in those cases.
501 if (walkResult.wasInterrupted() &&
502 allocaInsertPoint.getBlock()->getParent() ==
503 builder.GetInsertBlock()->getParent())
504 return allocaInsertPoint;
505
506 // Otherwise, insert to the entry block of the surrounding function.
507 // If the current IRBuilder InsertPoint is the function's entry, it cannot
508 // also be used for alloca insertion which would result in insertion order
509 // confusion. Create a new BasicBlock for the Builder and use the entry block
510 // for the allocs.
511 // TODO: Create a dedicated alloca BasicBlock at function creation such that
512 // we do not need to move the current InertPoint here.
513 if (builder.GetInsertBlock() ==
514 &builder.GetInsertBlock()->getParent()->getEntryBlock()) {
515 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end() &&
516 "Assuming end of basic block");
517 llvm::BasicBlock *entryBB = llvm::BasicBlock::Create(
518 builder.getContext(), "entry", builder.GetInsertBlock()->getParent(),
519 builder.GetInsertBlock()->getNextNode());
520 builder.CreateBr(entryBB);
521 builder.SetInsertPoint(entryBB);
522 }
523
524 llvm::BasicBlock &funcEntryBlock =
525 builder.GetInsertBlock()->getParent()->getEntryBlock();
526 return llvm::OpenMPIRBuilder::InsertPointTy(
527 &funcEntryBlock, funcEntryBlock.getFirstInsertionPt());
528}
529
530/// Find the loop information structure for the loop nest being translated. It
531/// will return a `null` value unless called from the translation function for
532/// a loop wrapper operation after successfully translating its body.
533static llvm::CanonicalLoopInfo *
535 llvm::CanonicalLoopInfo *loopInfo = nullptr;
536 moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
537 [&](OpenMPLoopInfoStackFrame &frame) {
538 loopInfo = frame.loopInfo;
539 return WalkResult::interrupt();
540 });
541 return loopInfo;
542}
543
544/// Converts the given region that appears within an OpenMP dialect operation to
545/// LLVM IR, creating a branch from the `sourceBlock` to the entry block of the
546/// region, and a branch from any block with an successor-less OpenMP terminator
547/// to `continuationBlock`. Populates `continuationBlockPHIs` with the PHI nodes
548/// of the continuation block if provided.
550 Region &region, StringRef blockName, llvm::IRBuilderBase &builder,
551 LLVM::ModuleTranslation &moduleTranslation,
552 SmallVectorImpl<llvm::PHINode *> *continuationBlockPHIs = nullptr) {
553 bool isLoopWrapper = isa<omp::LoopWrapperInterface>(region.getParentOp());
554
555 llvm::BasicBlock *continuationBlock =
556 splitBB(builder, true, "omp.region.cont");
557 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
558
559 llvm::LLVMContext &llvmContext = builder.getContext();
560 for (Block &bb : region) {
561 llvm::BasicBlock *llvmBB = llvm::BasicBlock::Create(
562 llvmContext, blockName, builder.GetInsertBlock()->getParent(),
563 builder.GetInsertBlock()->getNextNode());
564 moduleTranslation.mapBlock(&bb, llvmBB);
565 }
566
567 llvm::Instruction *sourceTerminator = sourceBlock->getTerminator();
568
569 // Terminators (namely YieldOp) may be forwarding values to the region that
570 // need to be available in the continuation block. Collect the types of these
571 // operands in preparation of creating PHI nodes. This is skipped for loop
572 // wrapper operations, for which we know in advance they have no terminators.
573 SmallVector<llvm::Type *> continuationBlockPHITypes;
574 unsigned numYields = 0;
575
576 if (!isLoopWrapper) {
577 bool operandsProcessed = false;
578 for (Block &bb : region.getBlocks()) {
579 if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(bb.getTerminator())) {
580 if (!operandsProcessed) {
581 for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
582 continuationBlockPHITypes.push_back(
583 moduleTranslation.convertType(yield->getOperand(i).getType()));
584 }
585 operandsProcessed = true;
586 } else {
587 assert(continuationBlockPHITypes.size() == yield->getNumOperands() &&
588 "mismatching number of values yielded from the region");
589 for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
590 llvm::Type *operandType =
591 moduleTranslation.convertType(yield->getOperand(i).getType());
592 (void)operandType;
593 assert(continuationBlockPHITypes[i] == operandType &&
594 "values of mismatching types yielded from the region");
595 }
596 }
597 numYields++;
598 }
599 }
600 }
601
602 // Insert PHI nodes in the continuation block for any values forwarded by the
603 // terminators in this region.
604 if (!continuationBlockPHITypes.empty())
605 assert(
606 continuationBlockPHIs &&
607 "expected continuation block PHIs if converted regions yield values");
608 if (continuationBlockPHIs) {
609 llvm::IRBuilderBase::InsertPointGuard guard(builder);
610 continuationBlockPHIs->reserve(continuationBlockPHITypes.size());
611 builder.SetInsertPoint(continuationBlock, continuationBlock->begin());
612 for (llvm::Type *ty : continuationBlockPHITypes)
613 continuationBlockPHIs->push_back(builder.CreatePHI(ty, numYields));
614 }
615
616 // Convert blocks one by one in topological order to ensure
617 // defs are converted before uses.
619 for (Block *bb : blocks) {
620 llvm::BasicBlock *llvmBB = moduleTranslation.lookupBlock(bb);
621 // Retarget the branch of the entry block to the entry block of the
622 // converted region (regions are single-entry).
623 if (bb->isEntryBlock()) {
624 assert(sourceTerminator->getNumSuccessors() == 1 &&
625 "provided entry block has multiple successors");
626 assert(sourceTerminator->getSuccessor(0) == continuationBlock &&
627 "ContinuationBlock is not the successor of the entry block");
628 sourceTerminator->setSuccessor(0, llvmBB);
629 }
630
631 llvm::IRBuilderBase::InsertPointGuard guard(builder);
632 if (failed(
633 moduleTranslation.convertBlock(*bb, bb->isEntryBlock(), builder)))
634 return llvm::make_error<PreviouslyReportedError>();
635
636 // Create a direct branch here for loop wrappers to prevent their lack of a
637 // terminator from causing a crash below.
638 if (isLoopWrapper) {
639 builder.CreateBr(continuationBlock);
640 continue;
641 }
642
643 // Special handling for `omp.yield` and `omp.terminator` (we may have more
644 // than one): they return the control to the parent OpenMP dialect operation
645 // so replace them with the branch to the continuation block. We handle this
646 // here to avoid relying inter-function communication through the
647 // ModuleTranslation class to set up the correct insertion point. This is
648 // also consistent with MLIR's idiom of handling special region terminators
649 // in the same code that handles the region-owning operation.
650 Operation *terminator = bb->getTerminator();
651 if (isa<omp::TerminatorOp, omp::YieldOp>(terminator)) {
652 builder.CreateBr(continuationBlock);
653
654 for (unsigned i = 0, e = terminator->getNumOperands(); i < e; ++i)
655 (*continuationBlockPHIs)[i]->addIncoming(
656 moduleTranslation.lookupValue(terminator->getOperand(i)), llvmBB);
657 }
658 }
659 // After all blocks have been traversed and values mapped, connect the PHI
660 // nodes to the results of preceding blocks.
661 LLVM::detail::connectPHINodes(region, moduleTranslation);
662
663 // Remove the blocks and values defined in this region from the mapping since
664 // they are not visible outside of this region. This allows the same region to
665 // be converted several times, that is cloned, without clashes, and slightly
666 // speeds up the lookups.
667 moduleTranslation.forgetMapping(region);
668
669 return continuationBlock;
670}
671
672/// Convert ProcBindKind from MLIR-generated enum to LLVM enum.
673static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind) {
674 switch (kind) {
675 case omp::ClauseProcBindKind::Close:
676 return llvm::omp::ProcBindKind::OMP_PROC_BIND_close;
677 case omp::ClauseProcBindKind::Master:
678 return llvm::omp::ProcBindKind::OMP_PROC_BIND_master;
679 case omp::ClauseProcBindKind::Primary:
680 return llvm::omp::ProcBindKind::OMP_PROC_BIND_primary;
681 case omp::ClauseProcBindKind::Spread:
682 return llvm::omp::ProcBindKind::OMP_PROC_BIND_spread;
683 }
684 llvm_unreachable("Unknown ClauseProcBindKind kind");
685}
686
687/// Converts an OpenMP 'masked' operation into LLVM IR using OpenMPIRBuilder.
688static LogicalResult
689convertOmpMasked(Operation &opInst, llvm::IRBuilderBase &builder,
690 LLVM::ModuleTranslation &moduleTranslation) {
691 auto maskedOp = cast<omp::MaskedOp>(opInst);
692 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
693
694 if (failed(checkImplementationStatus(opInst)))
695 return failure();
696
697 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
698 // MaskedOp has only one region associated with it.
699 auto &region = maskedOp.getRegion();
700 builder.restoreIP(codeGenIP);
701 return convertOmpOpRegions(region, "omp.masked.region", builder,
702 moduleTranslation)
703 .takeError();
704 };
705
706 // TODO: Perform finalization actions for variables. This has to be
707 // called for variables which have destructors/finalizers.
708 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
709
710 llvm::Value *filterVal = nullptr;
711 if (auto filterVar = maskedOp.getFilteredThreadId()) {
712 filterVal = moduleTranslation.lookupValue(filterVar);
713 } else {
714 llvm::LLVMContext &llvmContext = builder.getContext();
715 filterVal =
716 llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext), /*V=*/0);
717 }
718 assert(filterVal != nullptr);
719 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
720 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
721 moduleTranslation.getOpenMPBuilder()->createMasked(ompLoc, bodyGenCB,
722 finiCB, filterVal);
723
724 if (failed(handleError(afterIP, opInst)))
725 return failure();
726
727 builder.restoreIP(*afterIP);
728 return success();
729}
730
731/// Converts an OpenMP 'master' operation into LLVM IR using OpenMPIRBuilder.
732static LogicalResult
733convertOmpMaster(Operation &opInst, llvm::IRBuilderBase &builder,
734 LLVM::ModuleTranslation &moduleTranslation) {
735 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
736 auto masterOp = cast<omp::MasterOp>(opInst);
737
738 if (failed(checkImplementationStatus(opInst)))
739 return failure();
740
741 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
742 // MasterOp has only one region associated with it.
743 auto &region = masterOp.getRegion();
744 builder.restoreIP(codeGenIP);
745 return convertOmpOpRegions(region, "omp.master.region", builder,
746 moduleTranslation)
747 .takeError();
748 };
749
750 // TODO: Perform finalization actions for variables. This has to be
751 // called for variables which have destructors/finalizers.
752 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
753
754 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
755 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
756 moduleTranslation.getOpenMPBuilder()->createMaster(ompLoc, bodyGenCB,
757 finiCB);
758
759 if (failed(handleError(afterIP, opInst)))
760 return failure();
761
762 builder.restoreIP(*afterIP);
763 return success();
764}
765
766/// Converts an OpenMP 'critical' operation into LLVM IR using OpenMPIRBuilder.
767static LogicalResult
768convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder,
769 LLVM::ModuleTranslation &moduleTranslation) {
770 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
771 auto criticalOp = cast<omp::CriticalOp>(opInst);
772
773 if (failed(checkImplementationStatus(opInst)))
774 return failure();
775
776 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
777 // CriticalOp has only one region associated with it.
778 auto &region = cast<omp::CriticalOp>(opInst).getRegion();
779 builder.restoreIP(codeGenIP);
780 return convertOmpOpRegions(region, "omp.critical.region", builder,
781 moduleTranslation)
782 .takeError();
783 };
784
785 // TODO: Perform finalization actions for variables. This has to be
786 // called for variables which have destructors/finalizers.
787 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
788
789 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
790 llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
791 llvm::Constant *hint = nullptr;
792
793 // If it has a name, it probably has a hint too.
794 if (criticalOp.getNameAttr()) {
795 // The verifiers in OpenMP Dialect guarentee that all the pointers are
796 // non-null
797 auto symbolRef = cast<SymbolRefAttr>(criticalOp.getNameAttr());
798 auto criticalDeclareOp =
800 symbolRef);
801 hint =
802 llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext),
803 static_cast<int>(criticalDeclareOp.getHint()));
804 }
805 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
806 moduleTranslation.getOpenMPBuilder()->createCritical(
807 ompLoc, bodyGenCB, finiCB, criticalOp.getName().value_or(""), hint);
808
809 if (failed(handleError(afterIP, opInst)))
810 return failure();
811
812 builder.restoreIP(*afterIP);
813 return success();
814}
815
816/// A util to collect info needed to convert delayed privatizers from MLIR to
817/// LLVM.
819 template <typename OP>
821 : blockArgs(
822 cast<omp::BlockArgOpenMPOpInterface>(*op).getPrivateBlockArgs()) {
823 mlirVars.reserve(blockArgs.size());
824 llvmVars.reserve(blockArgs.size());
825 collectPrivatizationDecls<OP>(op);
826
827 for (mlir::Value privateVar : op.getPrivateVars())
828 mlirVars.push_back(privateVar);
829 }
830
835
836private:
837 /// Populates `privatizations` with privatization declarations used for the
838 /// given op.
839 template <class OP>
840 void collectPrivatizationDecls(OP op) {
841 std::optional<ArrayAttr> attr = op.getPrivateSyms();
842 if (!attr)
843 return;
844
845 privatizers.reserve(privatizers.size() + attr->size());
846 for (auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
847 privatizers.push_back(findPrivatizer(op, symbolRef));
848 }
849 }
850};
851
852/// Populates `reductions` with reduction declarations used in the given op.
853template <typename T>
854static void
857 std::optional<ArrayAttr> attr = op.getReductionSyms();
858 if (!attr)
859 return;
860
861 reductions.reserve(reductions.size() + op.getNumReductionVars());
862 for (auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
863 reductions.push_back(
865 op, symbolRef));
866 }
867}
868
869/// Translates the blocks contained in the given region and appends them to at
870/// the current insertion point of `builder`. The operations of the entry block
871/// are appended to the current insertion block. If set, `continuationBlockArgs`
872/// is populated with translated values that correspond to the values
873/// omp.yield'ed from the region.
874static LogicalResult inlineConvertOmpRegions(
875 Region &region, StringRef blockName, llvm::IRBuilderBase &builder,
876 LLVM::ModuleTranslation &moduleTranslation,
877 SmallVectorImpl<llvm::Value *> *continuationBlockArgs = nullptr) {
878 if (region.empty())
879 return success();
880
881 // Special case for single-block regions that don't create additional blocks:
882 // insert operations without creating additional blocks.
883 if (region.hasOneBlock()) {
884 llvm::Instruction *potentialTerminator =
885 builder.GetInsertBlock()->empty() ? nullptr
886 : &builder.GetInsertBlock()->back();
887
888 if (potentialTerminator && potentialTerminator->isTerminator())
889 potentialTerminator->removeFromParent();
890 moduleTranslation.mapBlock(&region.front(), builder.GetInsertBlock());
891
892 if (failed(moduleTranslation.convertBlock(
893 region.front(), /*ignoreArguments=*/true, builder)))
894 return failure();
895
896 // The continuation arguments are simply the translated terminator operands.
897 if (continuationBlockArgs)
898 llvm::append_range(
899 *continuationBlockArgs,
900 moduleTranslation.lookupValues(region.front().back().getOperands()));
901
902 // Drop the mapping that is no longer necessary so that the same region can
903 // be processed multiple times.
904 moduleTranslation.forgetMapping(region);
905
906 if (potentialTerminator && potentialTerminator->isTerminator()) {
907 llvm::BasicBlock *block = builder.GetInsertBlock();
908 if (block->empty()) {
909 // this can happen for really simple reduction init regions e.g.
910 // %0 = llvm.mlir.constant(0 : i32) : i32
911 // omp.yield(%0 : i32)
912 // because the llvm.mlir.constant (MLIR op) isn't converted into any
913 // llvm op
914 potentialTerminator->insertInto(block, block->begin());
915 } else {
916 potentialTerminator->insertAfter(&block->back());
917 }
918 }
919
920 return success();
921 }
922
924 llvm::Expected<llvm::BasicBlock *> continuationBlock =
925 convertOmpOpRegions(region, blockName, builder, moduleTranslation, &phis);
926
927 if (failed(handleError(continuationBlock, *region.getParentOp())))
928 return failure();
929
930 if (continuationBlockArgs)
931 llvm::append_range(*continuationBlockArgs, phis);
932 builder.SetInsertPoint(*continuationBlock,
933 (*continuationBlock)->getFirstInsertionPt());
934 return success();
935}
936
937namespace {
938/// Owning equivalents of OpenMPIRBuilder::(Atomic)ReductionGen that are used to
939/// store lambdas with capture.
940using OwningReductionGen =
941 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
942 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *,
943 llvm::Value *&)>;
944using OwningAtomicReductionGen =
945 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
946 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Type *, llvm::Value *,
947 llvm::Value *)>;
948using OwningDataPtrPtrReductionGen =
949 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
950 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *&)>;
951} // namespace
952
953/// Create an OpenMPIRBuilder-compatible reduction generator for the given
954/// reduction declaration. The generator uses `builder` but ignores its
955/// insertion point.
956static OwningReductionGen
957makeReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder,
958 LLVM::ModuleTranslation &moduleTranslation) {
959 // The lambda is mutable because we need access to non-const methods of decl
960 // (which aren't actually mutating it), and we must capture decl by-value to
961 // avoid the dangling reference after the parent function returns.
962 OwningReductionGen gen =
963 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
964 llvm::Value *lhs, llvm::Value *rhs,
965 llvm::Value *&result) mutable
966 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
967 moduleTranslation.mapValue(decl.getReductionLhsArg(), lhs);
968 moduleTranslation.mapValue(decl.getReductionRhsArg(), rhs);
969 builder.restoreIP(insertPoint);
971 if (failed(inlineConvertOmpRegions(decl.getReductionRegion(),
972 "omp.reduction.nonatomic.body", builder,
973 moduleTranslation, &phis)))
974 return llvm::createStringError(
975 "failed to inline `combiner` region of `omp.declare_reduction`");
976 result = llvm::getSingleElement(phis);
977 return builder.saveIP();
978 };
979 return gen;
980}
981
982/// Create an OpenMPIRBuilder-compatible atomic reduction generator for the
983/// given reduction declaration. The generator uses `builder` but ignores its
984/// insertion point. Returns null if there is no atomic region available in the
985/// reduction declaration.
986static OwningAtomicReductionGen
987makeAtomicReductionGen(omp::DeclareReductionOp decl,
988 llvm::IRBuilderBase &builder,
989 LLVM::ModuleTranslation &moduleTranslation) {
990 if (decl.getAtomicReductionRegion().empty())
991 return OwningAtomicReductionGen();
992
993 // The lambda is mutable because we need access to non-const methods of decl
994 // (which aren't actually mutating it), and we must capture decl by-value to
995 // avoid the dangling reference after the parent function returns.
996 OwningAtomicReductionGen atomicGen =
997 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint, llvm::Type *,
998 llvm::Value *lhs, llvm::Value *rhs) mutable
999 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1000 moduleTranslation.mapValue(decl.getAtomicReductionLhsArg(), lhs);
1001 moduleTranslation.mapValue(decl.getAtomicReductionRhsArg(), rhs);
1002 builder.restoreIP(insertPoint);
1004 if (failed(inlineConvertOmpRegions(decl.getAtomicReductionRegion(),
1005 "omp.reduction.atomic.body", builder,
1006 moduleTranslation, &phis)))
1007 return llvm::createStringError(
1008 "failed to inline `atomic` region of `omp.declare_reduction`");
1009 assert(phis.empty());
1010 return builder.saveIP();
1011 };
1012 return atomicGen;
1013}
1014
1015/// Create an OpenMPIRBuilder-compatible `data_ptr_ptr` reduction generator for
1016/// the given reduction declaration. The generator uses `builder` but ignores
1017/// its insertion point. Returns null if there is no `data_ptr_ptr` region
1018/// available in the reduction declaration.
1019static OwningDataPtrPtrReductionGen
1020makeRefDataPtrGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder,
1021 LLVM::ModuleTranslation &moduleTranslation, bool isByRef) {
1022 if (!isByRef)
1023 return OwningDataPtrPtrReductionGen();
1024
1025 OwningDataPtrPtrReductionGen refDataPtrGen =
1026 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
1027 llvm::Value *byRefVal, llvm::Value *&result) mutable
1028 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1029 moduleTranslation.mapValue(decl.getDataPtrPtrRegionArg(), byRefVal);
1030 builder.restoreIP(insertPoint);
1032 if (failed(inlineConvertOmpRegions(decl.getDataPtrPtrRegion(),
1033 "omp.data_ptr_ptr.body", builder,
1034 moduleTranslation, &phis)))
1035 return llvm::createStringError(
1036 "failed to inline `data_ptr_ptr` region of `omp.declare_reduction`");
1037 result = llvm::getSingleElement(phis);
1038 return builder.saveIP();
1039 };
1040
1041 return refDataPtrGen;
1042}
1043
1044/// Converts an OpenMP 'ordered' operation into LLVM IR using OpenMPIRBuilder.
1045static LogicalResult
1046convertOmpOrdered(Operation &opInst, llvm::IRBuilderBase &builder,
1047 LLVM::ModuleTranslation &moduleTranslation) {
1048 auto orderedOp = cast<omp::OrderedOp>(opInst);
1049
1050 if (failed(checkImplementationStatus(opInst)))
1051 return failure();
1052
1053 omp::ClauseDepend dependType = *orderedOp.getDoacrossDependType();
1054 bool isDependSource = dependType == omp::ClauseDepend::dependsource;
1055 unsigned numLoops = *orderedOp.getDoacrossNumLoops();
1056 SmallVector<llvm::Value *> vecValues =
1057 moduleTranslation.lookupValues(orderedOp.getDoacrossDependVars());
1058
1059 size_t indexVecValues = 0;
1060 while (indexVecValues < vecValues.size()) {
1061 SmallVector<llvm::Value *> storeValues;
1062 storeValues.reserve(numLoops);
1063 for (unsigned i = 0; i < numLoops; i++) {
1064 storeValues.push_back(vecValues[indexVecValues]);
1065 indexVecValues++;
1066 }
1067 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1068 findAllocaInsertPoint(builder, moduleTranslation);
1069 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1070 builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createOrderedDepend(
1071 ompLoc, allocaIP, numLoops, storeValues, ".cnt.addr", isDependSource));
1072 }
1073 return success();
1074}
1075
1076/// Converts an OpenMP 'ordered_region' operation into LLVM IR using
1077/// OpenMPIRBuilder.
1078static LogicalResult
1079convertOmpOrderedRegion(Operation &opInst, llvm::IRBuilderBase &builder,
1080 LLVM::ModuleTranslation &moduleTranslation) {
1081 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1082 auto orderedRegionOp = cast<omp::OrderedRegionOp>(opInst);
1083
1084 if (failed(checkImplementationStatus(opInst)))
1085 return failure();
1086
1087 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
1088 // OrderedOp has only one region associated with it.
1089 auto &region = cast<omp::OrderedRegionOp>(opInst).getRegion();
1090 builder.restoreIP(codeGenIP);
1091 return convertOmpOpRegions(region, "omp.ordered.region", builder,
1092 moduleTranslation)
1093 .takeError();
1094 };
1095
1096 // TODO: Perform finalization actions for variables. This has to be
1097 // called for variables which have destructors/finalizers.
1098 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
1099
1100 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1101 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1102 moduleTranslation.getOpenMPBuilder()->createOrderedThreadsSimd(
1103 ompLoc, bodyGenCB, finiCB, !orderedRegionOp.getParLevelSimd());
1104
1105 if (failed(handleError(afterIP, opInst)))
1106 return failure();
1107
1108 builder.restoreIP(*afterIP);
1109 return success();
1110}
1111
1112namespace {
1113/// Contains the arguments for an LLVM store operation
1114struct DeferredStore {
1115 DeferredStore(llvm::Value *value, llvm::Value *address)
1116 : value(value), address(address) {}
1117
1118 llvm::Value *value;
1119 llvm::Value *address;
1120};
1121} // namespace
1122
1123/// Allocate space for privatized reduction variables.
1124/// `deferredStores` contains information to create store operations which needs
1125/// to be inserted after all allocas
1126template <typename T>
1127static LogicalResult
1129 llvm::IRBuilderBase &builder,
1130 LLVM::ModuleTranslation &moduleTranslation,
1131 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1133 SmallVectorImpl<llvm::Value *> &privateReductionVariables,
1134 DenseMap<Value, llvm::Value *> &reductionVariableMap,
1135 SmallVectorImpl<DeferredStore> &deferredStores,
1136 llvm::ArrayRef<bool> isByRefs) {
1137 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1138 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1139
1140 // delay creating stores until after all allocas
1141 deferredStores.reserve(loop.getNumReductionVars());
1142
1143 for (std::size_t i = 0; i < loop.getNumReductionVars(); ++i) {
1144 Region &allocRegion = reductionDecls[i].getAllocRegion();
1145 if (isByRefs[i]) {
1146 if (allocRegion.empty())
1147 continue;
1148
1150 if (failed(inlineConvertOmpRegions(allocRegion, "omp.reduction.alloc",
1151 builder, moduleTranslation, &phis)))
1152 return loop.emitError(
1153 "failed to inline `alloc` region of `omp.declare_reduction`");
1154
1155 assert(phis.size() == 1 && "expected one allocation to be yielded");
1156 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1157
1158 // Allocate reduction variable (which is a pointer to the real reduction
1159 // variable allocated in the inlined region)
1160 llvm::Value *var = builder.CreateAlloca(
1161 moduleTranslation.convertType(reductionDecls[i].getType()));
1162
1163 llvm::Type *ptrTy = builder.getPtrTy();
1164 llvm::Value *castVar =
1165 builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
1166 llvm::Value *castPhi =
1167 builder.CreatePointerBitCastOrAddrSpaceCast(phis[0], ptrTy);
1168
1169 deferredStores.emplace_back(castPhi, castVar);
1170
1171 privateReductionVariables[i] = castVar;
1172 moduleTranslation.mapValue(reductionArgs[i], castPhi);
1173 reductionVariableMap.try_emplace(loop.getReductionVars()[i], castPhi);
1174 } else {
1175 assert(allocRegion.empty() &&
1176 "allocaction is implicit for by-val reduction");
1177 llvm::Value *var = builder.CreateAlloca(
1178 moduleTranslation.convertType(reductionDecls[i].getType()));
1179
1180 llvm::Type *ptrTy = builder.getPtrTy();
1181 llvm::Value *castVar =
1182 builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
1183
1184 moduleTranslation.mapValue(reductionArgs[i], castVar);
1185 privateReductionVariables[i] = castVar;
1186 reductionVariableMap.try_emplace(loop.getReductionVars()[i], castVar);
1187 }
1188 }
1189
1190 return success();
1191}
1192
1193/// Map input arguments to reduction initialization region
1194template <typename T>
1195static void
1197 llvm::IRBuilderBase &builder,
1199 DenseMap<Value, llvm::Value *> &reductionVariableMap,
1200 unsigned i) {
1201 // map input argument to the initialization region
1202 mlir::omp::DeclareReductionOp &reduction = reductionDecls[i];
1203 Region &initializerRegion = reduction.getInitializerRegion();
1204 Block &entry = initializerRegion.front();
1205
1206 mlir::Value mlirSource = loop.getReductionVars()[i];
1207 llvm::Value *llvmSource = moduleTranslation.lookupValue(mlirSource);
1208 llvm::Value *origVal = llvmSource;
1209 // If a non-pointer value is expected, load the value from the source pointer.
1210 if (!isa<LLVM::LLVMPointerType>(
1211 reduction.getInitializerMoldArg().getType()) &&
1212 isa<LLVM::LLVMPointerType>(mlirSource.getType())) {
1213 origVal =
1214 builder.CreateLoad(moduleTranslation.convertType(
1215 reduction.getInitializerMoldArg().getType()),
1216 llvmSource, "omp_orig");
1217 }
1218 moduleTranslation.mapValue(reduction.getInitializerMoldArg(), origVal);
1219
1220 if (entry.getNumArguments() > 1) {
1221 llvm::Value *allocation =
1222 reductionVariableMap.lookup(loop.getReductionVars()[i]);
1223 moduleTranslation.mapValue(reduction.getInitializerAllocArg(), allocation);
1224 }
1225}
1226
1227static void
1228setInsertPointForPossiblyEmptyBlock(llvm::IRBuilderBase &builder,
1229 llvm::BasicBlock *block = nullptr) {
1230 if (block == nullptr)
1231 block = builder.GetInsertBlock();
1232
1233 if (block->empty() || block->getTerminator() == nullptr)
1234 builder.SetInsertPoint(block);
1235 else
1236 builder.SetInsertPoint(block->getTerminator());
1237}
1238
1239/// Inline reductions' `init` regions. This functions assumes that the
1240/// `builder`'s insertion point is where the user wants the `init` regions to be
1241/// inlined; i.e. it does not try to find a proper insertion location for the
1242/// `init` regions. It also leaves the `builder's insertions point in a state
1243/// where the user can continue the code-gen directly afterwards.
1244template <typename OP>
1245static LogicalResult
1246initReductionVars(OP op, ArrayRef<BlockArgument> reductionArgs,
1247 llvm::IRBuilderBase &builder,
1248 LLVM::ModuleTranslation &moduleTranslation,
1249 llvm::BasicBlock *latestAllocaBlock,
1251 SmallVectorImpl<llvm::Value *> &privateReductionVariables,
1252 DenseMap<Value, llvm::Value *> &reductionVariableMap,
1253 llvm::ArrayRef<bool> isByRef,
1254 SmallVectorImpl<DeferredStore> &deferredStores) {
1255 if (op.getNumReductionVars() == 0)
1256 return success();
1257
1258 llvm::BasicBlock *initBlock = splitBB(builder, true, "omp.reduction.init");
1259 auto allocaIP = llvm::IRBuilderBase::InsertPoint(
1260 latestAllocaBlock, latestAllocaBlock->getTerminator()->getIterator());
1261 builder.restoreIP(allocaIP);
1262 SmallVector<llvm::Value *> byRefVars(op.getNumReductionVars());
1263
1264 for (unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1265 if (isByRef[i]) {
1266 if (!reductionDecls[i].getAllocRegion().empty())
1267 continue;
1268
1269 // TODO: remove after all users of by-ref are updated to use the alloc
1270 // region: Allocate reduction variable (which is a pointer to the real
1271 // reduciton variable allocated in the inlined region)
1272 byRefVars[i] = builder.CreateAlloca(
1273 moduleTranslation.convertType(reductionDecls[i].getType()));
1274 }
1275 }
1276
1277 setInsertPointForPossiblyEmptyBlock(builder, initBlock);
1278
1279 // store result of the alloc region to the allocated pointer to the real
1280 // reduction variable
1281 for (auto [data, addr] : deferredStores)
1282 builder.CreateStore(data, addr);
1283
1284 // Before the loop, store the initial values of reductions into reduction
1285 // variables. Although this could be done after allocas, we don't want to mess
1286 // up with the alloca insertion point.
1287 for (unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1289
1290 // map block argument to initializer region
1291 mapInitializationArgs(op, moduleTranslation, builder, reductionDecls,
1292 reductionVariableMap, i);
1293
1294 // TODO In some cases (specially on the GPU), the init regions may
1295 // contains stack alloctaions. If the region is inlined in a loop, this is
1296 // problematic. Instead of just inlining the region, handle allocations by
1297 // hoisting fixed length allocations to the function entry and using
1298 // stacksave and restore for variable length ones.
1299 if (failed(inlineConvertOmpRegions(reductionDecls[i].getInitializerRegion(),
1300 "omp.reduction.neutral", builder,
1301 moduleTranslation, &phis)))
1302 return failure();
1303
1304 assert(phis.size() == 1 && "expected one value to be yielded from the "
1305 "reduction neutral element declaration region");
1306
1308
1309 if (isByRef[i]) {
1310 if (!reductionDecls[i].getAllocRegion().empty())
1311 // done in allocReductionVars
1312 continue;
1313
1314 // TODO: this path can be removed once all users of by-ref are updated to
1315 // use an alloc region
1316
1317 // Store the result of the inlined region to the allocated reduction var
1318 // ptr
1319 builder.CreateStore(phis[0], byRefVars[i]);
1320
1321 privateReductionVariables[i] = byRefVars[i];
1322 moduleTranslation.mapValue(reductionArgs[i], phis[0]);
1323 reductionVariableMap.try_emplace(op.getReductionVars()[i], phis[0]);
1324 } else {
1325 // for by-ref case the store is inside of the reduction region
1326 builder.CreateStore(phis[0], privateReductionVariables[i]);
1327 // the rest was handled in allocByValReductionVars
1328 }
1329
1330 // forget the mapping for the initializer region because we might need a
1331 // different mapping if this reduction declaration is re-used for a
1332 // different variable
1333 moduleTranslation.forgetMapping(reductionDecls[i].getInitializerRegion());
1334 }
1335
1336 return success();
1337}
1338
1339/// Collect reduction info
1340template <typename T>
1341static void collectReductionInfo(
1342 T loop, llvm::IRBuilderBase &builder,
1343 LLVM::ModuleTranslation &moduleTranslation,
1346 SmallVectorImpl<OwningAtomicReductionGen> &owningAtomicReductionGens,
1348 const ArrayRef<llvm::Value *> privateReductionVariables,
1350 ArrayRef<bool> isByRef) {
1351 unsigned numReductions = loop.getNumReductionVars();
1352
1353 for (unsigned i = 0; i < numReductions; ++i) {
1354 owningReductionGens.push_back(
1355 makeReductionGen(reductionDecls[i], builder, moduleTranslation));
1356 owningAtomicReductionGens.push_back(
1357 makeAtomicReductionGen(reductionDecls[i], builder, moduleTranslation));
1359 reductionDecls[i], builder, moduleTranslation, isByRef[i]));
1360 }
1361
1362 // Collect the reduction information.
1363 reductionInfos.reserve(numReductions);
1364 for (unsigned i = 0; i < numReductions; ++i) {
1365 llvm::OpenMPIRBuilder::ReductionGenAtomicCBTy atomicGen = nullptr;
1366 if (owningAtomicReductionGens[i])
1367 atomicGen = owningAtomicReductionGens[i];
1368 llvm::Value *variable =
1369 moduleTranslation.lookupValue(loop.getReductionVars()[i]);
1370 mlir::Type allocatedType;
1371 reductionDecls[i].getAllocRegion().walk([&](mlir::Operation *op) {
1372 if (auto alloca = mlir::dyn_cast<LLVM::AllocaOp>(op)) {
1373 allocatedType = alloca.getElemType();
1375 }
1376
1378 });
1379
1380 reductionInfos.push_back(
1381 {moduleTranslation.convertType(reductionDecls[i].getType()), variable,
1382 privateReductionVariables[i],
1383 /*EvaluationKind=*/llvm::OpenMPIRBuilder::EvalKind::Scalar,
1385 /*ReductionGenClang=*/nullptr, atomicGen,
1387 allocatedType ? moduleTranslation.convertType(allocatedType) : nullptr,
1388 reductionDecls[i].getByrefElementType()
1389 ? moduleTranslation.convertType(
1390 *reductionDecls[i].getByrefElementType())
1391 : nullptr});
1392 }
1393}
1394
1395/// handling of DeclareReductionOp's cleanup region
1396static LogicalResult
1398 llvm::ArrayRef<llvm::Value *> privateVariables,
1399 LLVM::ModuleTranslation &moduleTranslation,
1400 llvm::IRBuilderBase &builder, StringRef regionName,
1401 bool shouldLoadCleanupRegionArg = true) {
1402 for (auto [i, cleanupRegion] : llvm::enumerate(cleanupRegions)) {
1403 if (cleanupRegion->empty())
1404 continue;
1405
1406 // map the argument to the cleanup region
1407 Block &entry = cleanupRegion->front();
1408
1409 llvm::Instruction *potentialTerminator =
1410 builder.GetInsertBlock()->empty() ? nullptr
1411 : &builder.GetInsertBlock()->back();
1412 if (potentialTerminator && potentialTerminator->isTerminator())
1413 builder.SetInsertPoint(potentialTerminator);
1414 llvm::Value *privateVarValue =
1415 shouldLoadCleanupRegionArg
1416 ? builder.CreateLoad(
1417 moduleTranslation.convertType(entry.getArgument(0).getType()),
1418 privateVariables[i])
1419 : privateVariables[i];
1420
1421 moduleTranslation.mapValue(entry.getArgument(0), privateVarValue);
1422
1423 if (failed(inlineConvertOmpRegions(*cleanupRegion, regionName, builder,
1424 moduleTranslation)))
1425 return failure();
1426
1427 // clear block argument mapping in case it needs to be re-created with a
1428 // different source for another use of the same reduction decl
1429 moduleTranslation.forgetMapping(*cleanupRegion);
1430 }
1431 return success();
1432}
1433
1434// TODO: not used by ParallelOp
1435template <class OP>
1436static LogicalResult createReductionsAndCleanup(
1437 OP op, llvm::IRBuilderBase &builder,
1438 LLVM::ModuleTranslation &moduleTranslation,
1439 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1441 ArrayRef<llvm::Value *> privateReductionVariables, ArrayRef<bool> isByRef,
1442 bool isNowait = false, bool isTeamsReduction = false) {
1443 // Process the reductions if required.
1444 if (op.getNumReductionVars() == 0)
1445 return success();
1446
1448 SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
1449 SmallVector<OwningDataPtrPtrReductionGen> owningReductionGenRefDataPtrGens;
1451
1452 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
1453
1454 // Create the reduction generators. We need to own them here because
1455 // ReductionInfo only accepts references to the generators.
1456 collectReductionInfo(op, builder, moduleTranslation, reductionDecls,
1457 owningReductionGens, owningAtomicReductionGens,
1458 owningReductionGenRefDataPtrGens,
1459 privateReductionVariables, reductionInfos, isByRef);
1460
1461 // The call to createReductions below expects the block to have a
1462 // terminator. Create an unreachable instruction to serve as terminator
1463 // and remove it later.
1464 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
1465 builder.SetInsertPoint(tempTerminator);
1466 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
1467 ompBuilder->createReductions(builder.saveIP(), allocaIP, reductionInfos,
1468 isByRef, isNowait, isTeamsReduction);
1469
1470 if (failed(handleError(contInsertPoint, *op)))
1471 return failure();
1472
1473 if (!contInsertPoint->getBlock())
1474 return op->emitOpError() << "failed to convert reductions";
1475
1476 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1477 ompBuilder->createBarrier(*contInsertPoint, llvm::omp::OMPD_for);
1478
1479 if (failed(handleError(afterIP, *op)))
1480 return failure();
1481
1482 tempTerminator->eraseFromParent();
1483 builder.restoreIP(*afterIP);
1484
1485 // after the construct, deallocate private reduction variables
1486 SmallVector<Region *> reductionRegions;
1487 llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
1488 [](omp::DeclareReductionOp reductionDecl) {
1489 return &reductionDecl.getCleanupRegion();
1490 });
1491 return inlineOmpRegionCleanup(reductionRegions, privateReductionVariables,
1492 moduleTranslation, builder,
1493 "omp.reduction.cleanup");
1494 return success();
1495}
1496
1497static ArrayRef<bool> getIsByRef(std::optional<ArrayRef<bool>> attr) {
1498 if (!attr)
1499 return {};
1500 return *attr;
1501}
1502
1503// TODO: not used by omp.parallel
1504template <typename OP>
1506 OP op, ArrayRef<BlockArgument> reductionArgs, llvm::IRBuilderBase &builder,
1507 LLVM::ModuleTranslation &moduleTranslation,
1508 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1510 SmallVectorImpl<llvm::Value *> &privateReductionVariables,
1511 DenseMap<Value, llvm::Value *> &reductionVariableMap,
1512 llvm::ArrayRef<bool> isByRef) {
1513 if (op.getNumReductionVars() == 0)
1514 return success();
1515
1516 SmallVector<DeferredStore> deferredStores;
1517
1518 if (failed(allocReductionVars(op, reductionArgs, builder, moduleTranslation,
1519 allocaIP, reductionDecls,
1520 privateReductionVariables, reductionVariableMap,
1521 deferredStores, isByRef)))
1522 return failure();
1523
1524 return initReductionVars(op, reductionArgs, builder, moduleTranslation,
1525 allocaIP.getBlock(), reductionDecls,
1526 privateReductionVariables, reductionVariableMap,
1527 isByRef, deferredStores);
1528}
1529
1530/// Return the llvm::Value * corresponding to the `privateVar` that
1531/// is being privatized. It isn't always as simple as looking up
1532/// moduleTranslation with privateVar. For instance, in case of
1533/// an allocatable, the descriptor for the allocatable is privatized.
1534/// This descriptor is mapped using an MapInfoOp. So, this function
1535/// will return a pointer to the llvm::Value corresponding to the
1536/// block argument for the mapped descriptor.
1537static llvm::Value *
1538findAssociatedValue(Value privateVar, llvm::IRBuilderBase &builder,
1539 LLVM::ModuleTranslation &moduleTranslation,
1540 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1541 if (mappedPrivateVars == nullptr || !mappedPrivateVars->contains(privateVar))
1542 return moduleTranslation.lookupValue(privateVar);
1543
1544 Value blockArg = (*mappedPrivateVars)[privateVar];
1545 Type privVarType = privateVar.getType();
1546 Type blockArgType = blockArg.getType();
1547 assert(isa<LLVM::LLVMPointerType>(blockArgType) &&
1548 "A block argument corresponding to a mapped var should have "
1549 "!llvm.ptr type");
1550
1551 if (privVarType == blockArgType)
1552 return moduleTranslation.lookupValue(blockArg);
1553
1554 // This typically happens when the privatized type is lowered from
1555 // boxchar<KIND> and gets lowered to !llvm.struct<(ptr, i64)>. That is the
1556 // struct/pair is passed by value. But, mapped values are passed only as
1557 // pointers, so before we privatize, we must load the pointer.
1558 if (!isa<LLVM::LLVMPointerType>(privVarType))
1559 return builder.CreateLoad(moduleTranslation.convertType(privVarType),
1560 moduleTranslation.lookupValue(blockArg));
1561
1562 return moduleTranslation.lookupValue(privateVar);
1563}
1564
1565/// Initialize a single (first)private variable. You probably want to use
1566/// allocateAndInitPrivateVars instead of this.
1567/// This returns the private variable which has been initialized. This
1568/// variable should be mapped before constructing the body of the Op.
1570initPrivateVar(llvm::IRBuilderBase &builder,
1571 LLVM::ModuleTranslation &moduleTranslation,
1572 omp::PrivateClauseOp &privDecl, llvm::Value *nonPrivateVar,
1573 BlockArgument &blockArg, llvm::Value *llvmPrivateVar,
1574 llvm::BasicBlock *privInitBlock,
1575 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1576 Region &initRegion = privDecl.getInitRegion();
1577 if (initRegion.empty())
1578 return llvmPrivateVar;
1579
1580 assert(nonPrivateVar);
1581 moduleTranslation.mapValue(privDecl.getInitMoldArg(), nonPrivateVar);
1582 moduleTranslation.mapValue(privDecl.getInitPrivateArg(), llvmPrivateVar);
1583
1584 // in-place convert the private initialization region
1586 if (failed(inlineConvertOmpRegions(initRegion, "omp.private.init", builder,
1587 moduleTranslation, &phis)))
1588 return llvm::createStringError(
1589 "failed to inline `init` region of `omp.private`");
1590
1591 assert(phis.size() == 1 && "expected one allocation to be yielded");
1592
1593 // clear init region block argument mapping in case it needs to be
1594 // re-created with a different source for another use of the same
1595 // reduction decl
1596 moduleTranslation.forgetMapping(initRegion);
1597
1598 // Prefer the value yielded from the init region to the allocated private
1599 // variable in case the region is operating on arguments by-value (e.g.
1600 // Fortran character boxes).
1601 return phis[0];
1602}
1603
1604/// Version of initPrivateVar which looks up the nonPrivateVar from mlirPrivVar.
1606 llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation,
1607 omp::PrivateClauseOp &privDecl, Value mlirPrivVar, BlockArgument &blockArg,
1608 llvm::Value *llvmPrivateVar, llvm::BasicBlock *privInitBlock,
1609 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1610 return initPrivateVar(
1611 builder, moduleTranslation, privDecl,
1612 findAssociatedValue(mlirPrivVar, builder, moduleTranslation,
1613 mappedPrivateVars),
1614 blockArg, llvmPrivateVar, privInitBlock, mappedPrivateVars);
1615}
1616
1617static llvm::Error
1618initPrivateVars(llvm::IRBuilderBase &builder,
1619 LLVM::ModuleTranslation &moduleTranslation,
1620 PrivateVarsInfo &privateVarsInfo,
1621 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1622 if (privateVarsInfo.blockArgs.empty())
1623 return llvm::Error::success();
1624
1625 llvm::BasicBlock *privInitBlock = splitBB(builder, true, "omp.private.init");
1626 setInsertPointForPossiblyEmptyBlock(builder, privInitBlock);
1627
1628 for (auto [idx, zip] : llvm::enumerate(llvm::zip_equal(
1629 privateVarsInfo.privatizers, privateVarsInfo.mlirVars,
1630 privateVarsInfo.blockArgs, privateVarsInfo.llvmVars))) {
1631 auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVar] = zip;
1633 builder, moduleTranslation, privDecl, mlirPrivVar, blockArg,
1634 llvmPrivateVar, privInitBlock, mappedPrivateVars);
1635
1636 if (!privVarOrErr)
1637 return privVarOrErr.takeError();
1638
1639 llvmPrivateVar = privVarOrErr.get();
1640 moduleTranslation.mapValue(blockArg, llvmPrivateVar);
1641
1643 }
1644
1645 return llvm::Error::success();
1646}
1647
1648/// Allocate and initialize delayed private variables. Returns the basic block
1649/// which comes after all of these allocations. llvm::Value * for each of these
1650/// private variables are populated in llvmPrivateVars.
1652allocatePrivateVars(llvm::IRBuilderBase &builder,
1653 LLVM::ModuleTranslation &moduleTranslation,
1654 PrivateVarsInfo &privateVarsInfo,
1655 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1656 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1657 // Allocate private vars
1658 llvm::Instruction *allocaTerminator = allocaIP.getBlock()->getTerminator();
1659 splitBB(llvm::OpenMPIRBuilder::InsertPointTy(allocaIP.getBlock(),
1660 allocaTerminator->getIterator()),
1661 true, allocaTerminator->getStableDebugLoc(),
1662 "omp.region.after_alloca");
1663
1664 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1665 // Update the allocaTerminator since the alloca block was split above.
1666 allocaTerminator = allocaIP.getBlock()->getTerminator();
1667 builder.SetInsertPoint(allocaTerminator);
1668 // The new terminator is an uncondition branch created by the splitBB above.
1669 assert(allocaTerminator->getNumSuccessors() == 1 &&
1670 "This is an unconditional branch created by splitBB");
1671
1672 llvm::DataLayout dataLayout = builder.GetInsertBlock()->getDataLayout();
1673 llvm::BasicBlock *afterAllocas = allocaTerminator->getSuccessor(0);
1674
1675 unsigned int allocaAS =
1676 moduleTranslation.getLLVMModule()->getDataLayout().getAllocaAddrSpace();
1677 unsigned int defaultAS = moduleTranslation.getLLVMModule()
1678 ->getDataLayout()
1679 .getProgramAddressSpace();
1680
1681 for (auto [privDecl, mlirPrivVar, blockArg] :
1682 llvm::zip_equal(privateVarsInfo.privatizers, privateVarsInfo.mlirVars,
1683 privateVarsInfo.blockArgs)) {
1684 llvm::Type *llvmAllocType =
1685 moduleTranslation.convertType(privDecl.getType());
1686 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1687 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
1688 llvmAllocType, /*ArraySize=*/nullptr, "omp.private.alloc");
1689 if (allocaAS != defaultAS)
1690 llvmPrivateVar = builder.CreateAddrSpaceCast(llvmPrivateVar,
1691 builder.getPtrTy(defaultAS));
1692
1693 privateVarsInfo.llvmVars.push_back(llvmPrivateVar);
1694 }
1695
1696 return afterAllocas;
1697}
1698
1699/// This can't always be determined statically, but when we can, it is good to
1700/// avoid generating compiler-added barriers which will deadlock the program.
1702 for (mlir::Operation *parent = op->getParentOp(); parent != nullptr;
1703 parent = parent->getParentOp()) {
1704 if (mlir::isa<omp::SingleOp, omp::CriticalOp>(parent))
1705 return true;
1706
1707 // e.g.
1708 // omp.single {
1709 // omp.parallel {
1710 // op
1711 // }
1712 // }
1713 if (mlir::isa<omp::ParallelOp>(parent))
1714 return false;
1715 }
1716 return false;
1717}
1718
1719static LogicalResult copyFirstPrivateVars(
1720 mlir::Operation *op, llvm::IRBuilderBase &builder,
1721 LLVM::ModuleTranslation &moduleTranslation,
1723 ArrayRef<llvm::Value *> llvmPrivateVars,
1724 SmallVectorImpl<omp::PrivateClauseOp> &privateDecls, bool insertBarrier,
1725 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1726 // Apply copy region for firstprivate.
1727 bool needsFirstprivate =
1728 llvm::any_of(privateDecls, [](omp::PrivateClauseOp &privOp) {
1729 return privOp.getDataSharingType() ==
1730 omp::DataSharingClauseType::FirstPrivate;
1731 });
1732
1733 if (!needsFirstprivate)
1734 return success();
1735
1736 llvm::BasicBlock *copyBlock =
1737 splitBB(builder, /*CreateBranch=*/true, "omp.private.copy");
1738 setInsertPointForPossiblyEmptyBlock(builder, copyBlock);
1739
1740 for (auto [decl, moldVar, llvmVar] :
1741 llvm::zip_equal(privateDecls, moldVars, llvmPrivateVars)) {
1742 if (decl.getDataSharingType() != omp::DataSharingClauseType::FirstPrivate)
1743 continue;
1744
1745 // copyRegion implements `lhs = rhs`
1746 Region &copyRegion = decl.getCopyRegion();
1747
1748 moduleTranslation.mapValue(decl.getCopyMoldArg(), moldVar);
1749
1750 // map copyRegion lhs arg
1751 moduleTranslation.mapValue(decl.getCopyPrivateArg(), llvmVar);
1752
1753 // in-place convert copy region
1754 if (failed(inlineConvertOmpRegions(copyRegion, "omp.private.copy", builder,
1755 moduleTranslation)))
1756 return decl.emitError("failed to inline `copy` region of `omp.private`");
1757
1759
1760 // ignore unused value yielded from copy region
1761
1762 // clear copy region block argument mapping in case it needs to be
1763 // re-created with different sources for reuse of the same reduction
1764 // decl
1765 moduleTranslation.forgetMapping(copyRegion);
1766 }
1767
1768 if (insertBarrier && !opIsInSingleThread(op)) {
1769 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
1770 llvm::OpenMPIRBuilder::InsertPointOrErrorTy res =
1771 ompBuilder->createBarrier(builder.saveIP(), llvm::omp::OMPD_barrier);
1772 if (failed(handleError(res, *op)))
1773 return failure();
1774 }
1775
1776 return success();
1777}
1778
1779static LogicalResult copyFirstPrivateVars(
1780 mlir::Operation *op, llvm::IRBuilderBase &builder,
1781 LLVM::ModuleTranslation &moduleTranslation,
1782 SmallVectorImpl<mlir::Value> &mlirPrivateVars,
1783 ArrayRef<llvm::Value *> llvmPrivateVars,
1784 SmallVectorImpl<omp::PrivateClauseOp> &privateDecls, bool insertBarrier,
1785 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1786 llvm::SmallVector<llvm::Value *> moldVars(mlirPrivateVars.size());
1787 llvm::transform(mlirPrivateVars, moldVars.begin(), [&](mlir::Value mlirVar) {
1788 // map copyRegion rhs arg
1789 llvm::Value *moldVar = findAssociatedValue(
1790 mlirVar, builder, moduleTranslation, mappedPrivateVars);
1791 assert(moldVar);
1792 return moldVar;
1793 });
1794 return copyFirstPrivateVars(op, builder, moduleTranslation, moldVars,
1795 llvmPrivateVars, privateDecls, insertBarrier,
1796 mappedPrivateVars);
1797}
1798
1799static LogicalResult
1800cleanupPrivateVars(llvm::IRBuilderBase &builder,
1801 LLVM::ModuleTranslation &moduleTranslation, Location loc,
1802 SmallVectorImpl<llvm::Value *> &llvmPrivateVars,
1804 // private variable deallocation
1805 SmallVector<Region *> privateCleanupRegions;
1806 llvm::transform(privateDecls, std::back_inserter(privateCleanupRegions),
1807 [](omp::PrivateClauseOp privatizer) {
1808 return &privatizer.getDeallocRegion();
1809 });
1810
1811 if (failed(inlineOmpRegionCleanup(
1812 privateCleanupRegions, llvmPrivateVars, moduleTranslation, builder,
1813 "omp.private.dealloc", /*shouldLoadCleanupRegionArg=*/false)))
1814 return mlir::emitError(loc, "failed to inline `dealloc` region of an "
1815 "`omp.private` op in");
1816
1817 return success();
1818}
1819
1820/// Returns true if the construct contains omp.cancel or omp.cancellation_point
1822 // omp.cancel and omp.cancellation_point must be "closely nested" so they will
1823 // be visible and not inside of function calls. This is enforced by the
1824 // verifier.
1825 return op
1826 ->walk([](Operation *child) {
1827 if (mlir::isa<omp::CancelOp, omp::CancellationPointOp>(child))
1828 return WalkResult::interrupt();
1829 return WalkResult::advance();
1830 })
1831 .wasInterrupted();
1832}
1833
1834static LogicalResult
1835convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder,
1836 LLVM::ModuleTranslation &moduleTranslation) {
1837 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1838 using StorableBodyGenCallbackTy =
1839 llvm::OpenMPIRBuilder::StorableBodyGenCallbackTy;
1840
1841 auto sectionsOp = cast<omp::SectionsOp>(opInst);
1842
1843 if (failed(checkImplementationStatus(opInst)))
1844 return failure();
1845
1846 llvm::ArrayRef<bool> isByRef = getIsByRef(sectionsOp.getReductionByref());
1847 assert(isByRef.size() == sectionsOp.getNumReductionVars());
1848
1850 collectReductionDecls(sectionsOp, reductionDecls);
1851 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1852 findAllocaInsertPoint(builder, moduleTranslation);
1853
1854 SmallVector<llvm::Value *> privateReductionVariables(
1855 sectionsOp.getNumReductionVars());
1856 DenseMap<Value, llvm::Value *> reductionVariableMap;
1857
1858 MutableArrayRef<BlockArgument> reductionArgs =
1859 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
1860
1862 sectionsOp, reductionArgs, builder, moduleTranslation, allocaIP,
1863 reductionDecls, privateReductionVariables, reductionVariableMap,
1864 isByRef)))
1865 return failure();
1866
1868
1869 for (Operation &op : *sectionsOp.getRegion().begin()) {
1870 auto sectionOp = dyn_cast<omp::SectionOp>(op);
1871 if (!sectionOp) // omp.terminator
1872 continue;
1873
1874 Region &region = sectionOp.getRegion();
1875 auto sectionCB = [&sectionsOp, &region, &builder, &moduleTranslation](
1876 InsertPointTy allocaIP, InsertPointTy codeGenIP) {
1877 builder.restoreIP(codeGenIP);
1878
1879 // map the omp.section reduction block argument to the omp.sections block
1880 // arguments
1881 // TODO: this assumes that the only block arguments are reduction
1882 // variables
1883 assert(region.getNumArguments() ==
1884 sectionsOp.getRegion().getNumArguments());
1885 for (auto [sectionsArg, sectionArg] : llvm::zip_equal(
1886 sectionsOp.getRegion().getArguments(), region.getArguments())) {
1887 llvm::Value *llvmVal = moduleTranslation.lookupValue(sectionsArg);
1888 assert(llvmVal);
1889 moduleTranslation.mapValue(sectionArg, llvmVal);
1890 }
1891
1892 return convertOmpOpRegions(region, "omp.section.region", builder,
1893 moduleTranslation)
1894 .takeError();
1895 };
1896 sectionCBs.push_back(sectionCB);
1897 }
1898
1899 // No sections within omp.sections operation - skip generation. This situation
1900 // is only possible if there is only a terminator operation inside the
1901 // sections operation
1902 if (sectionCBs.empty())
1903 return success();
1904
1905 assert(isa<omp::SectionOp>(*sectionsOp.getRegion().op_begin()));
1906
1907 // TODO: Perform appropriate actions according to the data-sharing
1908 // attribute (shared, private, firstprivate, ...) of variables.
1909 // Currently defaults to shared.
1910 auto privCB = [&](InsertPointTy, InsertPointTy codeGenIP, llvm::Value &,
1911 llvm::Value &vPtr, llvm::Value *&replacementValue)
1912 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1913 replacementValue = &vPtr;
1914 return codeGenIP;
1915 };
1916
1917 // TODO: Perform finalization actions for variables. This has to be
1918 // called for variables which have destructors/finalizers.
1919 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
1920
1921 allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
1922 bool isCancellable = constructIsCancellable(sectionsOp);
1923 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1924 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1925 moduleTranslation.getOpenMPBuilder()->createSections(
1926 ompLoc, allocaIP, sectionCBs, privCB, finiCB, isCancellable,
1927 sectionsOp.getNowait());
1928
1929 if (failed(handleError(afterIP, opInst)))
1930 return failure();
1931
1932 builder.restoreIP(*afterIP);
1933
1934 // Process the reductions if required.
1936 sectionsOp, builder, moduleTranslation, allocaIP, reductionDecls,
1937 privateReductionVariables, isByRef, sectionsOp.getNowait());
1938}
1939
1940/// Converts an OpenMP single construct into LLVM IR using OpenMPIRBuilder.
1941static LogicalResult
1942convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder,
1943 LLVM::ModuleTranslation &moduleTranslation) {
1944 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1945 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1946
1947 if (failed(checkImplementationStatus(*singleOp)))
1948 return failure();
1949
1950 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
1951 builder.restoreIP(codegenIP);
1952 return convertOmpOpRegions(singleOp.getRegion(), "omp.single.region",
1953 builder, moduleTranslation)
1954 .takeError();
1955 };
1956 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
1957
1958 // Handle copyprivate
1959 Operation::operand_range cpVars = singleOp.getCopyprivateVars();
1960 std::optional<ArrayAttr> cpFuncs = singleOp.getCopyprivateSyms();
1963 for (size_t i = 0, e = cpVars.size(); i < e; ++i) {
1964 llvmCPVars.push_back(moduleTranslation.lookupValue(cpVars[i]));
1966 singleOp, cast<SymbolRefAttr>((*cpFuncs)[i]));
1967 llvmCPFuncs.push_back(
1968 moduleTranslation.lookupFunction(llvmFuncOp.getName()));
1969 }
1970
1971 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1972 moduleTranslation.getOpenMPBuilder()->createSingle(
1973 ompLoc, bodyCB, finiCB, singleOp.getNowait(), llvmCPVars,
1974 llvmCPFuncs);
1975
1976 if (failed(handleError(afterIP, *singleOp)))
1977 return failure();
1978
1979 builder.restoreIP(*afterIP);
1980 return success();
1981}
1982
1983static bool teamsReductionContainedInDistribute(omp::TeamsOp teamsOp) {
1984 auto iface =
1985 llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(teamsOp.getOperation());
1986 // Check that all uses of the reduction block arg has the same distribute op
1987 // parent.
1989 Operation *distOp = nullptr;
1990 for (auto ra : iface.getReductionBlockArgs())
1991 for (auto &use : ra.getUses()) {
1992 auto *useOp = use.getOwner();
1993 // Ignore debug uses.
1994 if (mlir::isa<LLVM::DbgDeclareOp, LLVM::DbgValueOp>(useOp)) {
1995 debugUses.push_back(useOp);
1996 continue;
1997 }
1998
1999 auto currentDistOp = useOp->getParentOfType<omp::DistributeOp>();
2000 // Use is not inside a distribute op - return false
2001 if (!currentDistOp)
2002 return false;
2003 // Multiple distribute operations - return false
2004 Operation *currentOp = currentDistOp.getOperation();
2005 if (distOp && (distOp != currentOp))
2006 return false;
2007
2008 distOp = currentOp;
2009 }
2010
2011 // If we are going to use distribute reduction then remove any debug uses of
2012 // the reduction parameters in teamsOp. Otherwise they will be left without
2013 // any mapped value in moduleTranslation and will eventually error out.
2014 for (auto *use : debugUses)
2015 use->erase();
2016 return true;
2017}
2018
2019// Convert an OpenMP Teams construct to LLVM IR using OpenMPIRBuilder
2020static LogicalResult
2021convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder,
2022 LLVM::ModuleTranslation &moduleTranslation) {
2023 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2024 if (failed(checkImplementationStatus(*op)))
2025 return failure();
2026
2027 DenseMap<Value, llvm::Value *> reductionVariableMap;
2028 unsigned numReductionVars = op.getNumReductionVars();
2030 SmallVector<llvm::Value *> privateReductionVariables(numReductionVars);
2031 llvm::ArrayRef<bool> isByRef;
2032 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2033 findAllocaInsertPoint(builder, moduleTranslation);
2034
2035 // Only do teams reduction if there is no distribute op that captures the
2036 // reduction instead.
2037 bool doTeamsReduction = !teamsReductionContainedInDistribute(op);
2038 if (doTeamsReduction) {
2039 isByRef = getIsByRef(op.getReductionByref());
2040
2041 assert(isByRef.size() == op.getNumReductionVars());
2042
2043 MutableArrayRef<BlockArgument> reductionArgs =
2044 llvm::cast<omp::BlockArgOpenMPOpInterface>(*op).getReductionBlockArgs();
2045
2046 collectReductionDecls(op, reductionDecls);
2047
2049 op, reductionArgs, builder, moduleTranslation, allocaIP,
2050 reductionDecls, privateReductionVariables, reductionVariableMap,
2051 isByRef)))
2052 return failure();
2053 }
2054
2055 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
2057 moduleTranslation, allocaIP);
2058 builder.restoreIP(codegenIP);
2059 return convertOmpOpRegions(op.getRegion(), "omp.teams.region", builder,
2060 moduleTranslation)
2061 .takeError();
2062 };
2063
2064 llvm::Value *numTeamsLower = nullptr;
2065 if (Value numTeamsLowerVar = op.getNumTeamsLower())
2066 numTeamsLower = moduleTranslation.lookupValue(numTeamsLowerVar);
2067
2068 llvm::Value *numTeamsUpper = nullptr;
2069 if (!op.getNumTeamsUpperVars().empty())
2070 numTeamsUpper = moduleTranslation.lookupValue(op.getNumTeams(0));
2071
2072 llvm::Value *threadLimit = nullptr;
2073 if (!op.getThreadLimitVars().empty())
2074 threadLimit = moduleTranslation.lookupValue(op.getThreadLimit(0));
2075
2076 llvm::Value *ifExpr = nullptr;
2077 if (Value ifVar = op.getIfExpr())
2078 ifExpr = moduleTranslation.lookupValue(ifVar);
2079
2080 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2081 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2082 moduleTranslation.getOpenMPBuilder()->createTeams(
2083 ompLoc, bodyCB, numTeamsLower, numTeamsUpper, threadLimit, ifExpr);
2084
2085 if (failed(handleError(afterIP, *op)))
2086 return failure();
2087
2088 builder.restoreIP(*afterIP);
2089 if (doTeamsReduction) {
2090 // Process the reductions if required.
2092 op, builder, moduleTranslation, allocaIP, reductionDecls,
2093 privateReductionVariables, isByRef,
2094 /*isNoWait*/ false, /*isTeamsReduction*/ true);
2095 }
2096 return success();
2097}
2098
2099static void
2100buildDependData(std::optional<ArrayAttr> dependKinds, OperandRange dependVars,
2101 LLVM::ModuleTranslation &moduleTranslation,
2103 if (dependVars.empty())
2104 return;
2105 for (auto dep : llvm::zip(dependVars, dependKinds->getValue())) {
2106 llvm::omp::RTLDependenceKindTy type;
2107 switch (
2108 cast<mlir::omp::ClauseTaskDependAttr>(std::get<1>(dep)).getValue()) {
2109 case mlir::omp::ClauseTaskDepend::taskdependin:
2110 type = llvm::omp::RTLDependenceKindTy::DepIn;
2111 break;
2112 // The OpenMP runtime requires that the codegen for 'depend' clause for
2113 // 'out' dependency kind must be the same as codegen for 'depend' clause
2114 // with 'inout' dependency.
2115 case mlir::omp::ClauseTaskDepend::taskdependout:
2116 case mlir::omp::ClauseTaskDepend::taskdependinout:
2117 type = llvm::omp::RTLDependenceKindTy::DepInOut;
2118 break;
2119 case mlir::omp::ClauseTaskDepend::taskdependmutexinoutset:
2120 type = llvm::omp::RTLDependenceKindTy::DepMutexInOutSet;
2121 break;
2122 case mlir::omp::ClauseTaskDepend::taskdependinoutset:
2123 type = llvm::omp::RTLDependenceKindTy::DepInOutSet;
2124 break;
2125 };
2126 llvm::Value *depVal = moduleTranslation.lookupValue(std::get<0>(dep));
2127 llvm::OpenMPIRBuilder::DependData dd(type, depVal->getType(), depVal);
2128 dds.emplace_back(dd);
2129 }
2130}
2131
2132/// Shared implementation of a callback which adds a termiator for the new block
2133/// created for the branch taken when an openmp construct is cancelled. The
2134/// terminator is saved in \p cancelTerminators. This callback is invoked only
2135/// if there is cancellation inside of the taskgroup body.
2136/// The terminator will need to be fixed to branch to the correct block to
2137/// cleanup the construct.
2138static void
2140 llvm::IRBuilderBase &llvmBuilder,
2141 llvm::OpenMPIRBuilder &ompBuilder, mlir::Operation *op,
2142 llvm::omp::Directive cancelDirective) {
2143 auto finiCB = [&](llvm::OpenMPIRBuilder::InsertPointTy ip) -> llvm::Error {
2144 llvm::IRBuilderBase::InsertPointGuard guard(llvmBuilder);
2145
2146 // ip is currently in the block branched to if cancellation occurred.
2147 // We need to create a branch to terminate that block.
2148 llvmBuilder.restoreIP(ip);
2149
2150 // We must still clean up the construct after cancelling it, so we need to
2151 // branch to the block that finalizes the taskgroup.
2152 // That block has not been created yet so use this block as a dummy for now
2153 // and fix this after creating the operation.
2154 cancelTerminators.push_back(llvmBuilder.CreateBr(ip.getBlock()));
2155 return llvm::Error::success();
2156 };
2157 // We have to add the cleanup to the OpenMPIRBuilder before the body gets
2158 // created in case the body contains omp.cancel (which will then expect to be
2159 // able to find this cleanup callback).
2160 ompBuilder.pushFinalizationCB(
2161 {finiCB, cancelDirective, constructIsCancellable(op)});
2162}
2163
2164/// If we cancelled the construct, we should branch to the finalization block of
2165/// that construct. OMPIRBuilder structures the CFG such that the cleanup block
2166/// is immediately before the continuation block. Now this finalization has
2167/// been created we can fix the branch.
2168static void
2170 llvm::OpenMPIRBuilder &ompBuilder,
2171 const llvm::OpenMPIRBuilder::InsertPointTy &afterIP) {
2172 ompBuilder.popFinalizationCB();
2173 llvm::BasicBlock *constructFini = afterIP.getBlock()->getSinglePredecessor();
2174 for (llvm::BranchInst *cancelBranch : cancelTerminators) {
2175 assert(cancelBranch->getNumSuccessors() == 1 &&
2176 "cancel branch should have one target");
2177 cancelBranch->setSuccessor(0, constructFini);
2178 }
2179}
2180
2181namespace {
2182/// TaskContextStructManager takes care of creating and freeing a structure
2183/// containing information needed by the task body to execute.
2184class TaskContextStructManager {
2185public:
2186 TaskContextStructManager(llvm::IRBuilderBase &builder,
2187 LLVM::ModuleTranslation &moduleTranslation,
2188 MutableArrayRef<omp::PrivateClauseOp> privateDecls)
2189 : builder{builder}, moduleTranslation{moduleTranslation},
2190 privateDecls{privateDecls} {}
2191
2192 /// Creates a heap allocated struct containing space for each private
2193 /// variable. Invariant: privateVarTypes, privateDecls, and the elements of
2194 /// the structure should all have the same order (although privateDecls which
2195 /// do not read from the mold argument are skipped).
2196 void generateTaskContextStruct();
2197
2198 /// Create GEPs to access each member of the structure representing a private
2199 /// variable, adding them to llvmPrivateVars. Null values are added where
2200 /// private decls were skipped so that the ordering continues to match the
2201 /// private decls.
2202 void createGEPsToPrivateVars();
2203
2204 /// Given the address of the structure, return a GEP for each private variable
2205 /// in the structure. Null values are added where private decls were skipped
2206 /// so that the ordering continues to match the private decls.
2207 /// Must be called after generateTaskContextStruct().
2208 SmallVector<llvm::Value *>
2209 createGEPsToPrivateVars(llvm::Value *altStructPtr) const;
2210
2211 /// De-allocate the task context structure.
2212 void freeStructPtr();
2213
2214 MutableArrayRef<llvm::Value *> getLLVMPrivateVarGEPs() {
2215 return llvmPrivateVarGEPs;
2216 }
2217
2218 llvm::Value *getStructPtr() { return structPtr; }
2219
2220private:
2221 llvm::IRBuilderBase &builder;
2222 LLVM::ModuleTranslation &moduleTranslation;
2223 MutableArrayRef<omp::PrivateClauseOp> privateDecls;
2224
2225 /// The type of each member of the structure, in order.
2226 SmallVector<llvm::Type *> privateVarTypes;
2227
2228 /// LLVM values for each private variable, or null if that private variable is
2229 /// not included in the task context structure
2230 SmallVector<llvm::Value *> llvmPrivateVarGEPs;
2231
2232 /// A pointer to the structure containing context for this task.
2233 llvm::Value *structPtr = nullptr;
2234 /// The type of the structure
2235 llvm::Type *structTy = nullptr;
2236};
2237} // namespace
2238
2239void TaskContextStructManager::generateTaskContextStruct() {
2240 if (privateDecls.empty())
2241 return;
2242 privateVarTypes.reserve(privateDecls.size());
2243
2244 for (omp::PrivateClauseOp &privOp : privateDecls) {
2245 // Skip private variables which can safely be allocated and initialised
2246 // inside of the task
2247 if (!privOp.readsFromMold())
2248 continue;
2249 Type mlirType = privOp.getType();
2250 privateVarTypes.push_back(moduleTranslation.convertType(mlirType));
2251 }
2252
2253 if (privateVarTypes.empty())
2254 return;
2255
2256 structTy = llvm::StructType::get(moduleTranslation.getLLVMContext(),
2257 privateVarTypes);
2258
2259 llvm::DataLayout dataLayout =
2260 builder.GetInsertBlock()->getModule()->getDataLayout();
2261 llvm::Type *intPtrTy = builder.getIntPtrTy(dataLayout);
2262 llvm::Constant *allocSize = llvm::ConstantExpr::getSizeOf(structTy);
2263
2264 // Heap allocate the structure
2265 structPtr = builder.CreateMalloc(intPtrTy, structTy, allocSize,
2266 /*ArraySize=*/nullptr, /*MallocF=*/nullptr,
2267 "omp.task.context_ptr");
2268}
2269
2270SmallVector<llvm::Value *> TaskContextStructManager::createGEPsToPrivateVars(
2271 llvm::Value *altStructPtr) const {
2272 SmallVector<llvm::Value *> ret;
2273
2274 // Create GEPs for each struct member
2275 ret.reserve(privateDecls.size());
2276 llvm::Value *zero = builder.getInt32(0);
2277 unsigned i = 0;
2278 for (auto privDecl : privateDecls) {
2279 if (!privDecl.readsFromMold()) {
2280 // Handle this inside of the task so we don't pass unnessecary vars in
2281 ret.push_back(nullptr);
2282 continue;
2283 }
2284 llvm::Value *iVal = builder.getInt32(i);
2285 llvm::Value *gep = builder.CreateGEP(structTy, altStructPtr, {zero, iVal});
2286 ret.push_back(gep);
2287 i += 1;
2288 }
2289 return ret;
2290}
2291
2292void TaskContextStructManager::createGEPsToPrivateVars() {
2293 if (!structPtr)
2294 assert(privateVarTypes.empty());
2295 // Still need to run createGEPsToPrivateVars to populate llvmPrivateVarGEPs
2296 // with null values for skipped private decls
2297
2298 llvmPrivateVarGEPs = createGEPsToPrivateVars(structPtr);
2299}
2300
2301void TaskContextStructManager::freeStructPtr() {
2302 if (!structPtr)
2303 return;
2304
2305 llvm::IRBuilderBase::InsertPointGuard guard{builder};
2306 // Ensure we don't put the call to free() after the terminator
2307 builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
2308 builder.CreateFree(structPtr);
2309}
2310
2311/// Converts an OpenMP task construct into LLVM IR using OpenMPIRBuilder.
2312static LogicalResult
2313convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
2314 LLVM::ModuleTranslation &moduleTranslation) {
2315 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2316 if (failed(checkImplementationStatus(*taskOp)))
2317 return failure();
2318
2319 PrivateVarsInfo privateVarsInfo(taskOp);
2320 TaskContextStructManager taskStructMgr{builder, moduleTranslation,
2321 privateVarsInfo.privatizers};
2322
2323 // Allocate and copy private variables before creating the task. This avoids
2324 // accessing invalid memory if (after this scope ends) the private variables
2325 // are initialized from host variables or if the variables are copied into
2326 // from host variables (firstprivate). The insertion point is just before
2327 // where the code for creating and scheduling the task will go. That puts this
2328 // code outside of the outlined task region, which is what we want because
2329 // this way the initialization and copy regions are executed immediately while
2330 // the host variable data are still live.
2331
2332 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2333 findAllocaInsertPoint(builder, moduleTranslation);
2334
2335 // Not using splitBB() because that requires the current block to have a
2336 // terminator.
2337 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end());
2338 llvm::BasicBlock *taskStartBlock = llvm::BasicBlock::Create(
2339 builder.getContext(), "omp.task.start",
2340 /*Parent=*/builder.GetInsertBlock()->getParent());
2341 llvm::Instruction *branchToTaskStartBlock = builder.CreateBr(taskStartBlock);
2342 builder.SetInsertPoint(branchToTaskStartBlock);
2343
2344 // Now do this again to make the initialization and copy blocks
2345 llvm::BasicBlock *copyBlock =
2346 splitBB(builder, /*CreateBranch=*/true, "omp.private.copy");
2347 llvm::BasicBlock *initBlock =
2348 splitBB(builder, /*CreateBranch=*/true, "omp.private.init");
2349
2350 // Now the control flow graph should look like
2351 // starter_block:
2352 // <---- where we started when convertOmpTaskOp was called
2353 // br %omp.private.init
2354 // omp.private.init:
2355 // br %omp.private.copy
2356 // omp.private.copy:
2357 // br %omp.task.start
2358 // omp.task.start:
2359 // <---- where we want the insertion point to be when we call createTask()
2360
2361 // Save the alloca insertion point on ModuleTranslation stack for use in
2362 // nested regions.
2364 moduleTranslation, allocaIP);
2365
2366 // Allocate and initialize private variables
2367 builder.SetInsertPoint(initBlock->getTerminator());
2368
2369 // Create task variable structure
2370 taskStructMgr.generateTaskContextStruct();
2371 // GEPs so that we can initialize the variables. Don't use these GEPs inside
2372 // of the body otherwise it will be the GEP not the struct which is fowarded
2373 // to the outlined function. GEPs forwarded in this way are passed in a
2374 // stack-allocated (by OpenMPIRBuilder) structure which is not safe for tasks
2375 // which may not be executed until after the current stack frame goes out of
2376 // scope.
2377 taskStructMgr.createGEPsToPrivateVars();
2378
2379 for (auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVarAlloc] :
2380 llvm::zip_equal(privateVarsInfo.privatizers, privateVarsInfo.mlirVars,
2381 privateVarsInfo.blockArgs,
2382 taskStructMgr.getLLVMPrivateVarGEPs())) {
2383 // To be handled inside the task.
2384 if (!privDecl.readsFromMold())
2385 continue;
2386 assert(llvmPrivateVarAlloc &&
2387 "reads from mold so shouldn't have been skipped");
2388
2389 llvm::Expected<llvm::Value *> privateVarOrErr =
2390 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2391 blockArg, llvmPrivateVarAlloc, initBlock);
2392 if (!privateVarOrErr)
2393 return handleError(privateVarOrErr, *taskOp.getOperation());
2394
2396
2397 // TODO: this is a bit of a hack for Fortran character boxes.
2398 // Character boxes are passed by value into the init region and then the
2399 // initialized character box is yielded by value. Here we need to store the
2400 // yielded value into the private allocation, and load the private
2401 // allocation to match the type expected by region block arguments.
2402 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
2403 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2404 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
2405 // Load it so we have the value pointed to by the GEP
2406 llvmPrivateVarAlloc = builder.CreateLoad(privateVarOrErr.get()->getType(),
2407 llvmPrivateVarAlloc);
2408 }
2409 assert(llvmPrivateVarAlloc->getType() ==
2410 moduleTranslation.convertType(blockArg.getType()));
2411
2412 // Mapping blockArg -> llvmPrivateVarAlloc is done inside the body callback
2413 // so that OpenMPIRBuilder doesn't try to pass each GEP address through a
2414 // stack allocated structure.
2415 }
2416
2417 // firstprivate copy region
2418 setInsertPointForPossiblyEmptyBlock(builder, copyBlock);
2419 if (failed(copyFirstPrivateVars(
2420 taskOp, builder, moduleTranslation, privateVarsInfo.mlirVars,
2421 taskStructMgr.getLLVMPrivateVarGEPs(), privateVarsInfo.privatizers,
2422 taskOp.getPrivateNeedsBarrier())))
2423 return llvm::failure();
2424
2425 // Set up for call to createTask()
2426 builder.SetInsertPoint(taskStartBlock);
2427
2428 auto bodyCB = [&](InsertPointTy allocaIP,
2429 InsertPointTy codegenIP) -> llvm::Error {
2430 // Save the alloca insertion point on ModuleTranslation stack for use in
2431 // nested regions.
2433 moduleTranslation, allocaIP);
2434
2435 // translate the body of the task:
2436 builder.restoreIP(codegenIP);
2437
2438 llvm::BasicBlock *privInitBlock = nullptr;
2439 privateVarsInfo.llvmVars.resize(privateVarsInfo.blockArgs.size());
2440 for (auto [i, zip] : llvm::enumerate(llvm::zip_equal(
2441 privateVarsInfo.blockArgs, privateVarsInfo.privatizers,
2442 privateVarsInfo.mlirVars))) {
2443 auto [blockArg, privDecl, mlirPrivVar] = zip;
2444 // This is handled before the task executes
2445 if (privDecl.readsFromMold())
2446 continue;
2447
2448 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2449 llvm::Type *llvmAllocType =
2450 moduleTranslation.convertType(privDecl.getType());
2451 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
2452 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
2453 llvmAllocType, /*ArraySize=*/nullptr, "omp.private.alloc");
2454
2455 llvm::Expected<llvm::Value *> privateVarOrError =
2456 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2457 blockArg, llvmPrivateVar, privInitBlock);
2458 if (!privateVarOrError)
2459 return privateVarOrError.takeError();
2460 moduleTranslation.mapValue(blockArg, privateVarOrError.get());
2461 privateVarsInfo.llvmVars[i] = privateVarOrError.get();
2462 }
2463
2464 taskStructMgr.createGEPsToPrivateVars();
2465 for (auto [i, llvmPrivVar] :
2466 llvm::enumerate(taskStructMgr.getLLVMPrivateVarGEPs())) {
2467 if (!llvmPrivVar) {
2468 assert(privateVarsInfo.llvmVars[i] &&
2469 "This is added in the loop above");
2470 continue;
2471 }
2472 privateVarsInfo.llvmVars[i] = llvmPrivVar;
2473 }
2474
2475 // Find and map the addresses of each variable within the task context
2476 // structure
2477 for (auto [blockArg, llvmPrivateVar, privateDecl] :
2478 llvm::zip_equal(privateVarsInfo.blockArgs, privateVarsInfo.llvmVars,
2479 privateVarsInfo.privatizers)) {
2480 // This was handled above.
2481 if (!privateDecl.readsFromMold())
2482 continue;
2483 // Fix broken pass-by-value case for Fortran character boxes
2484 if (!mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2485 llvmPrivateVar = builder.CreateLoad(
2486 moduleTranslation.convertType(blockArg.getType()), llvmPrivateVar);
2487 }
2488 assert(llvmPrivateVar->getType() ==
2489 moduleTranslation.convertType(blockArg.getType()));
2490 moduleTranslation.mapValue(blockArg, llvmPrivateVar);
2491 }
2492
2493 auto continuationBlockOrError = convertOmpOpRegions(
2494 taskOp.getRegion(), "omp.task.region", builder, moduleTranslation);
2495 if (failed(handleError(continuationBlockOrError, *taskOp)))
2496 return llvm::make_error<PreviouslyReportedError>();
2497
2498 builder.SetInsertPoint(continuationBlockOrError.get()->getTerminator());
2499
2500 if (failed(cleanupPrivateVars(builder, moduleTranslation, taskOp.getLoc(),
2501 privateVarsInfo.llvmVars,
2502 privateVarsInfo.privatizers)))
2503 return llvm::make_error<PreviouslyReportedError>();
2504
2505 // Free heap allocated task context structure at the end of the task.
2506 taskStructMgr.freeStructPtr();
2507
2508 return llvm::Error::success();
2509 };
2510
2511 llvm::OpenMPIRBuilder &ompBuilder = *moduleTranslation.getOpenMPBuilder();
2512 SmallVector<llvm::BranchInst *> cancelTerminators;
2513 // The directive to match here is OMPD_taskgroup because it is the taskgroup
2514 // which is canceled. This is handled here because it is the task's cleanup
2515 // block which should be branched to.
2516 pushCancelFinalizationCB(cancelTerminators, builder, ompBuilder, taskOp,
2517 llvm::omp::Directive::OMPD_taskgroup);
2518
2520 buildDependData(taskOp.getDependKinds(), taskOp.getDependVars(),
2521 moduleTranslation, dds);
2522
2523 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2524 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2525 moduleTranslation.getOpenMPBuilder()->createTask(
2526 ompLoc, allocaIP, bodyCB, !taskOp.getUntied(),
2527 moduleTranslation.lookupValue(taskOp.getFinal()),
2528 moduleTranslation.lookupValue(taskOp.getIfExpr()), dds,
2529 taskOp.getMergeable(),
2530 moduleTranslation.lookupValue(taskOp.getEventHandle()),
2531 moduleTranslation.lookupValue(taskOp.getPriority()));
2532
2533 if (failed(handleError(afterIP, *taskOp)))
2534 return failure();
2535
2536 // Set the correct branch target for task cancellation
2537 popCancelFinalizationCB(cancelTerminators, ompBuilder, afterIP.get());
2538
2539 builder.restoreIP(*afterIP);
2540 return success();
2541}
2542
2543// Converts an OpenMP taskloop construct into LLVM IR using OpenMPIRBuilder.
2544static LogicalResult
2545convertOmpTaskloopOp(Operation &opInst, llvm::IRBuilderBase &builder,
2546 LLVM::ModuleTranslation &moduleTranslation) {
2547 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2548 auto taskloopOp = cast<omp::TaskloopOp>(opInst);
2549 if (failed(checkImplementationStatus(opInst)))
2550 return failure();
2551
2552 // It stores the pointer of allocated firstprivate copies,
2553 // which can be used later for freeing the allocated space.
2554 SmallVector<llvm::Value *> llvmFirstPrivateVars;
2555 PrivateVarsInfo privateVarsInfo(taskloopOp);
2556 TaskContextStructManager taskStructMgr{builder, moduleTranslation,
2557 privateVarsInfo.privatizers};
2558
2559 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2560 findAllocaInsertPoint(builder, moduleTranslation);
2561
2562 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end());
2563 llvm::BasicBlock *taskloopStartBlock = llvm::BasicBlock::Create(
2564 builder.getContext(), "omp.taskloop.start",
2565 /*Parent=*/builder.GetInsertBlock()->getParent());
2566 llvm::Instruction *branchToTaskloopStartBlock =
2567 builder.CreateBr(taskloopStartBlock);
2568 builder.SetInsertPoint(branchToTaskloopStartBlock);
2569
2570 llvm::BasicBlock *copyBlock =
2571 splitBB(builder, /*CreateBranch=*/true, "omp.private.copy");
2572 llvm::BasicBlock *initBlock =
2573 splitBB(builder, /*CreateBranch=*/true, "omp.private.init");
2574
2576 moduleTranslation, allocaIP);
2577
2578 // Allocate and initialize private variables
2579 builder.SetInsertPoint(initBlock->getTerminator());
2580
2581 // TODO: don't allocate if the loop has zero iterations.
2582 taskStructMgr.generateTaskContextStruct();
2583 taskStructMgr.createGEPsToPrivateVars();
2584
2585 llvmFirstPrivateVars.resize(privateVarsInfo.blockArgs.size());
2586
2587 for (auto [i, zip] : llvm::enumerate(llvm::zip_equal(
2588 privateVarsInfo.privatizers, privateVarsInfo.mlirVars,
2589 privateVarsInfo.blockArgs, taskStructMgr.getLLVMPrivateVarGEPs()))) {
2590 auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVarAlloc] = zip;
2591 // To be handled inside the taskloop.
2592 if (!privDecl.readsFromMold())
2593 continue;
2594 assert(llvmPrivateVarAlloc &&
2595 "reads from mold so shouldn't have been skipped");
2596
2597 llvm::Expected<llvm::Value *> privateVarOrErr =
2598 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2599 blockArg, llvmPrivateVarAlloc, initBlock);
2600 if (!privateVarOrErr)
2601 return handleError(privateVarOrErr, *taskloopOp.getOperation());
2602
2603 llvmFirstPrivateVars[i] = privateVarOrErr.get();
2604
2605 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2606 builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
2607
2608 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
2609 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2610 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
2611 // Load it so we have the value pointed to by the GEP
2612 llvmPrivateVarAlloc = builder.CreateLoad(privateVarOrErr.get()->getType(),
2613 llvmPrivateVarAlloc);
2614 }
2615 assert(llvmPrivateVarAlloc->getType() ==
2616 moduleTranslation.convertType(blockArg.getType()));
2617 }
2618
2619 // firstprivate copy region
2620 setInsertPointForPossiblyEmptyBlock(builder, copyBlock);
2621 if (failed(copyFirstPrivateVars(
2622 taskloopOp, builder, moduleTranslation, privateVarsInfo.mlirVars,
2623 taskStructMgr.getLLVMPrivateVarGEPs(), privateVarsInfo.privatizers,
2624 taskloopOp.getPrivateNeedsBarrier())))
2625 return llvm::failure();
2626
2627 // Set up inserttion point for call to createTaskloop()
2628 builder.SetInsertPoint(taskloopStartBlock);
2629
2630 auto bodyCB = [&](InsertPointTy allocaIP,
2631 InsertPointTy codegenIP) -> llvm::Error {
2632 // Save the alloca insertion point on ModuleTranslation stack for use in
2633 // nested regions.
2635 moduleTranslation, allocaIP);
2636
2637 // translate the body of the taskloop:
2638 builder.restoreIP(codegenIP);
2639
2640 llvm::BasicBlock *privInitBlock = nullptr;
2641 privateVarsInfo.llvmVars.resize(privateVarsInfo.blockArgs.size());
2642 for (auto [i, zip] : llvm::enumerate(llvm::zip_equal(
2643 privateVarsInfo.blockArgs, privateVarsInfo.privatizers,
2644 privateVarsInfo.mlirVars))) {
2645 auto [blockArg, privDecl, mlirPrivVar] = zip;
2646 // This is handled before the task executes
2647 if (privDecl.readsFromMold())
2648 continue;
2649
2650 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2651 llvm::Type *llvmAllocType =
2652 moduleTranslation.convertType(privDecl.getType());
2653 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
2654 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
2655 llvmAllocType, /*ArraySize=*/nullptr, "omp.private.alloc");
2656
2657 llvm::Expected<llvm::Value *> privateVarOrError =
2658 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2659 blockArg, llvmPrivateVar, privInitBlock);
2660 if (!privateVarOrError)
2661 return privateVarOrError.takeError();
2662 moduleTranslation.mapValue(blockArg, privateVarOrError.get());
2663 privateVarsInfo.llvmVars[i] = privateVarOrError.get();
2664 }
2665
2666 taskStructMgr.createGEPsToPrivateVars();
2667 for (auto [i, llvmPrivVar] :
2668 llvm::enumerate(taskStructMgr.getLLVMPrivateVarGEPs())) {
2669 if (!llvmPrivVar) {
2670 assert(privateVarsInfo.llvmVars[i] &&
2671 "This is added in the loop above");
2672 continue;
2673 }
2674 privateVarsInfo.llvmVars[i] = llvmPrivVar;
2675 }
2676
2677 // Find and map the addresses of each variable within the taskloop context
2678 // structure
2679 for (auto [blockArg, llvmPrivateVar, privateDecl] :
2680 llvm::zip_equal(privateVarsInfo.blockArgs, privateVarsInfo.llvmVars,
2681 privateVarsInfo.privatizers)) {
2682 // This was handled above.
2683 if (!privateDecl.readsFromMold())
2684 continue;
2685 // Fix broken pass-by-value case for Fortran character boxes
2686 if (!mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2687 llvmPrivateVar = builder.CreateLoad(
2688 moduleTranslation.convertType(blockArg.getType()), llvmPrivateVar);
2689 }
2690 assert(llvmPrivateVar->getType() ==
2691 moduleTranslation.convertType(blockArg.getType()));
2692 moduleTranslation.mapValue(blockArg, llvmPrivateVar);
2693 }
2694
2695 auto continuationBlockOrError =
2696 convertOmpOpRegions(taskloopOp.getRegion(), "omp.taskloop.region",
2697 builder, moduleTranslation);
2698
2699 if (failed(handleError(continuationBlockOrError, opInst)))
2700 return llvm::make_error<PreviouslyReportedError>();
2701
2702 builder.SetInsertPoint(continuationBlockOrError.get()->getTerminator());
2703
2704 // This is freeing the private variables as mapped inside of the task: these
2705 // will be per-task private copies possibly after task duplication. This is
2706 // handled transparently by how these are passed to the structure passed
2707 // into the outlined function. When the task is duplicated, that structure
2708 // is duplicated too.
2709 if (failed(cleanupPrivateVars(builder, moduleTranslation,
2710 taskloopOp.getLoc(), privateVarsInfo.llvmVars,
2711 privateVarsInfo.privatizers)))
2712 return llvm::make_error<PreviouslyReportedError>();
2713 // Similarly, the task context structure freed inside the task is the
2714 // per-task copy after task duplication.
2715 taskStructMgr.freeStructPtr();
2716
2717 return llvm::Error::success();
2718 };
2719
2720 // Taskloop divides into an appropriate number of tasks by repeatedly
2721 // duplicating the original task. Each time this is done, the task context
2722 // structure must be duplicated too.
2723 auto taskDupCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP,
2724 llvm::Value *destPtr, llvm::Value *srcPtr)
2726 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2727 builder.restoreIP(codegenIP);
2728
2729 llvm::Type *ptrTy =
2730 builder.getPtrTy(srcPtr->getType()->getPointerAddressSpace());
2731 llvm::Value *src =
2732 builder.CreateLoad(ptrTy, srcPtr, "omp.taskloop.context.src");
2733
2734 TaskContextStructManager &srcStructMgr = taskStructMgr;
2735 TaskContextStructManager destStructMgr(builder, moduleTranslation,
2736 privateVarsInfo.privatizers);
2737 destStructMgr.generateTaskContextStruct();
2738 llvm::Value *dest = destStructMgr.getStructPtr();
2739 dest->setName("omp.taskloop.context.dest");
2740 builder.CreateStore(dest, destPtr);
2741
2743 srcStructMgr.createGEPsToPrivateVars(src);
2745 destStructMgr.createGEPsToPrivateVars(dest);
2746
2747 // Inline init regions.
2748 for (auto [privDecl, mold, blockArg, llvmPrivateVarAlloc] :
2749 llvm::zip_equal(privateVarsInfo.privatizers, srcGEPs,
2750 privateVarsInfo.blockArgs, destGEPs)) {
2751 // To be handled inside task body.
2752 if (!privDecl.readsFromMold())
2753 continue;
2754 assert(llvmPrivateVarAlloc &&
2755 "reads from mold so shouldn't have been skipped");
2756
2757 llvm::Expected<llvm::Value *> privateVarOrErr =
2758 initPrivateVar(builder, moduleTranslation, privDecl, mold, blockArg,
2759 llvmPrivateVarAlloc, builder.GetInsertBlock());
2760 if (!privateVarOrErr)
2761 return privateVarOrErr.takeError();
2762
2764
2765 // TODO: this is a bit of a hack for Fortran character boxes.
2766 // Character boxes are passed by value into the init region and then the
2767 // initialized character box is yielded by value. Here we need to store
2768 // the yielded value into the private allocation, and load the private
2769 // allocation to match the type expected by region block arguments.
2770 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
2771 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2772 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
2773 // Load it so we have the value pointed to by the GEP
2774 llvmPrivateVarAlloc = builder.CreateLoad(
2775 privateVarOrErr.get()->getType(), llvmPrivateVarAlloc);
2776 }
2777 assert(llvmPrivateVarAlloc->getType() ==
2778 moduleTranslation.convertType(blockArg.getType()));
2779
2780 // Mapping blockArg -> llvmPrivateVarAlloc is done inside the body
2781 // callback so that OpenMPIRBuilder doesn't try to pass each GEP address
2782 // through a stack allocated structure.
2783 }
2784
2785 if (failed(copyFirstPrivateVars(
2786 &opInst, builder, moduleTranslation, srcGEPs, destGEPs,
2787 privateVarsInfo.privatizers, taskloopOp.getPrivateNeedsBarrier())))
2788 return llvm::make_error<PreviouslyReportedError>();
2789
2790 return builder.saveIP();
2791 };
2792
2793 auto loopOp = cast<omp::LoopNestOp>(taskloopOp.getWrappedLoop());
2794
2795 auto loopInfo = [&]() -> llvm::Expected<llvm::CanonicalLoopInfo *> {
2796 llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
2797 return loopInfo;
2798 };
2799
2800 Operation::operand_range lowerBounds = loopOp.getLoopLowerBounds();
2801 Operation::operand_range upperBounds = loopOp.getLoopUpperBounds();
2802 Operation::operand_range steps = loopOp.getLoopSteps();
2803 llvm::Type *boundType =
2804 moduleTranslation.lookupValue(lowerBounds[0])->getType();
2805 llvm::Value *lbVal = nullptr;
2806 llvm::Value *ubVal = builder.getIntN(boundType->getIntegerBitWidth(), 1);
2807 llvm::Value *stepVal = nullptr;
2808 if (loopOp.getCollapseNumLoops() > 1) {
2809 // In cases where Collapse is used with Taskloop, the upper bound of the
2810 // iteration space needs to be recalculated to cater for the collapsed loop.
2811 // The Collapsed Loop UpperBound is the product of all collapsed
2812 // loop's tripcount.
2813 // The LowerBound for collapsed loops is always 1. When the loops are
2814 // collapsed, it will reset the bounds and introduce processing to ensure
2815 // the index's are presented as expected. As this happens after creating
2816 // Taskloop, these bounds need predicting. Example:
2817 // !$omp taskloop collapse(2)
2818 // do i = 1, 10
2819 // do j = 1, 5
2820 // ..
2821 // end do
2822 // end do
2823 // This loop above has a total of 50 iterations, so the lb will be 1, and
2824 // the ub will be 50. collapseLoops in OMPIRBuilder then handles ensuring
2825 // that i and j are properly presented when used in the loop.
2826 for (uint64_t i = 0; i < loopOp.getCollapseNumLoops(); i++) {
2827 llvm::Value *loopLb = moduleTranslation.lookupValue(lowerBounds[i]);
2828 llvm::Value *loopUb = moduleTranslation.lookupValue(upperBounds[i]);
2829 llvm::Value *loopStep = moduleTranslation.lookupValue(steps[i]);
2830 // In some cases, such as where the ub is less than the lb so the loop
2831 // steps down, the calculation for the loopTripCount is swapped. To ensure
2832 // the correct value is found, calculate both UB - LB and LB - UB then
2833 // select which value to use depending on how the loop has been
2834 // configured.
2835 llvm::Value *loopLbMinusOne = builder.CreateSub(
2836 loopLb, builder.getIntN(boundType->getIntegerBitWidth(), 1));
2837 llvm::Value *loopUbMinusOne = builder.CreateSub(
2838 loopUb, builder.getIntN(boundType->getIntegerBitWidth(), 1));
2839 llvm::Value *boundsCmp = builder.CreateICmpSLT(loopLb, loopUb);
2840 llvm::Value *ubMinusLb = builder.CreateSub(loopUb, loopLbMinusOne);
2841 llvm::Value *lbMinusUb = builder.CreateSub(loopLb, loopUbMinusOne);
2842 llvm::Value *loopTripCount =
2843 builder.CreateSelect(boundsCmp, ubMinusLb, lbMinusUb);
2844 loopTripCount = builder.CreateBinaryIntrinsic(
2845 llvm::Intrinsic::abs, loopTripCount, builder.getFalse());
2846 // For loops that have a step value not equal to 1, we need to adjust the
2847 // trip count to ensure the correct number of iterations for the loop is
2848 // captured.
2849 llvm::Value *loopTripCountDivStep =
2850 builder.CreateSDiv(loopTripCount, loopStep);
2851 loopTripCountDivStep = builder.CreateBinaryIntrinsic(
2852 llvm::Intrinsic::abs, loopTripCountDivStep, builder.getFalse());
2853 llvm::Value *loopTripCountRem =
2854 builder.CreateSRem(loopTripCount, loopStep);
2855 loopTripCountRem = builder.CreateBinaryIntrinsic(
2856 llvm::Intrinsic::abs, loopTripCountRem, builder.getFalse());
2857 llvm::Value *needsRoundUp = builder.CreateICmpNE(
2858 loopTripCountRem,
2859 builder.getIntN(loopTripCountRem->getType()->getIntegerBitWidth(),
2860 0));
2861 loopTripCount =
2862 builder.CreateAdd(loopTripCountDivStep,
2863 builder.CreateZExtOrTrunc(
2864 needsRoundUp, loopTripCountDivStep->getType()));
2865 ubVal = builder.CreateMul(ubVal, loopTripCount);
2866 }
2867 lbVal = builder.getIntN(boundType->getIntegerBitWidth(), 1);
2868 stepVal = builder.getIntN(boundType->getIntegerBitWidth(), 1);
2869 } else {
2870 lbVal = moduleTranslation.lookupValue(lowerBounds[0]);
2871 ubVal = moduleTranslation.lookupValue(upperBounds[0]);
2872 stepVal = moduleTranslation.lookupValue(steps[0]);
2873 }
2874 assert(lbVal != nullptr && "Expected value for lbVal");
2875 assert(ubVal != nullptr && "Expected value for ubVal");
2876 assert(stepVal != nullptr && "Expected value for stepVal");
2877
2878 llvm::Value *ifCond = nullptr;
2879 llvm::Value *grainsize = nullptr;
2880 int sched = 0; // default
2881 mlir::Value grainsizeVal = taskloopOp.getGrainsize();
2882 mlir::Value numTasksVal = taskloopOp.getNumTasks();
2883 if (Value ifVar = taskloopOp.getIfExpr())
2884 ifCond = moduleTranslation.lookupValue(ifVar);
2885 if (grainsizeVal) {
2886 grainsize = moduleTranslation.lookupValue(grainsizeVal);
2887 sched = 1; // grainsize
2888 } else if (numTasksVal) {
2889 grainsize = moduleTranslation.lookupValue(numTasksVal);
2890 sched = 2; // num_tasks
2891 }
2892
2893 llvm::OpenMPIRBuilder::TaskDupCallbackTy taskDupOrNull = nullptr;
2894 if (taskStructMgr.getStructPtr())
2895 taskDupOrNull = taskDupCB;
2896
2897 llvm::OpenMPIRBuilder &ompBuilder = *moduleTranslation.getOpenMPBuilder();
2898 SmallVector<llvm::BranchInst *> cancelTerminators;
2899 // The directive to match here is OMPD_taskgroup because it is the
2900 // taskgroup which is canceled. This is handled here because it is the
2901 // task's cleanup block which should be branched to. It doesn't depend upon
2902 // nogroup because even in that case the taskloop might still be inside an
2903 // explicit taskgroup.
2904 pushCancelFinalizationCB(cancelTerminators, builder, ompBuilder, taskloopOp,
2905 llvm::omp::Directive::OMPD_taskgroup);
2906
2907 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2908 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2909 moduleTranslation.getOpenMPBuilder()->createTaskloop(
2910 ompLoc, allocaIP, bodyCB, loopInfo, lbVal, ubVal, stepVal,
2911 taskloopOp.getUntied(), ifCond, grainsize, taskloopOp.getNogroup(),
2912 sched, moduleTranslation.lookupValue(taskloopOp.getFinal()),
2913 taskloopOp.getMergeable(),
2914 moduleTranslation.lookupValue(taskloopOp.getPriority()),
2915 loopOp.getCollapseNumLoops(), taskDupOrNull,
2916 taskStructMgr.getStructPtr());
2917
2918 if (failed(handleError(afterIP, opInst)))
2919 return failure();
2920
2921 popCancelFinalizationCB(cancelTerminators, ompBuilder, afterIP.get());
2922
2923 builder.restoreIP(*afterIP);
2924 return success();
2925}
2926
2927/// Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder.
2928static LogicalResult
2929convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder,
2930 LLVM::ModuleTranslation &moduleTranslation) {
2931 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2932 if (failed(checkImplementationStatus(*tgOp)))
2933 return failure();
2934
2935 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
2936 builder.restoreIP(codegenIP);
2937 return convertOmpOpRegions(tgOp.getRegion(), "omp.taskgroup.region",
2938 builder, moduleTranslation)
2939 .takeError();
2940 };
2941
2942 InsertPointTy allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
2943 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2944 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2945 moduleTranslation.getOpenMPBuilder()->createTaskgroup(ompLoc, allocaIP,
2946 bodyCB);
2947
2948 if (failed(handleError(afterIP, *tgOp)))
2949 return failure();
2950
2951 builder.restoreIP(*afterIP);
2952 return success();
2953}
2954
2955static LogicalResult
2956convertOmpTaskwaitOp(omp::TaskwaitOp twOp, llvm::IRBuilderBase &builder,
2957 LLVM::ModuleTranslation &moduleTranslation) {
2958 if (failed(checkImplementationStatus(*twOp)))
2959 return failure();
2960
2961 moduleTranslation.getOpenMPBuilder()->createTaskwait(builder.saveIP());
2962 return success();
2963}
2964
2965/// Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
2966static LogicalResult
2967convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
2968 LLVM::ModuleTranslation &moduleTranslation) {
2969 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
2970 auto wsloopOp = cast<omp::WsloopOp>(opInst);
2971 if (failed(checkImplementationStatus(opInst)))
2972 return failure();
2973
2974 auto loopOp = cast<omp::LoopNestOp>(wsloopOp.getWrappedLoop());
2975 llvm::ArrayRef<bool> isByRef = getIsByRef(wsloopOp.getReductionByref());
2976 assert(isByRef.size() == wsloopOp.getNumReductionVars());
2977
2978 // Static is the default.
2979 auto schedule =
2980 wsloopOp.getScheduleKind().value_or(omp::ClauseScheduleKind::Static);
2981
2982 // Find the loop configuration.
2983 llvm::Value *step = moduleTranslation.lookupValue(loopOp.getLoopSteps()[0]);
2984 llvm::Type *ivType = step->getType();
2985 llvm::Value *chunk = nullptr;
2986 if (wsloopOp.getScheduleChunk()) {
2987 llvm::Value *chunkVar =
2988 moduleTranslation.lookupValue(wsloopOp.getScheduleChunk());
2989 chunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
2990 }
2991
2992 omp::DistributeOp distributeOp = nullptr;
2993 llvm::Value *distScheduleChunk = nullptr;
2994 bool hasDistSchedule = false;
2995 if (llvm::isa_and_present<omp::DistributeOp>(opInst.getParentOp())) {
2996 distributeOp = cast<omp::DistributeOp>(opInst.getParentOp());
2997 hasDistSchedule = distributeOp.getDistScheduleStatic();
2998 if (distributeOp.getDistScheduleChunkSize()) {
2999 llvm::Value *chunkVar = moduleTranslation.lookupValue(
3000 distributeOp.getDistScheduleChunkSize());
3001 distScheduleChunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
3002 }
3003 }
3004
3005 PrivateVarsInfo privateVarsInfo(wsloopOp);
3006
3008 collectReductionDecls(wsloopOp, reductionDecls);
3009 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3010 findAllocaInsertPoint(builder, moduleTranslation);
3011
3012 SmallVector<llvm::Value *> privateReductionVariables(
3013 wsloopOp.getNumReductionVars());
3014
3016 builder, moduleTranslation, privateVarsInfo, allocaIP);
3017 if (handleError(afterAllocas, opInst).failed())
3018 return failure();
3019
3020 DenseMap<Value, llvm::Value *> reductionVariableMap;
3021
3022 MutableArrayRef<BlockArgument> reductionArgs =
3023 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
3024
3025 SmallVector<DeferredStore> deferredStores;
3026
3027 if (failed(allocReductionVars(wsloopOp, reductionArgs, builder,
3028 moduleTranslation, allocaIP, reductionDecls,
3029 privateReductionVariables, reductionVariableMap,
3030 deferredStores, isByRef)))
3031 return failure();
3032
3033 if (handleError(initPrivateVars(builder, moduleTranslation, privateVarsInfo),
3034 opInst)
3035 .failed())
3036 return failure();
3037
3038 if (failed(copyFirstPrivateVars(
3039 wsloopOp, builder, moduleTranslation, privateVarsInfo.mlirVars,
3040 privateVarsInfo.llvmVars, privateVarsInfo.privatizers,
3041 wsloopOp.getPrivateNeedsBarrier())))
3042 return failure();
3043
3044 assert(afterAllocas.get()->getSinglePredecessor());
3045 if (failed(initReductionVars(wsloopOp, reductionArgs, builder,
3046 moduleTranslation,
3047 afterAllocas.get()->getSinglePredecessor(),
3048 reductionDecls, privateReductionVariables,
3049 reductionVariableMap, isByRef, deferredStores)))
3050 return failure();
3051
3052 // TODO: Handle doacross loops when the ordered clause has a parameter.
3053 bool isOrdered = wsloopOp.getOrdered().has_value();
3054 std::optional<omp::ScheduleModifier> scheduleMod = wsloopOp.getScheduleMod();
3055 bool isSimd = wsloopOp.getScheduleSimd();
3056 bool loopNeedsBarrier = !wsloopOp.getNowait();
3057
3058 // The only legal way for the direct parent to be omp.distribute is that this
3059 // represents 'distribute parallel do'. Otherwise, this is a regular
3060 // worksharing loop.
3061 llvm::omp::WorksharingLoopType workshareLoopType =
3062 llvm::isa_and_present<omp::DistributeOp>(opInst.getParentOp())
3063 ? llvm::omp::WorksharingLoopType::DistributeForStaticLoop
3064 : llvm::omp::WorksharingLoopType::ForStaticLoop;
3065
3066 SmallVector<llvm::BranchInst *> cancelTerminators;
3067 pushCancelFinalizationCB(cancelTerminators, builder, *ompBuilder, wsloopOp,
3068 llvm::omp::Directive::OMPD_for);
3069
3070 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3071
3072 // Initialize linear variables and linear step
3073 LinearClauseProcessor linearClauseProcessor;
3074
3075 if (!wsloopOp.getLinearVars().empty()) {
3076 auto linearVarTypes = wsloopOp.getLinearVarTypes().value();
3077 for (mlir::Attribute linearVarType : linearVarTypes)
3078 linearClauseProcessor.registerType(moduleTranslation, linearVarType);
3079
3080 for (auto [idx, linearVar] : llvm::enumerate(wsloopOp.getLinearVars()))
3081 linearClauseProcessor.createLinearVar(
3082 builder, moduleTranslation, moduleTranslation.lookupValue(linearVar),
3083 idx);
3084 for (mlir::Value linearStep : wsloopOp.getLinearStepVars())
3085 linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
3086 }
3087
3089 wsloopOp.getRegion(), "omp.wsloop.region", builder, moduleTranslation);
3090
3091 if (failed(handleError(regionBlock, opInst)))
3092 return failure();
3093
3094 llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
3095
3096 // Emit Initialization and Update IR for linear variables
3097 if (!wsloopOp.getLinearVars().empty()) {
3098 linearClauseProcessor.initLinearVar(builder, moduleTranslation,
3099 loopInfo->getPreheader());
3100 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
3101 moduleTranslation.getOpenMPBuilder()->createBarrier(
3102 builder.saveIP(), llvm::omp::OMPD_barrier);
3103 if (failed(handleError(afterBarrierIP, *loopOp)))
3104 return failure();
3105 builder.restoreIP(*afterBarrierIP);
3106 linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
3107 loopInfo->getIndVar());
3108 linearClauseProcessor.splitLinearFiniBB(builder, loopInfo->getExit());
3109 }
3110
3111 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
3112
3113 // Check if we can generate no-loop kernel
3114 bool noLoopMode = false;
3115 omp::TargetOp targetOp = wsloopOp->getParentOfType<mlir::omp::TargetOp>();
3116 if (targetOp) {
3117 Operation *targetCapturedOp = targetOp.getInnermostCapturedOmpOp();
3118 // We need this check because, without it, noLoopMode would be set to true
3119 // for every omp.wsloop nested inside a no-loop SPMD target region, even if
3120 // that loop is not the top-level SPMD one.
3121 if (loopOp == targetCapturedOp) {
3122 omp::TargetRegionFlags kernelFlags =
3123 targetOp.getKernelExecFlags(targetCapturedOp);
3124 if (omp::bitEnumContainsAll(kernelFlags,
3125 omp::TargetRegionFlags::spmd |
3126 omp::TargetRegionFlags::no_loop) &&
3127 !omp::bitEnumContainsAny(kernelFlags,
3128 omp::TargetRegionFlags::generic))
3129 noLoopMode = true;
3130 }
3131 }
3132
3133 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
3134 ompBuilder->applyWorkshareLoop(
3135 ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
3136 convertToScheduleKind(schedule), chunk, isSimd,
3137 scheduleMod == omp::ScheduleModifier::monotonic,
3138 scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
3139 workshareLoopType, noLoopMode, hasDistSchedule, distScheduleChunk);
3140
3141 if (failed(handleError(wsloopIP, opInst)))
3142 return failure();
3143
3144 // Emit finalization and in-place rewrites for linear vars.
3145 if (!wsloopOp.getLinearVars().empty()) {
3146 llvm::OpenMPIRBuilder::InsertPointTy oldIP = builder.saveIP();
3147 assert(loopInfo->getLastIter() &&
3148 "`lastiter` in CanonicalLoopInfo is nullptr");
3149 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
3150 linearClauseProcessor.finalizeLinearVar(builder, moduleTranslation,
3151 loopInfo->getLastIter());
3152 if (failed(handleError(afterBarrierIP, *loopOp)))
3153 return failure();
3154 for (size_t index = 0; index < wsloopOp.getLinearVars().size(); index++)
3155 linearClauseProcessor.rewriteInPlace(builder, "omp.loop_nest.region",
3156 index);
3157 builder.restoreIP(oldIP);
3158 }
3159
3160 // Set the correct branch target for task cancellation
3161 popCancelFinalizationCB(cancelTerminators, *ompBuilder, wsloopIP.get());
3162
3163 // Process the reductions if required.
3164 if (failed(createReductionsAndCleanup(
3165 wsloopOp, builder, moduleTranslation, allocaIP, reductionDecls,
3166 privateReductionVariables, isByRef, wsloopOp.getNowait(),
3167 /*isTeamsReduction=*/false)))
3168 return failure();
3169
3170 return cleanupPrivateVars(builder, moduleTranslation, wsloopOp.getLoc(),
3171 privateVarsInfo.llvmVars,
3172 privateVarsInfo.privatizers);
3173}
3174
3175/// Converts the OpenMP parallel operation to LLVM IR.
3176static LogicalResult
3177convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
3178 LLVM::ModuleTranslation &moduleTranslation) {
3179 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
3180 ArrayRef<bool> isByRef = getIsByRef(opInst.getReductionByref());
3181 assert(isByRef.size() == opInst.getNumReductionVars());
3182 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3183 bool isCancellable = constructIsCancellable(opInst);
3184
3185 if (failed(checkImplementationStatus(*opInst)))
3186 return failure();
3187
3188 PrivateVarsInfo privateVarsInfo(opInst);
3189
3190 // Collect reduction declarations
3192 collectReductionDecls(opInst, reductionDecls);
3193 SmallVector<llvm::Value *> privateReductionVariables(
3194 opInst.getNumReductionVars());
3195 SmallVector<DeferredStore> deferredStores;
3196
3197 auto bodyGenCB = [&](InsertPointTy allocaIP,
3198 InsertPointTy codeGenIP) -> llvm::Error {
3200 builder, moduleTranslation, privateVarsInfo, allocaIP);
3201 if (handleError(afterAllocas, *opInst).failed())
3202 return llvm::make_error<PreviouslyReportedError>();
3203
3204 // Allocate reduction vars
3205 DenseMap<Value, llvm::Value *> reductionVariableMap;
3206
3207 MutableArrayRef<BlockArgument> reductionArgs =
3208 cast<omp::BlockArgOpenMPOpInterface>(*opInst).getReductionBlockArgs();
3209
3210 allocaIP =
3211 InsertPointTy(allocaIP.getBlock(),
3212 allocaIP.getBlock()->getTerminator()->getIterator());
3213
3214 if (failed(allocReductionVars(
3215 opInst, reductionArgs, builder, moduleTranslation, allocaIP,
3216 reductionDecls, privateReductionVariables, reductionVariableMap,
3217 deferredStores, isByRef)))
3218 return llvm::make_error<PreviouslyReportedError>();
3219
3220 assert(afterAllocas.get()->getSinglePredecessor());
3221 builder.restoreIP(codeGenIP);
3222
3223 if (handleError(
3224 initPrivateVars(builder, moduleTranslation, privateVarsInfo),
3225 *opInst)
3226 .failed())
3227 return llvm::make_error<PreviouslyReportedError>();
3228
3229 if (failed(copyFirstPrivateVars(
3230 opInst, builder, moduleTranslation, privateVarsInfo.mlirVars,
3231 privateVarsInfo.llvmVars, privateVarsInfo.privatizers,
3232 opInst.getPrivateNeedsBarrier())))
3233 return llvm::make_error<PreviouslyReportedError>();
3234
3235 if (failed(
3236 initReductionVars(opInst, reductionArgs, builder, moduleTranslation,
3237 afterAllocas.get()->getSinglePredecessor(),
3238 reductionDecls, privateReductionVariables,
3239 reductionVariableMap, isByRef, deferredStores)))
3240 return llvm::make_error<PreviouslyReportedError>();
3241
3242 // Save the alloca insertion point on ModuleTranslation stack for use in
3243 // nested regions.
3245 moduleTranslation, allocaIP);
3246
3247 // ParallelOp has only one region associated with it.
3249 opInst.getRegion(), "omp.par.region", builder, moduleTranslation);
3250 if (!regionBlock)
3251 return regionBlock.takeError();
3252
3253 // Process the reductions if required.
3254 if (opInst.getNumReductionVars() > 0) {
3255 // Collect reduction info
3257 SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
3259 owningReductionGenRefDataPtrGens;
3261 collectReductionInfo(opInst, builder, moduleTranslation, reductionDecls,
3262 owningReductionGens, owningAtomicReductionGens,
3263 owningReductionGenRefDataPtrGens,
3264 privateReductionVariables, reductionInfos, isByRef);
3265
3266 // Move to region cont block
3267 builder.SetInsertPoint((*regionBlock)->getTerminator());
3268
3269 // Generate reductions from info
3270 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
3271 builder.SetInsertPoint(tempTerminator);
3272
3273 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
3274 ompBuilder->createReductions(
3275 builder.saveIP(), allocaIP, reductionInfos, isByRef,
3276 /*IsNoWait=*/false, /*IsTeamsReduction=*/false);
3277 if (!contInsertPoint)
3278 return contInsertPoint.takeError();
3279
3280 if (!contInsertPoint->getBlock())
3281 return llvm::make_error<PreviouslyReportedError>();
3282
3283 tempTerminator->eraseFromParent();
3284 builder.restoreIP(*contInsertPoint);
3285 }
3286
3287 return llvm::Error::success();
3288 };
3289
3290 auto privCB = [](InsertPointTy allocaIP, InsertPointTy codeGenIP,
3291 llvm::Value &, llvm::Value &val, llvm::Value *&replVal) {
3292 // tell OpenMPIRBuilder not to do anything. We handled Privatisation in
3293 // bodyGenCB.
3294 replVal = &val;
3295 return codeGenIP;
3296 };
3297
3298 // TODO: Perform finalization actions for variables. This has to be
3299 // called for variables which have destructors/finalizers.
3300 auto finiCB = [&](InsertPointTy codeGenIP) -> llvm::Error {
3301 InsertPointTy oldIP = builder.saveIP();
3302 builder.restoreIP(codeGenIP);
3303
3304 // if the reduction has a cleanup region, inline it here to finalize the
3305 // reduction variables
3306 SmallVector<Region *> reductionCleanupRegions;
3307 llvm::transform(reductionDecls, std::back_inserter(reductionCleanupRegions),
3308 [](omp::DeclareReductionOp reductionDecl) {
3309 return &reductionDecl.getCleanupRegion();
3310 });
3311 if (failed(inlineOmpRegionCleanup(
3312 reductionCleanupRegions, privateReductionVariables,
3313 moduleTranslation, builder, "omp.reduction.cleanup")))
3314 return llvm::createStringError(
3315 "failed to inline `cleanup` region of `omp.declare_reduction`");
3316
3317 if (failed(cleanupPrivateVars(builder, moduleTranslation, opInst.getLoc(),
3318 privateVarsInfo.llvmVars,
3319 privateVarsInfo.privatizers)))
3320 return llvm::make_error<PreviouslyReportedError>();
3321
3322 // If we could be performing cancellation, add the cancellation barrier on
3323 // the way out of the outlined region.
3324 if (isCancellable) {
3325 auto IPOrErr = ompBuilder->createBarrier(
3326 llvm::OpenMPIRBuilder::LocationDescription(builder),
3327 llvm::omp::Directive::OMPD_unknown,
3328 /* ForceSimpleCall */ false,
3329 /* CheckCancelFlag */ false);
3330 if (!IPOrErr)
3331 return IPOrErr.takeError();
3332 }
3333
3334 builder.restoreIP(oldIP);
3335 return llvm::Error::success();
3336 };
3337
3338 llvm::Value *ifCond = nullptr;
3339 if (auto ifVar = opInst.getIfExpr())
3340 ifCond = moduleTranslation.lookupValue(ifVar);
3341 llvm::Value *numThreads = nullptr;
3342 if (!opInst.getNumThreadsVars().empty())
3343 numThreads = moduleTranslation.lookupValue(opInst.getNumThreads(0));
3344 auto pbKind = llvm::omp::OMP_PROC_BIND_default;
3345 if (auto bind = opInst.getProcBindKind())
3346 pbKind = getProcBindKind(*bind);
3347
3348 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3349 findAllocaInsertPoint(builder, moduleTranslation);
3350 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3351
3352 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3353 ompBuilder->createParallel(ompLoc, allocaIP, bodyGenCB, privCB, finiCB,
3354 ifCond, numThreads, pbKind, isCancellable);
3355
3356 if (failed(handleError(afterIP, *opInst)))
3357 return failure();
3358
3359 builder.restoreIP(*afterIP);
3360 return success();
3361}
3362
3363/// Convert Order attribute to llvm::omp::OrderKind.
3364static llvm::omp::OrderKind
3365convertOrderKind(std::optional<omp::ClauseOrderKind> o) {
3366 if (!o)
3367 return llvm::omp::OrderKind::OMP_ORDER_unknown;
3368 switch (*o) {
3369 case omp::ClauseOrderKind::Concurrent:
3370 return llvm::omp::OrderKind::OMP_ORDER_concurrent;
3371 }
3372 llvm_unreachable("Unknown ClauseOrderKind kind");
3373}
3374
3375/// Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder.
3376static LogicalResult
3377convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
3378 LLVM::ModuleTranslation &moduleTranslation) {
3379 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3380 auto simdOp = cast<omp::SimdOp>(opInst);
3381
3382 if (failed(checkImplementationStatus(opInst)))
3383 return failure();
3384
3385 PrivateVarsInfo privateVarsInfo(simdOp);
3386
3387 MutableArrayRef<BlockArgument> reductionArgs =
3388 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
3389 DenseMap<Value, llvm::Value *> reductionVariableMap;
3390 SmallVector<llvm::Value *> privateReductionVariables(
3391 simdOp.getNumReductionVars());
3392 SmallVector<DeferredStore> deferredStores;
3394 collectReductionDecls(simdOp, reductionDecls);
3395 llvm::ArrayRef<bool> isByRef = getIsByRef(simdOp.getReductionByref());
3396 assert(isByRef.size() == simdOp.getNumReductionVars());
3397
3398 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3399 findAllocaInsertPoint(builder, moduleTranslation);
3400
3402 builder, moduleTranslation, privateVarsInfo, allocaIP);
3403 if (handleError(afterAllocas, opInst).failed())
3404 return failure();
3405
3406 // Initialize linear variables and linear step
3407 LinearClauseProcessor linearClauseProcessor;
3408
3409 if (!simdOp.getLinearVars().empty()) {
3410 auto linearVarTypes = simdOp.getLinearVarTypes().value();
3411 for (mlir::Attribute linearVarType : linearVarTypes)
3412 linearClauseProcessor.registerType(moduleTranslation, linearVarType);
3413 for (auto [idx, linearVar] : llvm::enumerate(simdOp.getLinearVars())) {
3414 bool isImplicit = false;
3415 for (auto [mlirPrivVar, llvmPrivateVar] : llvm::zip_equal(
3416 privateVarsInfo.mlirVars, privateVarsInfo.llvmVars)) {
3417 // If the linear variable is implicit, reuse the already
3418 // existing llvm::Value
3419 if (linearVar == mlirPrivVar) {
3420 isImplicit = true;
3421 linearClauseProcessor.createLinearVar(builder, moduleTranslation,
3422 llvmPrivateVar, idx);
3423 break;
3424 }
3425 }
3426
3427 if (!isImplicit)
3428 linearClauseProcessor.createLinearVar(
3429 builder, moduleTranslation,
3430 moduleTranslation.lookupValue(linearVar), idx);
3431 }
3432 for (mlir::Value linearStep : simdOp.getLinearStepVars())
3433 linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
3434 }
3435
3436 if (failed(allocReductionVars(simdOp, reductionArgs, builder,
3437 moduleTranslation, allocaIP, reductionDecls,
3438 privateReductionVariables, reductionVariableMap,
3439 deferredStores, isByRef)))
3440 return failure();
3441
3442 if (handleError(initPrivateVars(builder, moduleTranslation, privateVarsInfo),
3443 opInst)
3444 .failed())
3445 return failure();
3446
3447 // No call to copyFirstPrivateVars because FIRSTPRIVATE is not allowed for
3448 // SIMD.
3449
3450 assert(afterAllocas.get()->getSinglePredecessor());
3451 if (failed(initReductionVars(simdOp, reductionArgs, builder,
3452 moduleTranslation,
3453 afterAllocas.get()->getSinglePredecessor(),
3454 reductionDecls, privateReductionVariables,
3455 reductionVariableMap, isByRef, deferredStores)))
3456 return failure();
3457
3458 llvm::ConstantInt *simdlen = nullptr;
3459 if (std::optional<uint64_t> simdlenVar = simdOp.getSimdlen())
3460 simdlen = builder.getInt64(simdlenVar.value());
3461
3462 llvm::ConstantInt *safelen = nullptr;
3463 if (std::optional<uint64_t> safelenVar = simdOp.getSafelen())
3464 safelen = builder.getInt64(safelenVar.value());
3465
3466 llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars;
3467 llvm::omp::OrderKind order = convertOrderKind(simdOp.getOrder());
3468
3469 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
3470 std::optional<ArrayAttr> alignmentValues = simdOp.getAlignments();
3471 mlir::OperandRange operands = simdOp.getAlignedVars();
3472 for (size_t i = 0; i < operands.size(); ++i) {
3473 llvm::Value *alignment = nullptr;
3474 llvm::Value *llvmVal = moduleTranslation.lookupValue(operands[i]);
3475 llvm::Type *ty = llvmVal->getType();
3476
3477 auto intAttr = cast<IntegerAttr>((*alignmentValues)[i]);
3478 alignment = builder.getInt64(intAttr.getInt());
3479 assert(ty->isPointerTy() && "Invalid type for aligned variable");
3480 assert(alignment && "Invalid alignment value");
3481
3482 // Check if the alignment value is not a power of 2. If so, skip emitting
3483 // alignment.
3484 if (!intAttr.getValue().isPowerOf2())
3485 continue;
3486
3487 auto curInsert = builder.saveIP();
3488 builder.SetInsertPoint(sourceBlock);
3489 llvmVal = builder.CreateLoad(ty, llvmVal);
3490 builder.restoreIP(curInsert);
3491 alignedVars[llvmVal] = alignment;
3492 }
3493
3495 simdOp.getRegion(), "omp.simd.region", builder, moduleTranslation);
3496
3497 if (failed(handleError(regionBlock, opInst)))
3498 return failure();
3499
3500 llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
3501 // Emit Initialization for linear variables
3502 if (simdOp.getLinearVars().size()) {
3503 linearClauseProcessor.initLinearVar(builder, moduleTranslation,
3504 loopInfo->getPreheader());
3505
3506 linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
3507 loopInfo->getIndVar());
3508 }
3509 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
3510
3511 ompBuilder->applySimd(loopInfo, alignedVars,
3512 simdOp.getIfExpr()
3513 ? moduleTranslation.lookupValue(simdOp.getIfExpr())
3514 : nullptr,
3515 order, simdlen, safelen);
3516
3517 linearClauseProcessor.emitStoresForLinearVar(builder);
3518 for (size_t index = 0; index < simdOp.getLinearVars().size(); index++)
3519 linearClauseProcessor.rewriteInPlace(builder, "omp.loop_nest.region",
3520 index);
3521
3522 // We now need to reduce the per-simd-lane reduction variable into the
3523 // original variable. This works a bit differently to other reductions (e.g.
3524 // wsloop) because we don't need to call into the OpenMP runtime to handle
3525 // threads: everything happened in this one thread.
3526 for (auto [i, tuple] : llvm::enumerate(
3527 llvm::zip(reductionDecls, isByRef, simdOp.getReductionVars(),
3528 privateReductionVariables))) {
3529 auto [decl, byRef, reductionVar, privateReductionVar] = tuple;
3530
3531 OwningReductionGen gen = makeReductionGen(decl, builder, moduleTranslation);
3532 llvm::Value *originalVariable = moduleTranslation.lookupValue(reductionVar);
3533 llvm::Type *reductionType = moduleTranslation.convertType(decl.getType());
3534
3535 // We have one less load for by-ref case because that load is now inside of
3536 // the reduction region.
3537 llvm::Value *redValue = originalVariable;
3538 if (!byRef)
3539 redValue =
3540 builder.CreateLoad(reductionType, redValue, "red.value." + Twine(i));
3541 llvm::Value *privateRedValue = builder.CreateLoad(
3542 reductionType, privateReductionVar, "red.private.value." + Twine(i));
3543 llvm::Value *reduced;
3544
3545 auto res = gen(builder.saveIP(), redValue, privateRedValue, reduced);
3546 if (failed(handleError(res, opInst)))
3547 return failure();
3548 builder.restoreIP(res.get());
3549
3550 // For by-ref case, the store is inside of the reduction region.
3551 if (!byRef)
3552 builder.CreateStore(reduced, originalVariable);
3553 }
3554
3555 // After the construct, deallocate private reduction variables.
3556 SmallVector<Region *> reductionRegions;
3557 llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
3558 [](omp::DeclareReductionOp reductionDecl) {
3559 return &reductionDecl.getCleanupRegion();
3560 });
3561 if (failed(inlineOmpRegionCleanup(reductionRegions, privateReductionVariables,
3562 moduleTranslation, builder,
3563 "omp.reduction.cleanup")))
3564 return failure();
3565
3566 return cleanupPrivateVars(builder, moduleTranslation, simdOp.getLoc(),
3567 privateVarsInfo.llvmVars,
3568 privateVarsInfo.privatizers);
3569}
3570
3571/// Converts an OpenMP loop nest into LLVM IR using OpenMPIRBuilder.
3572static LogicalResult
3573convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder,
3574 LLVM::ModuleTranslation &moduleTranslation) {
3575 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3576 auto loopOp = cast<omp::LoopNestOp>(opInst);
3577
3578 if (failed(checkImplementationStatus(opInst)))
3579 return failure();
3580
3581 // Set up the source location value for OpenMP runtime.
3582 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3583
3584 // Generator of the canonical loop body.
3587 auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip,
3588 llvm::Value *iv) -> llvm::Error {
3589 // Make sure further conversions know about the induction variable.
3590 moduleTranslation.mapValue(
3591 loopOp.getRegion().front().getArgument(loopInfos.size()), iv);
3592
3593 // Capture the body insertion point for use in nested loops. BodyIP of the
3594 // CanonicalLoopInfo always points to the beginning of the entry block of
3595 // the body.
3596 bodyInsertPoints.push_back(ip);
3597
3598 if (loopInfos.size() != loopOp.getNumLoops() - 1)
3599 return llvm::Error::success();
3600
3601 // Convert the body of the loop.
3602 builder.restoreIP(ip);
3604 loopOp.getRegion(), "omp.loop_nest.region", builder, moduleTranslation);
3605 if (!regionBlock)
3606 return regionBlock.takeError();
3607
3608 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
3609 return llvm::Error::success();
3610 };
3611
3612 // Delegate actual loop construction to the OpenMP IRBuilder.
3613 // TODO: this currently assumes omp.loop_nest is semantically similar to SCF
3614 // loop, i.e. it has a positive step, uses signed integer semantics.
3615 // Reconsider this code when the nested loop operation clearly supports more
3616 // cases.
3617 for (unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) {
3618 llvm::Value *lowerBound =
3619 moduleTranslation.lookupValue(loopOp.getLoopLowerBounds()[i]);
3620 llvm::Value *upperBound =
3621 moduleTranslation.lookupValue(loopOp.getLoopUpperBounds()[i]);
3622 llvm::Value *step = moduleTranslation.lookupValue(loopOp.getLoopSteps()[i]);
3623
3624 // Make sure loop trip count are emitted in the preheader of the outermost
3625 // loop at the latest so that they are all available for the new collapsed
3626 // loop will be created below.
3627 llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc;
3628 llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP;
3629 if (i != 0) {
3630 loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back(),
3631 ompLoc.DL);
3632 computeIP = loopInfos.front()->getPreheaderIP();
3633 }
3634
3636 ompBuilder->createCanonicalLoop(
3637 loc, bodyGen, lowerBound, upperBound, step,
3638 /*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP);
3639
3640 if (failed(handleError(loopResult, *loopOp)))
3641 return failure();
3642
3643 loopInfos.push_back(*loopResult);
3644 }
3645
3646 llvm::OpenMPIRBuilder::InsertPointTy afterIP =
3647 loopInfos.front()->getAfterIP();
3648
3649 // Do tiling.
3650 if (const auto &tiles = loopOp.getTileSizes()) {
3651 llvm::Type *ivType = loopInfos.front()->getIndVarType();
3653
3654 for (auto tile : tiles.value()) {
3655 llvm::Value *tileVal = llvm::ConstantInt::get(ivType, tile);
3656 tileSizes.push_back(tileVal);
3657 }
3658
3659 std::vector<llvm::CanonicalLoopInfo *> newLoops =
3660 ompBuilder->tileLoops(ompLoc.DL, loopInfos, tileSizes);
3661
3662 // Update afterIP to get the correct insertion point after
3663 // tiling.
3664 llvm::BasicBlock *afterBB = newLoops.front()->getAfter();
3665 llvm::BasicBlock *afterAfterBB = afterBB->getSingleSuccessor();
3666 afterIP = {afterAfterBB, afterAfterBB->begin()};
3667
3668 // Update the loop infos.
3669 loopInfos.clear();
3670 for (const auto &newLoop : newLoops)
3671 loopInfos.push_back(newLoop);
3672 } // Tiling done.
3673
3674 // Do collapse.
3675 const auto &numCollapse = loopOp.getCollapseNumLoops();
3677 loopInfos.begin(), loopInfos.begin() + (numCollapse));
3678
3679 auto newTopLoopInfo =
3680 ompBuilder->collapseLoops(ompLoc.DL, collapseLoopInfos, {});
3681
3682 assert(newTopLoopInfo && "New top loop information is missing");
3683 moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
3684 [&](OpenMPLoopInfoStackFrame &frame) {
3685 frame.loopInfo = newTopLoopInfo;
3686 return WalkResult::interrupt();
3687 });
3688
3689 // Continue building IR after the loop. Note that the LoopInfo returned by
3690 // `collapseLoops` points inside the outermost loop and is intended for
3691 // potential further loop transformations. Use the insertion point stored
3692 // before collapsing loops instead.
3693 builder.restoreIP(afterIP);
3694 return success();
3695}
3696
3697/// Convert an omp.canonical_loop to LLVM-IR
3698static LogicalResult
3699convertOmpCanonicalLoopOp(omp::CanonicalLoopOp op, llvm::IRBuilderBase &builder,
3700 LLVM::ModuleTranslation &moduleTranslation) {
3701 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3702
3703 llvm::OpenMPIRBuilder::LocationDescription loopLoc(builder);
3704 Value loopIV = op.getInductionVar();
3705 Value loopTC = op.getTripCount();
3706
3707 llvm::Value *llvmTC = moduleTranslation.lookupValue(loopTC);
3708
3710 ompBuilder->createCanonicalLoop(
3711 loopLoc,
3712 [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *llvmIV) {
3713 // Register the mapping of MLIR induction variable to LLVM-IR
3714 // induction variable
3715 moduleTranslation.mapValue(loopIV, llvmIV);
3716
3717 builder.restoreIP(ip);
3719 convertOmpOpRegions(op.getRegion(), "omp.loop.region", builder,
3720 moduleTranslation);
3721
3722 return bodyGenStatus.takeError();
3723 },
3724 llvmTC, "omp.loop");
3725 if (!llvmOrError)
3726 return op.emitError(llvm::toString(llvmOrError.takeError()));
3727
3728 llvm::CanonicalLoopInfo *llvmCLI = *llvmOrError;
3729 llvm::IRBuilderBase::InsertPoint afterIP = llvmCLI->getAfterIP();
3730 builder.restoreIP(afterIP);
3731
3732 // Register the mapping of MLIR loop to LLVM-IR OpenMPIRBuilder loop
3733 if (Value cli = op.getCli())
3734 moduleTranslation.mapOmpLoop(cli, llvmCLI);
3735
3736 return success();
3737}
3738
3739/// Apply a `#pragma omp unroll` / "!$omp unroll" transformation using the
3740/// OpenMPIRBuilder.
3741static LogicalResult
3742applyUnrollHeuristic(omp::UnrollHeuristicOp op, llvm::IRBuilderBase &builder,
3743 LLVM::ModuleTranslation &moduleTranslation) {
3744 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3745
3746 Value applyee = op.getApplyee();
3747 assert(applyee && "Loop to apply unrolling on required");
3748
3749 llvm::CanonicalLoopInfo *consBuilderCLI =
3750 moduleTranslation.lookupOMPLoop(applyee);
3751 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
3752 ompBuilder->unrollLoopHeuristic(loc.DL, consBuilderCLI);
3753
3754 moduleTranslation.invalidateOmpLoop(applyee);
3755 return success();
3756}
3757
3758/// Apply a `#pragma omp tile` / `!$omp tile` transformation using the
3759/// OpenMPIRBuilder.
3760static LogicalResult applyTile(omp::TileOp op, llvm::IRBuilderBase &builder,
3761 LLVM::ModuleTranslation &moduleTranslation) {
3762 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3763 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
3764
3766 SmallVector<llvm::Value *> translatedSizes;
3767
3768 for (Value size : op.getSizes()) {
3769 llvm::Value *translatedSize = moduleTranslation.lookupValue(size);
3770 assert(translatedSize &&
3771 "sizes clause arguments must already be translated");
3772 translatedSizes.push_back(translatedSize);
3773 }
3774
3775 for (Value applyee : op.getApplyees()) {
3776 llvm::CanonicalLoopInfo *consBuilderCLI =
3777 moduleTranslation.lookupOMPLoop(applyee);
3778 assert(applyee && "Canonical loop must already been translated");
3779 translatedLoops.push_back(consBuilderCLI);
3780 }
3781
3782 auto generatedLoops =
3783 ompBuilder->tileLoops(loc.DL, translatedLoops, translatedSizes);
3784 if (!op.getGeneratees().empty()) {
3785 for (auto [mlirLoop, genLoop] :
3786 zip_equal(op.getGeneratees(), generatedLoops))
3787 moduleTranslation.mapOmpLoop(mlirLoop, genLoop);
3788 }
3789
3790 // CLIs can only be consumed once
3791 for (Value applyee : op.getApplyees())
3792 moduleTranslation.invalidateOmpLoop(applyee);
3793
3794 return success();
3795}
3796
3797/// Convert an Atomic Ordering attribute to llvm::AtomicOrdering.
3798static llvm::AtomicOrdering
3799convertAtomicOrdering(std::optional<omp::ClauseMemoryOrderKind> ao) {
3800 if (!ao)
3801 return llvm::AtomicOrdering::Monotonic; // Default Memory Ordering
3802
3803 switch (*ao) {
3804 case omp::ClauseMemoryOrderKind::Seq_cst:
3805 return llvm::AtomicOrdering::SequentiallyConsistent;
3806 case omp::ClauseMemoryOrderKind::Acq_rel:
3807 return llvm::AtomicOrdering::AcquireRelease;
3808 case omp::ClauseMemoryOrderKind::Acquire:
3809 return llvm::AtomicOrdering::Acquire;
3810 case omp::ClauseMemoryOrderKind::Release:
3811 return llvm::AtomicOrdering::Release;
3812 case omp::ClauseMemoryOrderKind::Relaxed:
3813 return llvm::AtomicOrdering::Monotonic;
3814 }
3815 llvm_unreachable("Unknown ClauseMemoryOrderKind kind");
3816}
3817
3818/// Convert omp.atomic.read operation to LLVM IR.
3819static LogicalResult
3820convertOmpAtomicRead(Operation &opInst, llvm::IRBuilderBase &builder,
3821 LLVM::ModuleTranslation &moduleTranslation) {
3822 auto readOp = cast<omp::AtomicReadOp>(opInst);
3823 if (failed(checkImplementationStatus(opInst)))
3824 return failure();
3825
3826 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3827 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3828 findAllocaInsertPoint(builder, moduleTranslation);
3829
3830 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3831
3832 llvm::AtomicOrdering AO = convertAtomicOrdering(readOp.getMemoryOrder());
3833 llvm::Value *x = moduleTranslation.lookupValue(readOp.getX());
3834 llvm::Value *v = moduleTranslation.lookupValue(readOp.getV());
3835
3836 llvm::Type *elementType =
3837 moduleTranslation.convertType(readOp.getElementType());
3838
3839 llvm::OpenMPIRBuilder::AtomicOpValue V = {v, elementType, false, false};
3840 llvm::OpenMPIRBuilder::AtomicOpValue X = {x, elementType, false, false};
3841 builder.restoreIP(ompBuilder->createAtomicRead(ompLoc, X, V, AO, allocaIP));
3842 return success();
3843}
3844
3845/// Converts an omp.atomic.write operation to LLVM IR.
3846static LogicalResult
3847convertOmpAtomicWrite(Operation &opInst, llvm::IRBuilderBase &builder,
3848 LLVM::ModuleTranslation &moduleTranslation) {
3849 auto writeOp = cast<omp::AtomicWriteOp>(opInst);
3850 if (failed(checkImplementationStatus(opInst)))
3851 return failure();
3852
3853 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3854 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3855 findAllocaInsertPoint(builder, moduleTranslation);
3856
3857 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3858 llvm::AtomicOrdering ao = convertAtomicOrdering(writeOp.getMemoryOrder());
3859 llvm::Value *expr = moduleTranslation.lookupValue(writeOp.getExpr());
3860 llvm::Value *dest = moduleTranslation.lookupValue(writeOp.getX());
3861 llvm::Type *ty = moduleTranslation.convertType(writeOp.getExpr().getType());
3862 llvm::OpenMPIRBuilder::AtomicOpValue x = {dest, ty, /*isSigned=*/false,
3863 /*isVolatile=*/false};
3864 builder.restoreIP(
3865 ompBuilder->createAtomicWrite(ompLoc, x, expr, ao, allocaIP));
3866 return success();
3867}
3868
3869/// Converts an LLVM dialect binary operation to the corresponding enum value
3870/// for `atomicrmw` supported binary operation.
3871static llvm::AtomicRMWInst::BinOp convertBinOpToAtomic(Operation &op) {
3873 .Case([&](LLVM::AddOp) { return llvm::AtomicRMWInst::BinOp::Add; })
3874 .Case([&](LLVM::SubOp) { return llvm::AtomicRMWInst::BinOp::Sub; })
3875 .Case([&](LLVM::AndOp) { return llvm::AtomicRMWInst::BinOp::And; })
3876 .Case([&](LLVM::OrOp) { return llvm::AtomicRMWInst::BinOp::Or; })
3877 .Case([&](LLVM::XOrOp) { return llvm::AtomicRMWInst::BinOp::Xor; })
3878 .Case([&](LLVM::UMaxOp) { return llvm::AtomicRMWInst::BinOp::UMax; })
3879 .Case([&](LLVM::UMinOp) { return llvm::AtomicRMWInst::BinOp::UMin; })
3880 .Case([&](LLVM::FAddOp) { return llvm::AtomicRMWInst::BinOp::FAdd; })
3881 .Case([&](LLVM::FSubOp) { return llvm::AtomicRMWInst::BinOp::FSub; })
3882 .Default(llvm::AtomicRMWInst::BinOp::BAD_BINOP);
3883}
3884
3885static void extractAtomicControlFlags(omp::AtomicUpdateOp atomicUpdateOp,
3886 bool &isIgnoreDenormalMode,
3887 bool &isFineGrainedMemory,
3888 bool &isRemoteMemory) {
3889 isIgnoreDenormalMode = false;
3890 isFineGrainedMemory = false;
3891 isRemoteMemory = false;
3892 if (atomicUpdateOp &&
3893 atomicUpdateOp->hasAttr(atomicUpdateOp.getAtomicControlAttrName())) {
3894 mlir::omp::AtomicControlAttr atomicControlAttr =
3895 atomicUpdateOp.getAtomicControlAttr();
3896 isIgnoreDenormalMode = atomicControlAttr.getIgnoreDenormalMode();
3897 isFineGrainedMemory = atomicControlAttr.getFineGrainedMemory();
3898 isRemoteMemory = atomicControlAttr.getRemoteMemory();
3899 }
3900}
3901
3902/// Converts an OpenMP atomic update operation using OpenMPIRBuilder.
3903static LogicalResult
3904convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst,
3905 llvm::IRBuilderBase &builder,
3906 LLVM::ModuleTranslation &moduleTranslation) {
3907 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3908 if (failed(checkImplementationStatus(*opInst)))
3909 return failure();
3910
3911 // Convert values and types.
3912 auto &innerOpList = opInst.getRegion().front().getOperations();
3913 bool isXBinopExpr{false};
3914 llvm::AtomicRMWInst::BinOp binop;
3915 mlir::Value mlirExpr;
3916 llvm::Value *llvmExpr = nullptr;
3917 llvm::Value *llvmX = nullptr;
3918 llvm::Type *llvmXElementType = nullptr;
3919 if (innerOpList.size() == 2) {
3920 // The two operations here are the update and the terminator.
3921 // Since we can identify the update operation, there is a possibility
3922 // that we can generate the atomicrmw instruction.
3923 mlir::Operation &innerOp = *opInst.getRegion().front().begin();
3924 if (!llvm::is_contained(innerOp.getOperands(),
3925 opInst.getRegion().getArgument(0))) {
3926 return opInst.emitError("no atomic update operation with region argument"
3927 " as operand found inside atomic.update region");
3928 }
3929 binop = convertBinOpToAtomic(innerOp);
3930 isXBinopExpr = innerOp.getOperand(0) == opInst.getRegion().getArgument(0);
3931 mlirExpr = (isXBinopExpr ? innerOp.getOperand(1) : innerOp.getOperand(0));
3932 llvmExpr = moduleTranslation.lookupValue(mlirExpr);
3933 } else {
3934 // Since the update region includes more than one operation
3935 // we will resort to generating a cmpxchg loop.
3936 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
3937 }
3938 llvmX = moduleTranslation.lookupValue(opInst.getX());
3939 llvmXElementType = moduleTranslation.convertType(
3940 opInst.getRegion().getArgument(0).getType());
3941 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
3942 /*isSigned=*/false,
3943 /*isVolatile=*/false};
3944
3945 llvm::AtomicOrdering atomicOrdering =
3946 convertAtomicOrdering(opInst.getMemoryOrder());
3947
3948 // Generate update code.
3949 auto updateFn =
3950 [&opInst, &moduleTranslation](
3951 llvm::Value *atomicx,
3952 llvm::IRBuilder<> &builder) -> llvm::Expected<llvm::Value *> {
3953 Block &bb = *opInst.getRegion().begin();
3954 moduleTranslation.mapValue(*opInst.getRegion().args_begin(), atomicx);
3955 moduleTranslation.mapBlock(&bb, builder.GetInsertBlock());
3956 if (failed(moduleTranslation.convertBlock(bb, true, builder)))
3957 return llvm::make_error<PreviouslyReportedError>();
3958
3959 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.getTerminator());
3960 assert(yieldop && yieldop.getResults().size() == 1 &&
3961 "terminator must be omp.yield op and it must have exactly one "
3962 "argument");
3963 return moduleTranslation.lookupValue(yieldop.getResults()[0]);
3964 };
3965
3966 bool isIgnoreDenormalMode;
3967 bool isFineGrainedMemory;
3968 bool isRemoteMemory;
3969 extractAtomicControlFlags(opInst, isIgnoreDenormalMode, isFineGrainedMemory,
3970 isRemoteMemory);
3971 // Handle ambiguous alloca, if any.
3972 auto allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
3973 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3974 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3975 ompBuilder->createAtomicUpdate(ompLoc, allocaIP, llvmAtomicX, llvmExpr,
3976 atomicOrdering, binop, updateFn,
3977 isXBinopExpr, isIgnoreDenormalMode,
3978 isFineGrainedMemory, isRemoteMemory);
3979
3980 if (failed(handleError(afterIP, *opInst)))
3981 return failure();
3982
3983 builder.restoreIP(*afterIP);
3984 return success();
3985}
3986
3987static LogicalResult
3988convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp,
3989 llvm::IRBuilderBase &builder,
3990 LLVM::ModuleTranslation &moduleTranslation) {
3991 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3992 if (failed(checkImplementationStatus(*atomicCaptureOp)))
3993 return failure();
3994
3995 mlir::Value mlirExpr;
3996 bool isXBinopExpr = false, isPostfixUpdate = false;
3997 llvm::AtomicRMWInst::BinOp binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
3998
3999 omp::AtomicUpdateOp atomicUpdateOp = atomicCaptureOp.getAtomicUpdateOp();
4000 omp::AtomicWriteOp atomicWriteOp = atomicCaptureOp.getAtomicWriteOp();
4001
4002 assert((atomicUpdateOp || atomicWriteOp) &&
4003 "internal op must be an atomic.update or atomic.write op");
4004
4005 if (atomicWriteOp) {
4006 isPostfixUpdate = true;
4007 mlirExpr = atomicWriteOp.getExpr();
4008 } else {
4009 isPostfixUpdate = atomicCaptureOp.getSecondOp() ==
4010 atomicCaptureOp.getAtomicUpdateOp().getOperation();
4011 auto &innerOpList = atomicUpdateOp.getRegion().front().getOperations();
4012 // Find the binary update operation that uses the region argument
4013 // and get the expression to update
4014 if (innerOpList.size() == 2) {
4015 mlir::Operation &innerOp = *atomicUpdateOp.getRegion().front().begin();
4016 if (!llvm::is_contained(innerOp.getOperands(),
4017 atomicUpdateOp.getRegion().getArgument(0))) {
4018 return atomicUpdateOp.emitError(
4019 "no atomic update operation with region argument"
4020 " as operand found inside atomic.update region");
4021 }
4022 binop = convertBinOpToAtomic(innerOp);
4023 isXBinopExpr =
4024 innerOp.getOperand(0) == atomicUpdateOp.getRegion().getArgument(0);
4025 mlirExpr = (isXBinopExpr ? innerOp.getOperand(1) : innerOp.getOperand(0));
4026 } else {
4027 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
4028 }
4029 }
4030
4031 llvm::Value *llvmExpr = moduleTranslation.lookupValue(mlirExpr);
4032 llvm::Value *llvmX =
4033 moduleTranslation.lookupValue(atomicCaptureOp.getAtomicReadOp().getX());
4034 llvm::Value *llvmV =
4035 moduleTranslation.lookupValue(atomicCaptureOp.getAtomicReadOp().getV());
4036 llvm::Type *llvmXElementType = moduleTranslation.convertType(
4037 atomicCaptureOp.getAtomicReadOp().getElementType());
4038 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
4039 /*isSigned=*/false,
4040 /*isVolatile=*/false};
4041 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicV = {llvmV, llvmXElementType,
4042 /*isSigned=*/false,
4043 /*isVolatile=*/false};
4044
4045 llvm::AtomicOrdering atomicOrdering =
4046 convertAtomicOrdering(atomicCaptureOp.getMemoryOrder());
4047
4048 auto updateFn =
4049 [&](llvm::Value *atomicx,
4050 llvm::IRBuilder<> &builder) -> llvm::Expected<llvm::Value *> {
4051 if (atomicWriteOp)
4052 return moduleTranslation.lookupValue(atomicWriteOp.getExpr());
4053 Block &bb = *atomicUpdateOp.getRegion().begin();
4054 moduleTranslation.mapValue(*atomicUpdateOp.getRegion().args_begin(),
4055 atomicx);
4056 moduleTranslation.mapBlock(&bb, builder.GetInsertBlock());
4057 if (failed(moduleTranslation.convertBlock(bb, true, builder)))
4058 return llvm::make_error<PreviouslyReportedError>();
4059
4060 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.getTerminator());
4061 assert(yieldop && yieldop.getResults().size() == 1 &&
4062 "terminator must be omp.yield op and it must have exactly one "
4063 "argument");
4064 return moduleTranslation.lookupValue(yieldop.getResults()[0]);
4065 };
4066
4067 bool isIgnoreDenormalMode;
4068 bool isFineGrainedMemory;
4069 bool isRemoteMemory;
4070 extractAtomicControlFlags(atomicUpdateOp, isIgnoreDenormalMode,
4071 isFineGrainedMemory, isRemoteMemory);
4072 // Handle ambiguous alloca, if any.
4073 auto allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
4074 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4075 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4076 ompBuilder->createAtomicCapture(
4077 ompLoc, allocaIP, llvmAtomicX, llvmAtomicV, llvmExpr, atomicOrdering,
4078 binop, updateFn, atomicUpdateOp, isPostfixUpdate, isXBinopExpr,
4079 isIgnoreDenormalMode, isFineGrainedMemory, isRemoteMemory);
4080
4081 if (failed(handleError(afterIP, *atomicCaptureOp)))
4082 return failure();
4083
4084 builder.restoreIP(*afterIP);
4085 return success();
4086}
4087
4088static llvm::omp::Directive convertCancellationConstructType(
4089 omp::ClauseCancellationConstructType directive) {
4090 switch (directive) {
4091 case omp::ClauseCancellationConstructType::Loop:
4092 return llvm::omp::Directive::OMPD_for;
4093 case omp::ClauseCancellationConstructType::Parallel:
4094 return llvm::omp::Directive::OMPD_parallel;
4095 case omp::ClauseCancellationConstructType::Sections:
4096 return llvm::omp::Directive::OMPD_sections;
4097 case omp::ClauseCancellationConstructType::Taskgroup:
4098 return llvm::omp::Directive::OMPD_taskgroup;
4099 }
4100 llvm_unreachable("Unhandled cancellation construct type");
4101}
4102
4103static LogicalResult
4104convertOmpCancel(omp::CancelOp op, llvm::IRBuilderBase &builder,
4105 LLVM::ModuleTranslation &moduleTranslation) {
4106 if (failed(checkImplementationStatus(*op.getOperation())))
4107 return failure();
4108
4109 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4110 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4111
4112 llvm::Value *ifCond = nullptr;
4113 if (Value ifVar = op.getIfExpr())
4114 ifCond = moduleTranslation.lookupValue(ifVar);
4115
4116 llvm::omp::Directive cancelledDirective =
4117 convertCancellationConstructType(op.getCancelDirective());
4118
4119 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4120 ompBuilder->createCancel(ompLoc, ifCond, cancelledDirective);
4121
4122 if (failed(handleError(afterIP, *op.getOperation())))
4123 return failure();
4124
4125 builder.restoreIP(afterIP.get());
4126
4127 return success();
4128}
4129
4130static LogicalResult
4131convertOmpCancellationPoint(omp::CancellationPointOp op,
4132 llvm::IRBuilderBase &builder,
4133 LLVM::ModuleTranslation &moduleTranslation) {
4134 if (failed(checkImplementationStatus(*op.getOperation())))
4135 return failure();
4136
4137 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4138 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4139
4140 llvm::omp::Directive cancelledDirective =
4141 convertCancellationConstructType(op.getCancelDirective());
4142
4143 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4144 ompBuilder->createCancellationPoint(ompLoc, cancelledDirective);
4145
4146 if (failed(handleError(afterIP, *op.getOperation())))
4147 return failure();
4148
4149 builder.restoreIP(afterIP.get());
4150
4151 return success();
4152}
4153
4154/// Converts an OpenMP Threadprivate operation into LLVM IR using
4155/// OpenMPIRBuilder.
4156static LogicalResult
4157convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder,
4158 LLVM::ModuleTranslation &moduleTranslation) {
4159 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4160 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4161 auto threadprivateOp = cast<omp::ThreadprivateOp>(opInst);
4162
4163 if (failed(checkImplementationStatus(opInst)))
4164 return failure();
4165
4166 Value symAddr = threadprivateOp.getSymAddr();
4167 auto *symOp = symAddr.getDefiningOp();
4168
4169 if (auto asCast = dyn_cast<LLVM::AddrSpaceCastOp>(symOp))
4170 symOp = asCast.getOperand().getDefiningOp();
4171
4172 if (!isa<LLVM::AddressOfOp>(symOp))
4173 return opInst.emitError("Addressing symbol not found");
4174 LLVM::AddressOfOp addressOfOp = dyn_cast<LLVM::AddressOfOp>(symOp);
4175
4176 LLVM::GlobalOp global =
4177 addressOfOp.getGlobal(moduleTranslation.symbolTable());
4178 llvm::GlobalValue *globalValue = moduleTranslation.lookupGlobal(global);
4179 llvm::Type *type = globalValue->getValueType();
4180 llvm::TypeSize typeSize =
4181 builder.GetInsertBlock()->getModule()->getDataLayout().getTypeStoreSize(
4182 type);
4183 llvm::ConstantInt *size = builder.getInt64(typeSize.getFixedValue());
4184 llvm::Value *callInst = ompBuilder->createCachedThreadPrivate(
4185 ompLoc, globalValue, size, global.getSymName() + ".cache");
4186 moduleTranslation.mapValue(opInst.getResult(0), callInst);
4187
4188 return success();
4189}
4190
4191static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind
4192convertToDeviceClauseKind(mlir::omp::DeclareTargetDeviceType deviceClause) {
4193 switch (deviceClause) {
4194 case mlir::omp::DeclareTargetDeviceType::host:
4195 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseHost;
4196 break;
4197 case mlir::omp::DeclareTargetDeviceType::nohost:
4198 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseNoHost;
4199 break;
4200 case mlir::omp::DeclareTargetDeviceType::any:
4201 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseAny;
4202 break;
4203 }
4204 llvm_unreachable("unhandled device clause");
4205}
4206
4207static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind
4209 mlir::omp::DeclareTargetCaptureClause captureClause) {
4210 switch (captureClause) {
4211 case mlir::omp::DeclareTargetCaptureClause::to:
4212 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo;
4213 case mlir::omp::DeclareTargetCaptureClause::link:
4214 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink;
4215 case mlir::omp::DeclareTargetCaptureClause::enter:
4216 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter;
4217 case mlir::omp::DeclareTargetCaptureClause::none:
4218 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryNone;
4219 }
4220 llvm_unreachable("unhandled capture clause");
4221}
4222
4224 Operation *op = value.getDefiningOp();
4225 if (auto addrCast = dyn_cast_if_present<LLVM::AddrSpaceCastOp>(op))
4226 op = addrCast->getOperand(0).getDefiningOp();
4227 if (auto addressOfOp = dyn_cast_if_present<LLVM::AddressOfOp>(op)) {
4228 auto modOp = addressOfOp->getParentOfType<mlir::ModuleOp>();
4229 return modOp.lookupSymbol(addressOfOp.getGlobalName());
4230 }
4231 return nullptr;
4232}
4233
4234static llvm::SmallString<64>
4235getDeclareTargetRefPtrSuffix(LLVM::GlobalOp globalOp,
4236 llvm::OpenMPIRBuilder &ompBuilder) {
4237 llvm::SmallString<64> suffix;
4238 llvm::raw_svector_ostream os(suffix);
4239 if (globalOp.getVisibility() == mlir::SymbolTable::Visibility::Private) {
4240 auto loc = globalOp->getLoc()->findInstanceOf<FileLineColLoc>();
4241 auto fileInfoCallBack = [&loc]() {
4242 return std::pair<std::string, uint64_t>(
4243 llvm::StringRef(loc.getFilename()), loc.getLine());
4244 };
4245
4246 auto vfs = llvm::vfs::getRealFileSystem();
4247 os << llvm::format(
4248 "_%x",
4249 ompBuilder.getTargetEntryUniqueInfo(fileInfoCallBack, *vfs).FileID);
4250 }
4251 os << "_decl_tgt_ref_ptr";
4252
4253 return suffix;
4254}
4255
4256static bool isDeclareTargetLink(Value value) {
4257 if (auto declareTargetGlobal =
4258 dyn_cast_if_present<omp::DeclareTargetInterface>(
4259 getGlobalOpFromValue(value)))
4260 if (declareTargetGlobal.getDeclareTargetCaptureClause() ==
4261 omp::DeclareTargetCaptureClause::link)
4262 return true;
4263 return false;
4264}
4265
4266static bool isDeclareTargetTo(Value value) {
4267 if (auto declareTargetGlobal =
4268 dyn_cast_if_present<omp::DeclareTargetInterface>(
4269 getGlobalOpFromValue(value)))
4270 if (declareTargetGlobal.getDeclareTargetCaptureClause() ==
4271 omp::DeclareTargetCaptureClause::to ||
4272 declareTargetGlobal.getDeclareTargetCaptureClause() ==
4273 omp::DeclareTargetCaptureClause::enter)
4274 return true;
4275 return false;
4276}
4277
4278// Returns the reference pointer generated by the lowering of the declare
4279// target operation in cases where the link clause is used or the to clause is
4280// used in USM mode.
4281static llvm::Value *
4283 LLVM::ModuleTranslation &moduleTranslation) {
4284 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4285 if (auto gOp =
4286 dyn_cast_or_null<LLVM::GlobalOp>(getGlobalOpFromValue(value))) {
4287 if (auto declareTargetGlobal =
4288 dyn_cast<omp::DeclareTargetInterface>(gOp.getOperation())) {
4289 // In this case, we must utilise the reference pointer generated by
4290 // the declare target operation, similar to Clang
4291 if ((declareTargetGlobal.getDeclareTargetCaptureClause() ==
4292 omp::DeclareTargetCaptureClause::link) ||
4293 (declareTargetGlobal.getDeclareTargetCaptureClause() ==
4294 omp::DeclareTargetCaptureClause::to &&
4295 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
4296 llvm::SmallString<64> suffix =
4297 getDeclareTargetRefPtrSuffix(gOp, *ompBuilder);
4298
4299 if (gOp.getSymName().contains(suffix))
4300 return moduleTranslation.getLLVMModule()->getNamedValue(
4301 gOp.getSymName());
4302
4303 return moduleTranslation.getLLVMModule()->getNamedValue(
4304 (gOp.getSymName().str() + suffix.str()).str());
4305 }
4306 }
4307 }
4308 return nullptr;
4309}
4310
4311namespace {
4312// Append customMappers information to existing MapInfosTy
4313struct MapInfosTy : llvm::OpenMPIRBuilder::MapInfosTy {
4314 SmallVector<Operation *, 4> Mappers;
4315
4316 /// Append arrays in \a CurInfo.
4317 void append(MapInfosTy &curInfo) {
4318 Mappers.append(curInfo.Mappers.begin(), curInfo.Mappers.end());
4319 llvm::OpenMPIRBuilder::MapInfosTy::append(curInfo);
4320 }
4321};
4322// A small helper structure to contain data gathered
4323// for map lowering and coalese it into one area and
4324// avoiding extra computations such as searches in the
4325// llvm module for lowered mapped variables or checking
4326// if something is declare target (and retrieving the
4327// value) more than neccessary.
4328struct MapInfoData : MapInfosTy {
4329 llvm::SmallVector<bool, 4> IsDeclareTarget;
4330 llvm::SmallVector<bool, 4> IsAMember;
4331 // Identify if mapping was added by mapClause or use_device clauses.
4332 llvm::SmallVector<bool, 4> IsAMapping;
4333 llvm::SmallVector<mlir::Operation *, 4> MapClause;
4334 llvm::SmallVector<llvm::Value *, 4> OriginalValue;
4335 // Stripped off array/pointer to get the underlying
4336 // element type
4337 llvm::SmallVector<llvm::Type *, 4> BaseType;
4338
4339 /// Append arrays in \a CurInfo.
4340 void append(MapInfoData &CurInfo) {
4341 IsDeclareTarget.append(CurInfo.IsDeclareTarget.begin(),
4342 CurInfo.IsDeclareTarget.end());
4343 MapClause.append(CurInfo.MapClause.begin(), CurInfo.MapClause.end());
4344 OriginalValue.append(CurInfo.OriginalValue.begin(),
4345 CurInfo.OriginalValue.end());
4346 BaseType.append(CurInfo.BaseType.begin(), CurInfo.BaseType.end());
4347 MapInfosTy::append(CurInfo);
4348 }
4349};
4350
4351enum class TargetDirectiveEnumTy : uint32_t {
4352 None = 0,
4353 Target = 1,
4354 TargetData = 2,
4355 TargetEnterData = 3,
4356 TargetExitData = 4,
4357 TargetUpdate = 5
4358};
4359
4360static TargetDirectiveEnumTy getTargetDirectiveEnumTyFromOp(Operation *op) {
4361 return llvm::TypeSwitch<Operation *, TargetDirectiveEnumTy>(op)
4362 .Case([](omp::TargetDataOp) { return TargetDirectiveEnumTy::TargetData; })
4363 .Case([](omp::TargetEnterDataOp) {
4364 return TargetDirectiveEnumTy::TargetEnterData;
4365 })
4366 .Case([&](omp::TargetExitDataOp) {
4367 return TargetDirectiveEnumTy::TargetExitData;
4368 })
4369 .Case([&](omp::TargetUpdateOp) {
4370 return TargetDirectiveEnumTy::TargetUpdate;
4371 })
4372 .Case([&](omp::TargetOp) { return TargetDirectiveEnumTy::Target; })
4373 .Default([&](Operation *op) { return TargetDirectiveEnumTy::None; });
4374}
4375
4376} // namespace
4377
4378static uint64_t getArrayElementSizeInBits(LLVM::LLVMArrayType arrTy,
4379 DataLayout &dl) {
4380 if (auto nestedArrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(
4381 arrTy.getElementType()))
4382 return getArrayElementSizeInBits(nestedArrTy, dl);
4383 return dl.getTypeSizeInBits(arrTy.getElementType());
4384}
4385
4386// This function calculates the size to be offloaded for a specified type, given
4387// its associated map clause (which can contain bounds information which affects
4388// the total size), this size is calculated based on the underlying element type
4389// e.g. given a 1-D array of ints, we will calculate the size from the integer
4390// type * number of elements in the array. This size can be used in other
4391// calculations but is ultimately used as an argument to the OpenMP runtimes
4392// kernel argument structure which is generated through the combinedInfo data
4393// structures.
4394// This function is somewhat equivalent to Clang's getExprTypeSize inside of
4395// CGOpenMPRuntime.cpp.
4396static llvm::Value *getSizeInBytes(DataLayout &dl, const mlir::Type &type,
4397 Operation *clauseOp,
4398 llvm::Value *basePointer,
4399 llvm::Type *baseType,
4400 llvm::IRBuilderBase &builder,
4401 LLVM::ModuleTranslation &moduleTranslation) {
4402 if (auto memberClause =
4403 mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(clauseOp)) {
4404 // This calculates the size to transfer based on bounds and the underlying
4405 // element type, provided bounds have been specified (Fortran
4406 // pointers/allocatables/target and arrays that have sections specified fall
4407 // into this as well)
4408 if (!memberClause.getBounds().empty()) {
4409 llvm::Value *elementCount = builder.getInt64(1);
4410 for (auto bounds : memberClause.getBounds()) {
4411 if (auto boundOp = mlir::dyn_cast_if_present<mlir::omp::MapBoundsOp>(
4412 bounds.getDefiningOp())) {
4413 // The below calculation for the size to be mapped calculated from the
4414 // map.info's bounds is: (elemCount * [UB - LB] + 1), later we
4415 // multiply by the underlying element types byte size to get the full
4416 // size to be offloaded based on the bounds
4417 elementCount = builder.CreateMul(
4418 elementCount,
4419 builder.CreateAdd(
4420 builder.CreateSub(
4421 moduleTranslation.lookupValue(boundOp.getUpperBound()),
4422 moduleTranslation.lookupValue(boundOp.getLowerBound())),
4423 builder.getInt64(1)));
4424 }
4425 }
4426
4427 // utilising getTypeSizeInBits instead of getTypeSize as getTypeSize gives
4428 // the size in inconsistent byte or bit format.
4429 uint64_t underlyingTypeSzInBits = dl.getTypeSizeInBits(type);
4430 if (auto arrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(type))
4431 underlyingTypeSzInBits = getArrayElementSizeInBits(arrTy, dl);
4432
4433 // The size in bytes x number of elements, the sizeInBytes stored is
4434 // the underyling types size, e.g. if ptr<i32>, it'll be the i32's
4435 // size, so we do some on the fly runtime math to get the size in
4436 // bytes from the extent (ub - lb) * sizeInBytes. NOTE: This may need
4437 // some adjustment for members with more complex types.
4438 return builder.CreateMul(elementCount,
4439 builder.getInt64(underlyingTypeSzInBits / 8));
4440 }
4441 }
4442
4443 return builder.getInt64(dl.getTypeSizeInBits(type) / 8);
4444}
4445
4446// Convert the MLIR map flag set to the runtime map flag set for embedding
4447// in LLVM-IR. This is important as the two bit-flag lists do not correspond
4448// 1-to-1 as there's flags the runtime doesn't care about and vice versa.
4449// Certain flags are discarded here such as RefPtee and co.
4450static llvm::omp::OpenMPOffloadMappingFlags
4451convertClauseMapFlags(omp::ClauseMapFlags mlirFlags) {
4452 auto mapTypeToBool = [&mlirFlags](omp::ClauseMapFlags flag) {
4453 return (mlirFlags & flag) == flag;
4454 };
4455 const bool hasExplicitMap =
4456 (mlirFlags & ~omp::ClauseMapFlags::is_device_ptr) !=
4457 omp::ClauseMapFlags::none;
4458
4459 llvm::omp::OpenMPOffloadMappingFlags mapType =
4460 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
4461
4462 if (mapTypeToBool(omp::ClauseMapFlags::to))
4463 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
4464
4465 if (mapTypeToBool(omp::ClauseMapFlags::from))
4466 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
4467
4468 if (mapTypeToBool(omp::ClauseMapFlags::always))
4469 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
4470
4471 if (mapTypeToBool(omp::ClauseMapFlags::del))
4472 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
4473
4474 if (mapTypeToBool(omp::ClauseMapFlags::return_param))
4475 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
4476
4477 if (mapTypeToBool(omp::ClauseMapFlags::priv))
4478 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRIVATE;
4479
4480 if (mapTypeToBool(omp::ClauseMapFlags::literal))
4481 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
4482
4483 if (mapTypeToBool(omp::ClauseMapFlags::implicit))
4484 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
4485
4486 if (mapTypeToBool(omp::ClauseMapFlags::close))
4487 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
4488
4489 if (mapTypeToBool(omp::ClauseMapFlags::present))
4490 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT;
4491
4492 if (mapTypeToBool(omp::ClauseMapFlags::ompx_hold))
4493 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD;
4494
4495 if (mapTypeToBool(omp::ClauseMapFlags::attach))
4496 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH;
4497
4498 if (mapTypeToBool(omp::ClauseMapFlags::is_device_ptr)) {
4499 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
4500 if (!hasExplicitMap)
4501 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
4502 }
4503
4504 return mapType;
4505}
4506
4508 MapInfoData &mapData, SmallVectorImpl<Value> &mapVars,
4509 LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl,
4510 llvm::IRBuilderBase &builder, ArrayRef<Value> useDevPtrOperands = {},
4511 ArrayRef<Value> useDevAddrOperands = {},
4512 ArrayRef<Value> hasDevAddrOperands = {}) {
4513 auto checkIsAMember = [](const auto &mapVars, auto mapOp) {
4514 // Check if this is a member mapping and correctly assign that it is, if
4515 // it is a member of a larger object.
4516 // TODO: Need better handling of members, and distinguishing of members
4517 // that are implicitly allocated on device vs explicitly passed in as
4518 // arguments.
4519 // TODO: May require some further additions to support nested record
4520 // types, i.e. member maps that can have member maps.
4521 for (Value mapValue : mapVars) {
4522 auto map = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
4523 for (auto member : map.getMembers())
4524 if (member == mapOp)
4525 return true;
4526 }
4527 return false;
4528 };
4529
4530 // Process MapOperands
4531 for (Value mapValue : mapVars) {
4532 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
4533 Value offloadPtr =
4534 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
4535 mapData.OriginalValue.push_back(moduleTranslation.lookupValue(offloadPtr));
4536 mapData.Pointers.push_back(mapData.OriginalValue.back());
4537
4538 if (llvm::Value *refPtr =
4539 getRefPtrIfDeclareTarget(offloadPtr, moduleTranslation)) {
4540 mapData.IsDeclareTarget.push_back(true);
4541 mapData.BasePointers.push_back(refPtr);
4542 } else if (isDeclareTargetTo(offloadPtr)) {
4543 mapData.IsDeclareTarget.push_back(true);
4544 mapData.BasePointers.push_back(mapData.OriginalValue.back());
4545 } else { // regular mapped variable
4546 mapData.IsDeclareTarget.push_back(false);
4547 mapData.BasePointers.push_back(mapData.OriginalValue.back());
4548 }
4549
4550 mapData.BaseType.push_back(
4551 moduleTranslation.convertType(mapOp.getVarType()));
4552 mapData.Sizes.push_back(
4553 getSizeInBytes(dl, mapOp.getVarType(), mapOp, mapData.Pointers.back(),
4554 mapData.BaseType.back(), builder, moduleTranslation));
4555 mapData.MapClause.push_back(mapOp.getOperation());
4556 mapData.Types.push_back(convertClauseMapFlags(mapOp.getMapType()));
4557 mapData.Names.push_back(LLVM::createMappingInformation(
4558 mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
4559 mapData.DevicePointers.push_back(llvm::OpenMPIRBuilder::DeviceInfoTy::None);
4560 if (mapOp.getMapperId())
4561 mapData.Mappers.push_back(
4563 mapOp, mapOp.getMapperIdAttr()));
4564 else
4565 mapData.Mappers.push_back(nullptr);
4566 mapData.IsAMapping.push_back(true);
4567 mapData.IsAMember.push_back(checkIsAMember(mapVars, mapOp));
4568 }
4569
4570 auto findMapInfo = [&mapData](llvm::Value *val,
4571 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
4572 unsigned index = 0;
4573 bool found = false;
4574 for (llvm::Value *basePtr : mapData.OriginalValue) {
4575 if (basePtr == val && mapData.IsAMapping[index]) {
4576 found = true;
4577 mapData.Types[index] |=
4578 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
4579 mapData.DevicePointers[index] = devInfoTy;
4580 }
4581 index++;
4582 }
4583 return found;
4584 };
4585
4586 // Process useDevPtr(Addr)Operands
4587 auto addDevInfos = [&](const llvm::ArrayRef<Value> &useDevOperands,
4588 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
4589 for (Value mapValue : useDevOperands) {
4590 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
4591 Value offloadPtr =
4592 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
4593 llvm::Value *origValue = moduleTranslation.lookupValue(offloadPtr);
4594
4595 // Check if map info is already present for this entry.
4596 if (!findMapInfo(origValue, devInfoTy)) {
4597 mapData.OriginalValue.push_back(origValue);
4598 mapData.Pointers.push_back(mapData.OriginalValue.back());
4599 mapData.IsDeclareTarget.push_back(false);
4600 mapData.BasePointers.push_back(mapData.OriginalValue.back());
4601 mapData.BaseType.push_back(
4602 moduleTranslation.convertType(mapOp.getVarType()));
4603 mapData.Sizes.push_back(builder.getInt64(0));
4604 mapData.MapClause.push_back(mapOp.getOperation());
4605 mapData.Types.push_back(
4606 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM);
4607 mapData.Names.push_back(LLVM::createMappingInformation(
4608 mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
4609 mapData.DevicePointers.push_back(devInfoTy);
4610 mapData.Mappers.push_back(nullptr);
4611 mapData.IsAMapping.push_back(false);
4612 mapData.IsAMember.push_back(checkIsAMember(useDevOperands, mapOp));
4613 }
4614 }
4615 };
4616
4617 addDevInfos(useDevAddrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
4618 addDevInfos(useDevPtrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer);
4619
4620 for (Value mapValue : hasDevAddrOperands) {
4621 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
4622 Value offloadPtr =
4623 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
4624 llvm::Value *origValue = moduleTranslation.lookupValue(offloadPtr);
4625 auto mapType = convertClauseMapFlags(mapOp.getMapType());
4626 auto mapTypeAlways = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
4627 bool isDevicePtr =
4628 (mapOp.getMapType() & omp::ClauseMapFlags::is_device_ptr) !=
4629 omp::ClauseMapFlags::none;
4630
4631 mapData.OriginalValue.push_back(origValue);
4632 mapData.BasePointers.push_back(origValue);
4633 mapData.Pointers.push_back(origValue);
4634 mapData.IsDeclareTarget.push_back(false);
4635 mapData.BaseType.push_back(
4636 moduleTranslation.convertType(mapOp.getVarType()));
4637 mapData.Sizes.push_back(
4638 builder.getInt64(dl.getTypeSize(mapOp.getVarType())));
4639 mapData.MapClause.push_back(mapOp.getOperation());
4640 if (llvm::to_underlying(mapType & mapTypeAlways)) {
4641 // Descriptors are mapped with the ALWAYS flag, since they can get
4642 // rematerialized, so the address of the decriptor for a given object
4643 // may change from one place to another.
4644 mapData.Types.push_back(mapType);
4645 // Technically it's possible for a non-descriptor mapping to have
4646 // both has-device-addr and ALWAYS, so lookup the mapper in case it
4647 // exists.
4648 if (mapOp.getMapperId()) {
4649 mapData.Mappers.push_back(
4651 mapOp, mapOp.getMapperIdAttr()));
4652 } else {
4653 mapData.Mappers.push_back(nullptr);
4654 }
4655 } else {
4656 // For is_device_ptr we need the map type to propagate so the runtime
4657 // can materialize the device-side copy of the pointer container.
4658 mapData.Types.push_back(
4659 isDevicePtr ? mapType
4660 : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL);
4661 mapData.Mappers.push_back(nullptr);
4662 }
4663 mapData.Names.push_back(LLVM::createMappingInformation(
4664 mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
4665 mapData.DevicePointers.push_back(
4666 isDevicePtr ? llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer
4667 : llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
4668 mapData.IsAMapping.push_back(false);
4669 mapData.IsAMember.push_back(checkIsAMember(hasDevAddrOperands, mapOp));
4670 }
4671}
4672
4673static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp) {
4674 auto *res = llvm::find(mapData.MapClause, memberOp);
4675 assert(res != mapData.MapClause.end() &&
4676 "MapInfoOp for member not found in MapData, cannot return index");
4677 return std::distance(mapData.MapClause.begin(), res);
4678}
4679
4681 omp::MapInfoOp mapInfo) {
4682 ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
4683 llvm::SmallVector<size_t> occludedChildren;
4684 llvm::sort(
4685 indices.begin(), indices.end(), [&](const size_t a, const size_t b) {
4686 // Bail early if we are asked to look at the same index. If we do not
4687 // bail early, we can end up mistakenly adding indices to
4688 // occludedChildren. This can occur with some types of libc++ hardening.
4689 if (a == b)
4690 return false;
4691
4692 auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]);
4693 auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]);
4694
4695 for (auto it : llvm::zip(memberIndicesA, memberIndicesB)) {
4696 int64_t aIndex = mlir::cast<IntegerAttr>(std::get<0>(it)).getInt();
4697 int64_t bIndex = mlir::cast<IntegerAttr>(std::get<1>(it)).getInt();
4698
4699 if (aIndex == bIndex)
4700 continue;
4701
4702 if (aIndex < bIndex)
4703 return true;
4704
4705 if (aIndex > bIndex)
4706 return false;
4707 }
4708
4709 // Iterated up until the end of the smallest member and
4710 // they were found to be equal up to that point, so select
4711 // the member with the lowest index count, so the "parent"
4712 bool memberAParent = memberIndicesA.size() < memberIndicesB.size();
4713 if (memberAParent)
4714 occludedChildren.push_back(b);
4715 else
4716 occludedChildren.push_back(a);
4717 return memberAParent;
4718 });
4719
4720 // We remove children from the index list that are overshadowed by
4721 // a parent, this prevents us retrieving these as the first or last
4722 // element when the parent is the correct element in these cases.
4723 for (auto v : occludedChildren)
4724 indices.erase(std::remove(indices.begin(), indices.end(), v),
4725 indices.end());
4726}
4727
4728static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo,
4729 bool first) {
4730 ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
4731 // Only 1 member has been mapped, we can return it.
4732 if (indexAttr.size() == 1)
4733 return cast<omp::MapInfoOp>(mapInfo.getMembers()[0].getDefiningOp());
4734 llvm::SmallVector<size_t> indices(indexAttr.size());
4735 std::iota(indices.begin(), indices.end(), 0);
4736 sortMapIndices(indices, mapInfo);
4737 return llvm::cast<omp::MapInfoOp>(
4738 mapInfo.getMembers()[first ? indices.front() : indices.back()]
4739 .getDefiningOp());
4740}
4741
4742/// This function calculates the array/pointer offset for map data provided
4743/// with bounds operations, e.g. when provided something like the following:
4744///
4745/// Fortran
4746/// map(tofrom: array(2:5, 3:2))
4747///
4748/// We must calculate the initial pointer offset to pass across, this function
4749/// performs this using bounds.
4750///
4751/// TODO/WARNING: This only supports Fortran's column major indexing currently
4752/// as is noted in the note below and comments in the function, we must extend
4753/// this function when we add a C++ frontend.
4754/// NOTE: which while specified in row-major order it currently needs to be
4755/// flipped for Fortran's column order array allocation and access (as
4756/// opposed to C++'s row-major, hence the backwards processing where order is
4757/// important). This is likely important to keep in mind for the future when
4758/// we incorporate a C++ frontend, both frontends will need to agree on the
4759/// ordering of generated bounds operations (one may have to flip them) to
4760/// make the below lowering frontend agnostic. The offload size
4761/// calcualtion may also have to be adjusted for C++.
4762static std::vector<llvm::Value *>
4764 llvm::IRBuilderBase &builder, bool isArrayTy,
4765 OperandRange bounds) {
4766 std::vector<llvm::Value *> idx;
4767 // There's no bounds to calculate an offset from, we can safely
4768 // ignore and return no indices.
4769 if (bounds.empty())
4770 return idx;
4771
4772 // If we have an array type, then we have its type so can treat it as a
4773 // normal GEP instruction where the bounds operations are simply indexes
4774 // into the array. We currently do reverse order of the bounds, which
4775 // I believe leans more towards Fortran's column-major in memory.
4776 if (isArrayTy) {
4777 idx.push_back(builder.getInt64(0));
4778 for (int i = bounds.size() - 1; i >= 0; --i) {
4779 if (auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
4780 bounds[i].getDefiningOp())) {
4781 idx.push_back(moduleTranslation.lookupValue(boundOp.getLowerBound()));
4782 }
4783 }
4784 } else {
4785 // If we do not have an array type, but we have bounds, then we're dealing
4786 // with a pointer that's being treated like an array and we have the
4787 // underlying type e.g. an i32, or f64 etc, e.g. a fortran descriptor base
4788 // address (pointer pointing to the actual data) so we must caclulate the
4789 // offset using a single index which the following loop attempts to
4790 // compute using the standard column-major algorithm e.g for a 3D array:
4791 //
4792 // ((((c_idx * b_len) + b_idx) * a_len) + a_idx)
4793 //
4794 // It is of note that it's doing column-major rather than row-major at the
4795 // moment, but having a way for the frontend to indicate which major format
4796 // to use or standardizing/canonicalizing the order of the bounds to compute
4797 // the offset may be useful in the future when there's other frontends with
4798 // different formats.
4799 std::vector<llvm::Value *> dimensionIndexSizeOffset;
4800 for (int i = bounds.size() - 1; i >= 0; --i) {
4801 if (auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
4802 bounds[i].getDefiningOp())) {
4803 if (i == ((int)bounds.size() - 1))
4804 idx.emplace_back(
4805 moduleTranslation.lookupValue(boundOp.getLowerBound()));
4806 else
4807 idx.back() = builder.CreateAdd(
4808 builder.CreateMul(idx.back(), moduleTranslation.lookupValue(
4809 boundOp.getExtent())),
4810 moduleTranslation.lookupValue(boundOp.getLowerBound()));
4811 }
4812 }
4813 }
4814
4815 return idx;
4816}
4817
4819 llvm::transform(values, std::back_inserter(ints), [](Attribute value) {
4820 return cast<IntegerAttr>(value).getInt();
4821 });
4822}
4823
4824// Gathers members that are overlapping in the parent, excluding members that
4825// themselves overlap, keeping the top-most (closest to parents level) map.
4826static void
4828 omp::MapInfoOp parentOp) {
4829 // No members mapped, no overlaps.
4830 if (parentOp.getMembers().empty())
4831 return;
4832
4833 // Single member, we can insert and return early.
4834 if (parentOp.getMembers().size() == 1) {
4835 overlapMapDataIdxs.push_back(0);
4836 return;
4837 }
4838
4839 // 1) collect list of top-level overlapping members from MemberOp
4841 ArrayAttr indexAttr = parentOp.getMembersIndexAttr();
4842 for (auto [memIndex, indicesAttr] : llvm::enumerate(indexAttr))
4843 memberByIndex.push_back(
4844 std::make_pair(memIndex, cast<ArrayAttr>(indicesAttr)));
4845
4846 // Sort the smallest first (higher up the parent -> member chain), so that
4847 // when we remove members, we remove as much as we can in the initial
4848 // iterations, shortening the number of passes required.
4849 llvm::sort(memberByIndex.begin(), memberByIndex.end(),
4850 [&](auto a, auto b) { return a.second.size() < b.second.size(); });
4851
4852 // Remove elements from the vector if there is a parent element that
4853 // supersedes it. i.e. if member [0] is mapped, we can remove members [0,1],
4854 // [0,2].. etc.
4856 for (auto v : memberByIndex) {
4857 llvm::SmallVector<int64_t> vArr(v.second.size());
4858 getAsIntegers(v.second, vArr);
4859 skipList.push_back(
4860 *std::find_if(memberByIndex.begin(), memberByIndex.end(), [&](auto x) {
4861 if (v == x)
4862 return false;
4863 llvm::SmallVector<int64_t> xArr(x.second.size());
4864 getAsIntegers(x.second, xArr);
4865 return std::equal(vArr.begin(), vArr.end(), xArr.begin()) &&
4866 xArr.size() >= vArr.size();
4867 }));
4868 }
4869
4870 // Collect the indices, as we need the base pointer etc. from the MapData
4871 // structure which is primarily accessible via index at the moment.
4872 for (auto v : memberByIndex)
4873 if (find(skipList.begin(), skipList.end(), v) == skipList.end())
4874 overlapMapDataIdxs.push_back(v.first);
4875}
4876
4877// The intent is to verify if the mapped data being passed is a
4878// pointer -> pointee that requires special handling in certain cases,
4879// e.g. applying the OMP_MAP_PTR_AND_OBJ map type.
4880//
4881// There may be a better way to verify this, but unfortunately with
4882// opaque pointers we lose the ability to easily check if something is
4883// a pointer whilst maintaining access to the underlying type.
4884static bool checkIfPointerMap(omp::MapInfoOp mapOp) {
4885 // If we have a varPtrPtr field assigned then the underlying type is a pointer
4886 if (mapOp.getVarPtrPtr())
4887 return true;
4888
4889 // If the map data is declare target with a link clause, then it's represented
4890 // as a pointer when we lower it to LLVM-IR even if at the MLIR level it has
4891 // no relation to pointers.
4892 if (isDeclareTargetLink(mapOp.getVarPtr()))
4893 return true;
4894
4895 return false;
4896}
4897
4898// This creates two insertions into the MapInfosTy data structure for the
4899// "parent" of a set of members, (usually a container e.g.
4900// class/structure/derived type) when subsequent members have also been
4901// explicitly mapped on the same map clause. Certain types, such as Fortran
4902// descriptors are mapped like this as well, however, the members are
4903// implicit as far as a user is concerned, but we must explicitly map them
4904// internally.
4905//
4906// This function also returns the memberOfFlag for this particular parent,
4907// which is utilised in subsequent member mappings (by modifying there map type
4908// with it) to indicate that a member is part of this parent and should be
4909// treated by the runtime as such. Important to achieve the correct mapping.
4910//
4911// This function borrows a lot from Clang's emitCombinedEntry function
4912// inside of CGOpenMPRuntime.cpp
4913static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
4914 LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder,
4915 llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo,
4916 MapInfoData &mapData, uint64_t mapDataIndex,
4917 TargetDirectiveEnumTy targetDirective) {
4918 assert(!ompBuilder.Config.isTargetDevice() &&
4919 "function only supported for host device codegen");
4920
4921 auto parentClause =
4922 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
4923
4924 auto *parentMapper = mapData.Mappers[mapDataIndex];
4925
4926 // Map the first segment of the parent. If a user-defined mapper is attached,
4927 // include the parent's to/from-style bits (and common modifiers) in this
4928 // base entry so the mapper receives correct copy semantics via its 'type'
4929 // parameter. Also keep TARGET_PARAM when required for kernel arguments.
4930 llvm::omp::OpenMPOffloadMappingFlags baseFlag =
4931 (targetDirective == TargetDirectiveEnumTy::Target &&
4932 !mapData.IsDeclareTarget[mapDataIndex])
4933 ? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM
4934 : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
4935
4936 if (parentMapper) {
4937 using mapFlags = llvm::omp::OpenMPOffloadMappingFlags;
4938 // Preserve relevant map-type bits from the parent clause. These include
4939 // the copy direction (TO/FROM), as well as commonly used modifiers that
4940 // should be visible to the mapper for correct behaviour.
4941 mapFlags parentFlags = mapData.Types[mapDataIndex];
4942 mapFlags preserve = mapFlags::OMP_MAP_TO | mapFlags::OMP_MAP_FROM |
4943 mapFlags::OMP_MAP_ALWAYS | mapFlags::OMP_MAP_CLOSE |
4944 mapFlags::OMP_MAP_PRESENT | mapFlags::OMP_MAP_OMPX_HOLD;
4945 baseFlag |= (parentFlags & preserve);
4946 }
4947
4948 combinedInfo.Types.emplace_back(baseFlag);
4949 combinedInfo.DevicePointers.emplace_back(
4950 mapData.DevicePointers[mapDataIndex]);
4951 // Only attach the mapper to the base entry when we are mapping the whole
4952 // parent. Combined/segment entries must not carry a mapper; otherwise the
4953 // mapper can be invoked with a partial size, which is undefined behaviour.
4954 combinedInfo.Mappers.emplace_back(
4955 parentMapper && !parentClause.getPartialMap() ? parentMapper : nullptr);
4956 combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
4957 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
4958 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
4959
4960 // Calculate size of the parent object being mapped based on the
4961 // addresses at runtime, highAddr - lowAddr = size. This of course
4962 // doesn't factor in allocated data like pointers, hence the further
4963 // processing of members specified by users, or in the case of
4964 // Fortran pointers and allocatables, the mapping of the pointed to
4965 // data by the descriptor (which itself, is a structure containing
4966 // runtime information on the dynamically allocated data).
4967 llvm::Value *lowAddr, *highAddr;
4968 if (!parentClause.getPartialMap()) {
4969 lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
4970 builder.getPtrTy());
4971 highAddr = builder.CreatePointerCast(
4972 builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
4973 mapData.Pointers[mapDataIndex], 1),
4974 builder.getPtrTy());
4975 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
4976 } else {
4977 auto mapOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
4978 int firstMemberIdx = getMapDataMemberIdx(
4979 mapData, getFirstOrLastMappedMemberPtr(mapOp, true));
4980 lowAddr = builder.CreatePointerCast(mapData.Pointers[firstMemberIdx],
4981 builder.getPtrTy());
4982 int lastMemberIdx = getMapDataMemberIdx(
4983 mapData, getFirstOrLastMappedMemberPtr(mapOp, false));
4984 highAddr = builder.CreatePointerCast(
4985 builder.CreateGEP(mapData.BaseType[lastMemberIdx],
4986 mapData.Pointers[lastMemberIdx], builder.getInt64(1)),
4987 builder.getPtrTy());
4988 combinedInfo.Pointers.emplace_back(mapData.Pointers[firstMemberIdx]);
4989 }
4990
4991 llvm::Value *size = builder.CreateIntCast(
4992 builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
4993 builder.getInt64Ty(),
4994 /*isSigned=*/false);
4995 combinedInfo.Sizes.push_back(size);
4996
4997 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag =
4998 ompBuilder.getMemberOfFlag(combinedInfo.BasePointers.size() - 1);
4999
5000 // This creates the initial MEMBER_OF mapping that consists of
5001 // the parent/top level container (same as above effectively, except
5002 // with a fixed initial compile time size and separate maptype which
5003 // indicates the true mape type (tofrom etc.). This parent mapping is
5004 // only relevant if the structure in its totality is being mapped,
5005 // otherwise the above suffices.
5006 if (!parentClause.getPartialMap()) {
5007 // TODO: This will need to be expanded to include the whole host of logic
5008 // for the map flags that Clang currently supports (e.g. it should do some
5009 // further case specific flag modifications). For the moment, it handles
5010 // what we support as expected.
5011 llvm::omp::OpenMPOffloadMappingFlags mapFlag = mapData.Types[mapDataIndex];
5012 bool hasMapClose = (llvm::omp::OpenMPOffloadMappingFlags(mapFlag) &
5013 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE) ==
5014 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
5015 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
5016
5017 if (targetDirective == TargetDirectiveEnumTy::TargetUpdate || hasMapClose) {
5018 combinedInfo.Types.emplace_back(mapFlag);
5019 combinedInfo.DevicePointers.emplace_back(
5020 mapData.DevicePointers[mapDataIndex]);
5021 combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
5022 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
5023 combinedInfo.BasePointers.emplace_back(
5024 mapData.BasePointers[mapDataIndex]);
5025 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
5026 combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIndex]);
5027 combinedInfo.Mappers.emplace_back(nullptr);
5028 } else {
5029 llvm::SmallVector<size_t> overlapIdxs;
5030 // Find all of the members that "overlap", i.e. occlude other members that
5031 // were mapped alongside the parent, e.g. member [0], occludes [0,1] and
5032 // [0,2], but not [1,0].
5033 getOverlappedMembers(overlapIdxs, parentClause);
5034 // We need to make sure the overlapped members are sorted in order of
5035 // lowest address to highest address.
5036 sortMapIndices(overlapIdxs, parentClause);
5037
5038 lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
5039 builder.getPtrTy());
5040 highAddr = builder.CreatePointerCast(
5041 builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
5042 mapData.Pointers[mapDataIndex], 1),
5043 builder.getPtrTy());
5044
5045 // TODO: We may want to skip arrays/array sections in this as Clang does.
5046 // It appears to be an optimisation rather than a necessity though,
5047 // but this requires further investigation. However, we would have to make
5048 // sure to not exclude maps with bounds that ARE pointers, as these are
5049 // processed as separate components, i.e. pointer + data.
5050 for (auto v : overlapIdxs) {
5051 auto mapDataOverlapIdx = getMapDataMemberIdx(
5052 mapData,
5053 cast<omp::MapInfoOp>(parentClause.getMembers()[v].getDefiningOp()));
5054 combinedInfo.Types.emplace_back(mapFlag);
5055 combinedInfo.DevicePointers.emplace_back(
5056 mapData.DevicePointers[mapDataOverlapIdx]);
5057 combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
5058 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
5059 combinedInfo.BasePointers.emplace_back(
5060 mapData.BasePointers[mapDataIndex]);
5061 combinedInfo.Mappers.emplace_back(nullptr);
5062 combinedInfo.Pointers.emplace_back(lowAddr);
5063 combinedInfo.Sizes.emplace_back(builder.CreateIntCast(
5064 builder.CreatePtrDiff(builder.getInt8Ty(),
5065 mapData.OriginalValue[mapDataOverlapIdx],
5066 lowAddr),
5067 builder.getInt64Ty(), /*isSigned=*/true));
5068 lowAddr = builder.CreateConstGEP1_32(
5069 checkIfPointerMap(llvm::cast<omp::MapInfoOp>(
5070 mapData.MapClause[mapDataOverlapIdx]))
5071 ? builder.getPtrTy()
5072 : mapData.BaseType[mapDataOverlapIdx],
5073 mapData.BasePointers[mapDataOverlapIdx], 1);
5074 }
5075
5076 combinedInfo.Types.emplace_back(mapFlag);
5077 combinedInfo.DevicePointers.emplace_back(
5078 mapData.DevicePointers[mapDataIndex]);
5079 combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
5080 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
5081 combinedInfo.BasePointers.emplace_back(
5082 mapData.BasePointers[mapDataIndex]);
5083 combinedInfo.Mappers.emplace_back(nullptr);
5084 combinedInfo.Pointers.emplace_back(lowAddr);
5085 combinedInfo.Sizes.emplace_back(builder.CreateIntCast(
5086 builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
5087 builder.getInt64Ty(), true));
5088 }
5089 }
5090 return memberOfFlag;
5091}
5092
5093// This function is intended to add explicit mappings of members
5095 LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder,
5096 llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo,
5097 MapInfoData &mapData, uint64_t mapDataIndex,
5098 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag,
5099 TargetDirectiveEnumTy targetDirective) {
5100 assert(!ompBuilder.Config.isTargetDevice() &&
5101 "function only supported for host device codegen");
5102
5103 auto parentClause =
5104 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
5105
5106 for (auto mappedMembers : parentClause.getMembers()) {
5107 auto memberClause =
5108 llvm::cast<omp::MapInfoOp>(mappedMembers.getDefiningOp());
5109 int memberDataIdx = getMapDataMemberIdx(mapData, memberClause);
5110
5111 assert(memberDataIdx >= 0 && "could not find mapped member of structure");
5112
5113 // If we're currently mapping a pointer to a block of data, we must
5114 // initially map the pointer, and then attatch/bind the data with a
5115 // subsequent map to the pointer. This segment of code generates the
5116 // pointer mapping, which can in certain cases be optimised out as Clang
5117 // currently does in its lowering. However, for the moment we do not do so,
5118 // in part as we currently have substantially less information on the data
5119 // being mapped at this stage.
5120 if (checkIfPointerMap(memberClause)) {
5121 auto mapFlag = convertClauseMapFlags(memberClause.getMapType());
5122 mapFlag &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
5123 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
5124 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
5125 combinedInfo.Types.emplace_back(mapFlag);
5126 combinedInfo.DevicePointers.emplace_back(
5127 llvm::OpenMPIRBuilder::DeviceInfoTy::None);
5128 combinedInfo.Mappers.emplace_back(nullptr);
5129 combinedInfo.Names.emplace_back(
5130 LLVM::createMappingInformation(memberClause.getLoc(), ompBuilder));
5131 combinedInfo.BasePointers.emplace_back(
5132 mapData.BasePointers[mapDataIndex]);
5133 combinedInfo.Pointers.emplace_back(mapData.BasePointers[memberDataIdx]);
5134 combinedInfo.Sizes.emplace_back(builder.getInt64(
5135 moduleTranslation.getLLVMModule()->getDataLayout().getPointerSize()));
5136 }
5137
5138 // Same MemberOfFlag to indicate its link with parent and other members
5139 // of.
5140 auto mapFlag = convertClauseMapFlags(memberClause.getMapType());
5141 mapFlag &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
5142 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
5143 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
5144 bool isDeclTargetTo = isDeclareTargetTo(parentClause.getVarPtr()
5145 ? parentClause.getVarPtr()
5146 : parentClause.getVarPtrPtr());
5147 if (checkIfPointerMap(memberClause) &&
5148 (!isDeclTargetTo ||
5149 (targetDirective != TargetDirectiveEnumTy::TargetUpdate &&
5150 targetDirective != TargetDirectiveEnumTy::TargetData))) {
5151 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
5152 }
5153
5154 combinedInfo.Types.emplace_back(mapFlag);
5155 combinedInfo.DevicePointers.emplace_back(
5156 mapData.DevicePointers[memberDataIdx]);
5157 combinedInfo.Mappers.emplace_back(mapData.Mappers[memberDataIdx]);
5158 combinedInfo.Names.emplace_back(
5159 LLVM::createMappingInformation(memberClause.getLoc(), ompBuilder));
5160 uint64_t basePointerIndex =
5161 checkIfPointerMap(memberClause) ? memberDataIdx : mapDataIndex;
5162 combinedInfo.BasePointers.emplace_back(
5163 mapData.BasePointers[basePointerIndex]);
5164 combinedInfo.Pointers.emplace_back(mapData.Pointers[memberDataIdx]);
5165
5166 llvm::Value *size = mapData.Sizes[memberDataIdx];
5167 if (checkIfPointerMap(memberClause)) {
5168 size = builder.CreateSelect(
5169 builder.CreateIsNull(mapData.Pointers[memberDataIdx]),
5170 builder.getInt64(0), size);
5171 }
5172
5173 combinedInfo.Sizes.emplace_back(size);
5174 }
5175}
5176
5177static void processIndividualMap(MapInfoData &mapData, size_t mapDataIdx,
5178 MapInfosTy &combinedInfo,
5179 TargetDirectiveEnumTy targetDirective,
5180 int mapDataParentIdx = -1) {
5181 // Declare Target Mappings are excluded from being marked as
5182 // OMP_MAP_TARGET_PARAM as they are not passed as parameters, they're
5183 // marked with OMP_MAP_PTR_AND_OBJ instead.
5184 auto mapFlag = mapData.Types[mapDataIdx];
5185 auto mapInfoOp = llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIdx]);
5186
5187 bool isPtrTy = checkIfPointerMap(mapInfoOp);
5188 if (isPtrTy)
5189 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
5190
5191 if (targetDirective == TargetDirectiveEnumTy::Target &&
5192 !mapData.IsDeclareTarget[mapDataIdx])
5193 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
5194
5195 if (mapInfoOp.getMapCaptureType() == omp::VariableCaptureKind::ByCopy &&
5196 !isPtrTy)
5197 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
5198
5199 // if we're provided a mapDataParentIdx, then the data being mapped is
5200 // part of a larger object (in a parent <-> member mapping) and in this
5201 // case our BasePointer should be the parent.
5202 if (mapDataParentIdx >= 0)
5203 combinedInfo.BasePointers.emplace_back(
5204 mapData.BasePointers[mapDataParentIdx]);
5205 else
5206 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIdx]);
5207
5208 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIdx]);
5209 combinedInfo.DevicePointers.emplace_back(mapData.DevicePointers[mapDataIdx]);
5210 combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIdx]);
5211 combinedInfo.Names.emplace_back(mapData.Names[mapDataIdx]);
5212 combinedInfo.Types.emplace_back(mapFlag);
5213 combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIdx]);
5214}
5215
5217 llvm::IRBuilderBase &builder,
5218 llvm::OpenMPIRBuilder &ompBuilder,
5219 DataLayout &dl, MapInfosTy &combinedInfo,
5220 MapInfoData &mapData, uint64_t mapDataIndex,
5221 TargetDirectiveEnumTy targetDirective) {
5222 assert(!ompBuilder.Config.isTargetDevice() &&
5223 "function only supported for host device codegen");
5224
5225 auto parentClause =
5226 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
5227
5228 // If we have a partial map (no parent referenced in the map clauses of the
5229 // directive, only members) and only a single member, we do not need to bind
5230 // the map of the member to the parent, we can pass the member separately.
5231 if (parentClause.getMembers().size() == 1 && parentClause.getPartialMap()) {
5232 auto memberClause = llvm::cast<omp::MapInfoOp>(
5233 parentClause.getMembers()[0].getDefiningOp());
5234 int memberDataIdx = getMapDataMemberIdx(mapData, memberClause);
5235 // Note: Clang treats arrays with explicit bounds that fall into this
5236 // category as a parent with map case, however, it seems this isn't a
5237 // requirement, and processing them as an individual map is fine. So,
5238 // we will handle them as individual maps for the moment, as it's
5239 // difficult for us to check this as we always require bounds to be
5240 // specified currently and it's also marginally more optimal (single
5241 // map rather than two). The difference may come from the fact that
5242 // Clang maps array without bounds as pointers (which we do not
5243 // currently do), whereas we treat them as arrays in all cases
5244 // currently.
5245 processIndividualMap(mapData, memberDataIdx, combinedInfo, targetDirective,
5246 mapDataIndex);
5247 return;
5248 }
5249
5250 llvm::omp::OpenMPOffloadMappingFlags memberOfParentFlag =
5251 mapParentWithMembers(moduleTranslation, builder, ompBuilder, dl,
5252 combinedInfo, mapData, mapDataIndex,
5253 targetDirective);
5254 processMapMembersWithParent(moduleTranslation, builder, ompBuilder, dl,
5255 combinedInfo, mapData, mapDataIndex,
5256 memberOfParentFlag, targetDirective);
5257}
5258
5259// This is a variation on Clang's GenerateOpenMPCapturedVars, which
5260// generates different operation (e.g. load/store) combinations for
5261// arguments to the kernel, based on map capture kinds which are then
5262// utilised in the combinedInfo in place of the original Map value.
5263static void
5264createAlteredByCaptureMap(MapInfoData &mapData,
5265 LLVM::ModuleTranslation &moduleTranslation,
5266 llvm::IRBuilderBase &builder) {
5267 assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
5268 "function only supported for host device codegen");
5269 for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
5270 // if it's declare target, skip it, it's handled separately.
5271 if (!mapData.IsDeclareTarget[i]) {
5272 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
5273 omp::VariableCaptureKind captureKind = mapOp.getMapCaptureType();
5274 bool isPtrTy = checkIfPointerMap(mapOp);
5275
5276 // Currently handles array sectioning lowerbound case, but more
5277 // logic may be required in the future. Clang invokes EmitLValue,
5278 // which has specialised logic for special Clang types such as user
5279 // defines, so it is possible we will have to extend this for
5280 // structures or other complex types. As the general idea is that this
5281 // function mimics some of the logic from Clang that we require for
5282 // kernel argument passing from host -> device.
5283 switch (captureKind) {
5284 case omp::VariableCaptureKind::ByRef: {
5285 llvm::Value *newV = mapData.Pointers[i];
5286 std::vector<llvm::Value *> offsetIdx = calculateBoundsOffset(
5287 moduleTranslation, builder, mapData.BaseType[i]->isArrayTy(),
5288 mapOp.getBounds());
5289 if (isPtrTy)
5290 newV = builder.CreateLoad(builder.getPtrTy(), newV);
5291
5292 if (!offsetIdx.empty())
5293 newV = builder.CreateInBoundsGEP(mapData.BaseType[i], newV, offsetIdx,
5294 "array_offset");
5295 mapData.Pointers[i] = newV;
5296 } break;
5297 case omp::VariableCaptureKind::ByCopy: {
5298 llvm::Type *type = mapData.BaseType[i];
5299 llvm::Value *newV;
5300 if (mapData.Pointers[i]->getType()->isPointerTy())
5301 newV = builder.CreateLoad(type, mapData.Pointers[i]);
5302 else
5303 newV = mapData.Pointers[i];
5304
5305 if (!isPtrTy) {
5306 auto curInsert = builder.saveIP();
5307 llvm::DebugLoc DbgLoc = builder.getCurrentDebugLocation();
5308 builder.restoreIP(findAllocaInsertPoint(builder, moduleTranslation));
5309 auto *memTempAlloc =
5310 builder.CreateAlloca(builder.getPtrTy(), nullptr, ".casted");
5311 builder.SetCurrentDebugLocation(DbgLoc);
5312 builder.restoreIP(curInsert);
5313
5314 builder.CreateStore(newV, memTempAlloc);
5315 newV = builder.CreateLoad(builder.getPtrTy(), memTempAlloc);
5316 }
5317
5318 mapData.Pointers[i] = newV;
5319 mapData.BasePointers[i] = newV;
5320 } break;
5321 case omp::VariableCaptureKind::This:
5322 case omp::VariableCaptureKind::VLAType:
5323 mapData.MapClause[i]->emitOpError("Unhandled capture kind");
5324 break;
5325 }
5326 }
5327 }
5328}
5329
5330// Generate all map related information and fill the combinedInfo.
5331static void genMapInfos(llvm::IRBuilderBase &builder,
5332 LLVM::ModuleTranslation &moduleTranslation,
5333 DataLayout &dl, MapInfosTy &combinedInfo,
5334 MapInfoData &mapData,
5335 TargetDirectiveEnumTy targetDirective) {
5336 assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
5337 "function only supported for host device codegen");
5338 // We wish to modify some of the methods in which arguments are
5339 // passed based on their capture type by the target region, this can
5340 // involve generating new loads and stores, which changes the
5341 // MLIR value to LLVM value mapping, however, we only wish to do this
5342 // locally for the current function/target and also avoid altering
5343 // ModuleTranslation, so we remap the base pointer or pointer stored
5344 // in the map infos corresponding MapInfoData, which is later accessed
5345 // by genMapInfos and createTarget to help generate the kernel and
5346 // kernel arg structure. It primarily becomes relevant in cases like
5347 // bycopy, or byref range'd arrays. In the default case, we simply
5348 // pass thee pointer byref as both basePointer and pointer.
5349 createAlteredByCaptureMap(mapData, moduleTranslation, builder);
5350
5351 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
5352
5353 // We operate under the assumption that all vectors that are
5354 // required in MapInfoData are of equal lengths (either filled with
5355 // default constructed data or appropiate information) so we can
5356 // utilise the size from any component of MapInfoData, if we can't
5357 // something is missing from the initial MapInfoData construction.
5358 for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
5359 // NOTE/TODO: We currently do not support arbitrary depth record
5360 // type mapping.
5361 if (mapData.IsAMember[i])
5362 continue;
5363
5364 auto mapInfoOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[i]);
5365 if (!mapInfoOp.getMembers().empty()) {
5366 processMapWithMembersOf(moduleTranslation, builder, *ompBuilder, dl,
5367 combinedInfo, mapData, i, targetDirective);
5368 continue;
5369 }
5370
5371 processIndividualMap(mapData, i, combinedInfo, targetDirective);
5372 }
5373}
5374
5375static llvm::Expected<llvm::Function *>
5376emitUserDefinedMapper(Operation *declMapperOp, llvm::IRBuilderBase &builder,
5377 LLVM::ModuleTranslation &moduleTranslation,
5378 llvm::StringRef mapperFuncName,
5379 TargetDirectiveEnumTy targetDirective);
5380
5381static llvm::Expected<llvm::Function *>
5382getOrCreateUserDefinedMapperFunc(Operation *op, llvm::IRBuilderBase &builder,
5383 LLVM::ModuleTranslation &moduleTranslation,
5384 TargetDirectiveEnumTy targetDirective) {
5385 assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
5386 "function only supported for host device codegen");
5387 auto declMapperOp = cast<omp::DeclareMapperOp>(op);
5388 std::string mapperFuncName =
5389 moduleTranslation.getOpenMPBuilder()->createPlatformSpecificName(
5390 {"omp_mapper", declMapperOp.getSymName()});
5391
5392 if (auto *lookupFunc = moduleTranslation.lookupFunction(mapperFuncName))
5393 return lookupFunc;
5394
5395 // Recursive types can cause re-entrant mapper emission. The mapper function
5396 // is created by OpenMPIRBuilder before the callbacks run, so it may already
5397 // exist in the LLVM module even though it is not yet registered in the
5398 // ModuleTranslation mapping table. Reuse and register it to break the
5399 // recursion.
5400 if (llvm::Function *existingFunc =
5401 moduleTranslation.getLLVMModule()->getFunction(mapperFuncName)) {
5402 moduleTranslation.mapFunction(mapperFuncName, existingFunc);
5403 return existingFunc;
5404 }
5405
5406 return emitUserDefinedMapper(declMapperOp, builder, moduleTranslation,
5407 mapperFuncName, targetDirective);
5408}
5409
5410static llvm::Expected<llvm::Function *>
5411emitUserDefinedMapper(Operation *op, llvm::IRBuilderBase &builder,
5412 LLVM::ModuleTranslation &moduleTranslation,
5413 llvm::StringRef mapperFuncName,
5414 TargetDirectiveEnumTy targetDirective) {
5415 assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
5416 "function only supported for host device codegen");
5417 auto declMapperOp = cast<omp::DeclareMapperOp>(op);
5418 auto declMapperInfoOp = declMapperOp.getDeclareMapperInfo();
5419 DataLayout dl = DataLayout(declMapperOp->getParentOfType<ModuleOp>());
5420 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
5421 llvm::Type *varType = moduleTranslation.convertType(declMapperOp.getType());
5422 SmallVector<Value> mapVars = declMapperInfoOp.getMapVars();
5423
5424 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
5425
5426 // Fill up the arrays with all the mapped variables.
5427 MapInfosTy combinedInfo;
5428 auto genMapInfoCB =
5429 [&](InsertPointTy codeGenIP, llvm::Value *ptrPHI,
5430 llvm::Value *unused2) -> llvm::OpenMPIRBuilder::MapInfosOrErrorTy {
5431 builder.restoreIP(codeGenIP);
5432 moduleTranslation.mapValue(declMapperOp.getSymVal(), ptrPHI);
5433 moduleTranslation.mapBlock(&declMapperOp.getRegion().front(),
5434 builder.GetInsertBlock());
5435 if (failed(moduleTranslation.convertBlock(declMapperOp.getRegion().front(),
5436 /*ignoreArguments=*/true,
5437 builder)))
5438 return llvm::make_error<PreviouslyReportedError>();
5439 MapInfoData mapData;
5440 collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, dl,
5441 builder);
5442 genMapInfos(builder, moduleTranslation, dl, combinedInfo, mapData,
5443 targetDirective);
5444
5445 // Drop the mapping that is no longer necessary so that the same region
5446 // can be processed multiple times.
5447 moduleTranslation.forgetMapping(declMapperOp.getRegion());
5448 return combinedInfo;
5449 };
5450
5451 auto customMapperCB = [&](unsigned i) -> llvm::Expected<llvm::Function *> {
5452 if (!combinedInfo.Mappers[i])
5453 return nullptr;
5454 return getOrCreateUserDefinedMapperFunc(combinedInfo.Mappers[i], builder,
5455 moduleTranslation, targetDirective);
5456 };
5457
5458 llvm::Expected<llvm::Function *> newFn = ompBuilder->emitUserDefinedMapper(
5459 genMapInfoCB, varType, mapperFuncName, customMapperCB);
5460 if (!newFn)
5461 return newFn.takeError();
5462 if ([[maybe_unused]] llvm::Function *mappedFunc =
5463 moduleTranslation.lookupFunction(mapperFuncName)) {
5464 assert(mappedFunc == *newFn &&
5465 "mapper function mapping disagrees with emitted function");
5466 } else {
5467 moduleTranslation.mapFunction(mapperFuncName, *newFn);
5468 }
5469 return *newFn;
5470}
5471
5472static LogicalResult
5473convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
5474 LLVM::ModuleTranslation &moduleTranslation) {
5475 llvm::Value *ifCond = nullptr;
5476 llvm::Value *deviceID = builder.getInt64(llvm::omp::OMP_DEVICEID_UNDEF);
5477 SmallVector<Value> mapVars;
5478 SmallVector<Value> useDevicePtrVars;
5479 SmallVector<Value> useDeviceAddrVars;
5480 llvm::omp::RuntimeFunction RTLFn;
5481 DataLayout DL = DataLayout(op->getParentOfType<ModuleOp>());
5482 TargetDirectiveEnumTy targetDirective = getTargetDirectiveEnumTyFromOp(op);
5483
5484 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
5485 llvm::OpenMPIRBuilder::TargetDataInfo info(
5486 /*RequiresDevicePointerInfo=*/true,
5487 /*SeparateBeginEndCalls=*/true);
5488 assert(!ompBuilder->Config.isTargetDevice() &&
5489 "target data/enter/exit/update are host ops");
5490 bool isOffloadEntry = !ompBuilder->Config.TargetTriples.empty();
5491
5492 auto getDeviceID = [&](mlir::Value dev) -> llvm::Value * {
5493 llvm::Value *v = moduleTranslation.lookupValue(dev);
5494 return builder.CreateIntCast(v, builder.getInt64Ty(), /*isSigned=*/true);
5495 };
5496
5497 LogicalResult result =
5499 .Case([&](omp::TargetDataOp dataOp) {
5500 if (failed(checkImplementationStatus(*dataOp)))
5501 return failure();
5502
5503 if (auto ifVar = dataOp.getIfExpr())
5504 ifCond = moduleTranslation.lookupValue(ifVar);
5505
5506 if (mlir::Value devId = dataOp.getDevice())
5507 deviceID = getDeviceID(devId);
5508
5509 mapVars = dataOp.getMapVars();
5510 useDevicePtrVars = dataOp.getUseDevicePtrVars();
5511 useDeviceAddrVars = dataOp.getUseDeviceAddrVars();
5512 return success();
5513 })
5514 .Case([&](omp::TargetEnterDataOp enterDataOp) -> LogicalResult {
5515 if (failed(checkImplementationStatus(*enterDataOp)))
5516 return failure();
5517
5518 if (auto ifVar = enterDataOp.getIfExpr())
5519 ifCond = moduleTranslation.lookupValue(ifVar);
5520
5521 if (mlir::Value devId = enterDataOp.getDevice())
5522 deviceID = getDeviceID(devId);
5523
5524 RTLFn =
5525 enterDataOp.getNowait()
5526 ? llvm::omp::OMPRTL___tgt_target_data_begin_nowait_mapper
5527 : llvm::omp::OMPRTL___tgt_target_data_begin_mapper;
5528 mapVars = enterDataOp.getMapVars();
5529 info.HasNoWait = enterDataOp.getNowait();
5530 return success();
5531 })
5532 .Case([&](omp::TargetExitDataOp exitDataOp) -> LogicalResult {
5533 if (failed(checkImplementationStatus(*exitDataOp)))
5534 return failure();
5535
5536 if (auto ifVar = exitDataOp.getIfExpr())
5537 ifCond = moduleTranslation.lookupValue(ifVar);
5538
5539 if (mlir::Value devId = exitDataOp.getDevice())
5540 deviceID = getDeviceID(devId);
5541
5542 RTLFn = exitDataOp.getNowait()
5543 ? llvm::omp::OMPRTL___tgt_target_data_end_nowait_mapper
5544 : llvm::omp::OMPRTL___tgt_target_data_end_mapper;
5545 mapVars = exitDataOp.getMapVars();
5546 info.HasNoWait = exitDataOp.getNowait();
5547 return success();
5548 })
5549 .Case([&](omp::TargetUpdateOp updateDataOp) -> LogicalResult {
5550 if (failed(checkImplementationStatus(*updateDataOp)))
5551 return failure();
5552
5553 if (auto ifVar = updateDataOp.getIfExpr())
5554 ifCond = moduleTranslation.lookupValue(ifVar);
5555
5556 if (mlir::Value devId = updateDataOp.getDevice())
5557 deviceID = getDeviceID(devId);
5558
5559 RTLFn =
5560 updateDataOp.getNowait()
5561 ? llvm::omp::OMPRTL___tgt_target_data_update_nowait_mapper
5562 : llvm::omp::OMPRTL___tgt_target_data_update_mapper;
5563 mapVars = updateDataOp.getMapVars();
5564 info.HasNoWait = updateDataOp.getNowait();
5565 return success();
5566 })
5567 .DefaultUnreachable("unexpected operation");
5568
5569 if (failed(result))
5570 return failure();
5571 // Pretend we have IF(false) if we're not doing offload.
5572 if (!isOffloadEntry)
5573 ifCond = builder.getFalse();
5574
5575 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
5576 MapInfoData mapData;
5577 collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, DL,
5578 builder, useDevicePtrVars, useDeviceAddrVars);
5579
5580 // Fill up the arrays with all the mapped variables.
5581 MapInfosTy combinedInfo;
5582 auto genMapInfoCB = [&](InsertPointTy codeGenIP) -> MapInfosTy & {
5583 builder.restoreIP(codeGenIP);
5584 genMapInfos(builder, moduleTranslation, DL, combinedInfo, mapData,
5585 targetDirective);
5586 return combinedInfo;
5587 };
5588
5589 // Define a lambda to apply mappings between use_device_addr and
5590 // use_device_ptr base pointers, and their associated block arguments.
5591 auto mapUseDevice =
5592 [&moduleTranslation](
5593 llvm::OpenMPIRBuilder::DeviceInfoTy type,
5595 llvm::SmallVectorImpl<Value> &useDeviceVars, MapInfoData &mapInfoData,
5596 llvm::function_ref<llvm::Value *(llvm::Value *)> mapper = nullptr) {
5597 for (auto [arg, useDevVar] :
5598 llvm::zip_equal(blockArgs, useDeviceVars)) {
5599
5600 auto getMapBasePtr = [](omp::MapInfoOp mapInfoOp) {
5601 return mapInfoOp.getVarPtrPtr() ? mapInfoOp.getVarPtrPtr()
5602 : mapInfoOp.getVarPtr();
5603 };
5604
5605 auto useDevMap = cast<omp::MapInfoOp>(useDevVar.getDefiningOp());
5606 for (auto [mapClause, devicePointer, basePointer] : llvm::zip_equal(
5607 mapInfoData.MapClause, mapInfoData.DevicePointers,
5608 mapInfoData.BasePointers)) {
5609 auto mapOp = cast<omp::MapInfoOp>(mapClause);
5610 if (getMapBasePtr(mapOp) != getMapBasePtr(useDevMap) ||
5611 devicePointer != type)
5612 continue;
5613
5614 if (llvm::Value *devPtrInfoMap =
5615 mapper ? mapper(basePointer) : basePointer) {
5616 moduleTranslation.mapValue(arg, devPtrInfoMap);
5617 break;
5618 }
5619 }
5620 }
5621 };
5622
5623 using BodyGenTy = llvm::OpenMPIRBuilder::BodyGenTy;
5624 auto bodyGenCB = [&](InsertPointTy codeGenIP, BodyGenTy bodyGenType)
5625 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
5626 // We must always restoreIP regardless of doing anything the caller
5627 // does not restore it, leading to incorrect (no) branch generation.
5628 builder.restoreIP(codeGenIP);
5629 assert(isa<omp::TargetDataOp>(op) &&
5630 "BodyGen requested for non TargetDataOp");
5631 auto blockArgIface = cast<omp::BlockArgOpenMPOpInterface>(op);
5632 Region &region = cast<omp::TargetDataOp>(op).getRegion();
5633 switch (bodyGenType) {
5634 case BodyGenTy::Priv:
5635 // Check if any device ptr/addr info is available
5636 if (!info.DevicePtrInfoMap.empty()) {
5637 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
5638 blockArgIface.getUseDeviceAddrBlockArgs(),
5639 useDeviceAddrVars, mapData,
5640 [&](llvm::Value *basePointer) -> llvm::Value * {
5641 if (!info.DevicePtrInfoMap[basePointer].second)
5642 return nullptr;
5643 return builder.CreateLoad(
5644 builder.getPtrTy(),
5645 info.DevicePtrInfoMap[basePointer].second);
5646 });
5647 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
5648 blockArgIface.getUseDevicePtrBlockArgs(), useDevicePtrVars,
5649 mapData, [&](llvm::Value *basePointer) {
5650 return info.DevicePtrInfoMap[basePointer].second;
5651 });
5652
5653 if (failed(inlineConvertOmpRegions(region, "omp.data.region", builder,
5654 moduleTranslation)))
5655 return llvm::make_error<PreviouslyReportedError>();
5656 }
5657 break;
5658 case BodyGenTy::DupNoPriv:
5659 if (info.DevicePtrInfoMap.empty()) {
5660 // For host device we still need to do the mapping for codegen,
5661 // otherwise it may try to lookup a missing value.
5662 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
5663 blockArgIface.getUseDeviceAddrBlockArgs(),
5664 useDeviceAddrVars, mapData);
5665 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
5666 blockArgIface.getUseDevicePtrBlockArgs(), useDevicePtrVars,
5667 mapData);
5668 }
5669 break;
5670 case BodyGenTy::NoPriv:
5671 // If device info is available then region has already been generated
5672 if (info.DevicePtrInfoMap.empty()) {
5673 if (failed(inlineConvertOmpRegions(region, "omp.data.region", builder,
5674 moduleTranslation)))
5675 return llvm::make_error<PreviouslyReportedError>();
5676 }
5677 break;
5678 }
5679 return builder.saveIP();
5680 };
5681
5682 auto customMapperCB =
5683 [&](unsigned int i) -> llvm::Expected<llvm::Function *> {
5684 if (!combinedInfo.Mappers[i])
5685 return nullptr;
5686 info.HasMapper = true;
5687 return getOrCreateUserDefinedMapperFunc(combinedInfo.Mappers[i], builder,
5688 moduleTranslation, targetDirective);
5689 };
5690
5691 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
5692 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
5693 findAllocaInsertPoint(builder, moduleTranslation);
5694 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = [&]() {
5695 if (isa<omp::TargetDataOp>(op))
5696 return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(),
5697 deviceID, ifCond, info, genMapInfoCB,
5698 customMapperCB,
5699 /*MapperFunc=*/nullptr, bodyGenCB,
5700 /*DeviceAddrCB=*/nullptr);
5701 return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(),
5702 deviceID, ifCond, info, genMapInfoCB,
5703 customMapperCB, &RTLFn);
5704 }();
5705
5706 if (failed(handleError(afterIP, *op)))
5707 return failure();
5708
5709 builder.restoreIP(*afterIP);
5710 return success();
5711}
5712
5713static LogicalResult
5714convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder,
5715 LLVM::ModuleTranslation &moduleTranslation) {
5716 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
5717 auto distributeOp = cast<omp::DistributeOp>(opInst);
5718 if (failed(checkImplementationStatus(opInst)))
5719 return failure();
5720
5721 /// Process teams op reduction in distribute if the reduction is contained in
5722 /// the distribute op.
5723 omp::TeamsOp teamsOp = opInst.getParentOfType<omp::TeamsOp>();
5724 bool doDistributeReduction =
5725 teamsOp ? teamsReductionContainedInDistribute(teamsOp) : false;
5726
5727 DenseMap<Value, llvm::Value *> reductionVariableMap;
5728 unsigned numReductionVars = teamsOp ? teamsOp.getNumReductionVars() : 0;
5730 SmallVector<llvm::Value *> privateReductionVariables(numReductionVars);
5731 llvm::ArrayRef<bool> isByRef;
5732
5733 if (doDistributeReduction) {
5734 isByRef = getIsByRef(teamsOp.getReductionByref());
5735 assert(isByRef.size() == teamsOp.getNumReductionVars());
5736
5737 collectReductionDecls(teamsOp, reductionDecls);
5738 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
5739 findAllocaInsertPoint(builder, moduleTranslation);
5740
5741 MutableArrayRef<BlockArgument> reductionArgs =
5742 llvm::cast<omp::BlockArgOpenMPOpInterface>(*teamsOp)
5743 .getReductionBlockArgs();
5744
5746 teamsOp, reductionArgs, builder, moduleTranslation, allocaIP,
5747 reductionDecls, privateReductionVariables, reductionVariableMap,
5748 isByRef)))
5749 return failure();
5750 }
5751
5752 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
5753 auto bodyGenCB = [&](InsertPointTy allocaIP,
5754 InsertPointTy codeGenIP) -> llvm::Error {
5755 // Save the alloca insertion point on ModuleTranslation stack for use in
5756 // nested regions.
5758 moduleTranslation, allocaIP);
5759
5760 // DistributeOp has only one region associated with it.
5761 builder.restoreIP(codeGenIP);
5762 PrivateVarsInfo privVarsInfo(distributeOp);
5763
5765 allocatePrivateVars(builder, moduleTranslation, privVarsInfo, allocaIP);
5766 if (handleError(afterAllocas, opInst).failed())
5767 return llvm::make_error<PreviouslyReportedError>();
5768
5769 if (handleError(initPrivateVars(builder, moduleTranslation, privVarsInfo),
5770 opInst)
5771 .failed())
5772 return llvm::make_error<PreviouslyReportedError>();
5773
5774 if (failed(copyFirstPrivateVars(
5775 distributeOp, builder, moduleTranslation, privVarsInfo.mlirVars,
5776 privVarsInfo.llvmVars, privVarsInfo.privatizers,
5777 distributeOp.getPrivateNeedsBarrier())))
5778 return llvm::make_error<PreviouslyReportedError>();
5779
5780 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
5781 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
5783 convertOmpOpRegions(distributeOp.getRegion(), "omp.distribute.region",
5784 builder, moduleTranslation);
5785 if (!regionBlock)
5786 return regionBlock.takeError();
5787 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
5788
5789 // Skip applying a workshare loop below when translating 'distribute
5790 // parallel do' (it's been already handled by this point while translating
5791 // the nested omp.wsloop).
5792 if (!isa_and_present<omp::WsloopOp>(distributeOp.getNestedWrapper())) {
5793 // TODO: Add support for clauses which are valid for DISTRIBUTE
5794 // constructs. Static schedule is the default.
5795 bool hasDistSchedule = distributeOp.getDistScheduleStatic();
5796 auto schedule = hasDistSchedule ? omp::ClauseScheduleKind::Distribute
5797 : omp::ClauseScheduleKind::Static;
5798 // dist_schedule clauses are ordered - otherise this should be false
5799 bool isOrdered = hasDistSchedule;
5800 std::optional<omp::ScheduleModifier> scheduleMod;
5801 bool isSimd = false;
5802 llvm::omp::WorksharingLoopType workshareLoopType =
5803 llvm::omp::WorksharingLoopType::DistributeStaticLoop;
5804 bool loopNeedsBarrier = false;
5805 llvm::Value *chunk = moduleTranslation.lookupValue(
5806 distributeOp.getDistScheduleChunkSize());
5807 llvm::CanonicalLoopInfo *loopInfo =
5808 findCurrentLoopInfo(moduleTranslation);
5809 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
5810 ompBuilder->applyWorkshareLoop(
5811 ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
5812 convertToScheduleKind(schedule), chunk, isSimd,
5813 scheduleMod == omp::ScheduleModifier::monotonic,
5814 scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
5815 workshareLoopType, false, hasDistSchedule, chunk);
5816
5817 if (!wsloopIP)
5818 return wsloopIP.takeError();
5819 }
5820 if (failed(cleanupPrivateVars(builder, moduleTranslation,
5821 distributeOp.getLoc(), privVarsInfo.llvmVars,
5822 privVarsInfo.privatizers)))
5823 return llvm::make_error<PreviouslyReportedError>();
5824
5825 return llvm::Error::success();
5826 };
5827
5828 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
5829 findAllocaInsertPoint(builder, moduleTranslation);
5830 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
5831 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
5832 ompBuilder->createDistribute(ompLoc, allocaIP, bodyGenCB);
5833
5834 if (failed(handleError(afterIP, opInst)))
5835 return failure();
5836
5837 builder.restoreIP(*afterIP);
5838
5839 if (doDistributeReduction) {
5840 // Process the reductions if required.
5842 teamsOp, builder, moduleTranslation, allocaIP, reductionDecls,
5843 privateReductionVariables, isByRef,
5844 /*isNoWait*/ false, /*isTeamsReduction*/ true);
5845 }
5846 return success();
5847}
5848
5849/// Lowers the FlagsAttr which is applied to the module on the device
5850/// pass when offloading, this attribute contains OpenMP RTL globals that can
5851/// be passed as flags to the frontend, otherwise they are set to default
5852static LogicalResult
5853convertFlagsAttr(Operation *op, mlir::omp::FlagsAttr attribute,
5854 LLVM::ModuleTranslation &moduleTranslation) {
5855 if (!cast<mlir::ModuleOp>(op))
5856 return failure();
5857
5858 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
5859
5860 ompBuilder->M.addModuleFlag(llvm::Module::Max, "openmp-device",
5861 attribute.getOpenmpDeviceVersion());
5862
5863 if (attribute.getNoGpuLib())
5864 return success();
5865
5866 ompBuilder->createGlobalFlag(
5867 attribute.getDebugKind() /*LangOpts().OpenMPTargetDebug*/,
5868 "__omp_rtl_debug_kind");
5869 ompBuilder->createGlobalFlag(
5870 attribute
5871 .getAssumeTeamsOversubscription() /*LangOpts().OpenMPTeamSubscription*/
5872 ,
5873 "__omp_rtl_assume_teams_oversubscription");
5874 ompBuilder->createGlobalFlag(
5875 attribute
5876 .getAssumeThreadsOversubscription() /*LangOpts().OpenMPThreadSubscription*/
5877 ,
5878 "__omp_rtl_assume_threads_oversubscription");
5879 ompBuilder->createGlobalFlag(
5880 attribute.getAssumeNoThreadState() /*LangOpts().OpenMPNoThreadState*/,
5881 "__omp_rtl_assume_no_thread_state");
5882 ompBuilder->createGlobalFlag(
5883 attribute
5884 .getAssumeNoNestedParallelism() /*LangOpts().OpenMPNoNestedParallelism*/
5885 ,
5886 "__omp_rtl_assume_no_nested_parallelism");
5887 return success();
5888}
5889
5890static void getTargetEntryUniqueInfo(llvm::TargetRegionEntryInfo &targetInfo,
5891 omp::TargetOp targetOp,
5892 llvm::StringRef parentName = "") {
5893 auto fileLoc = targetOp.getLoc()->findInstanceOf<FileLineColLoc>();
5894
5895 assert(fileLoc && "No file found from location");
5896 StringRef fileName = fileLoc.getFilename().getValue();
5897
5898 llvm::sys::fs::UniqueID id;
5899 uint64_t line = fileLoc.getLine();
5900 if (auto ec = llvm::sys::fs::getUniqueID(fileName, id)) {
5901 size_t fileHash = llvm::hash_value(fileName.str());
5902 size_t deviceId = 0xdeadf17e;
5903 targetInfo =
5904 llvm::TargetRegionEntryInfo(parentName, deviceId, fileHash, line);
5905 } else {
5906 targetInfo = llvm::TargetRegionEntryInfo(parentName, id.getDevice(),
5907 id.getFile(), line);
5908 }
5909}
5910
5911static void
5912handleDeclareTargetMapVar(MapInfoData &mapData,
5913 LLVM::ModuleTranslation &moduleTranslation,
5914 llvm::IRBuilderBase &builder, llvm::Function *func) {
5915 assert(moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
5916 "function only supported for target device codegen");
5917 llvm::IRBuilderBase::InsertPointGuard guard(builder);
5918 for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
5919 // In the case of declare target mapped variables, the basePointer is
5920 // the reference pointer generated by the convertDeclareTargetAttr
5921 // method. Whereas the kernelValue is the original variable, so for
5922 // the device we must replace all uses of this original global variable
5923 // (stored in kernelValue) with the reference pointer (stored in
5924 // basePointer for declare target mapped variables), as for device the
5925 // data is mapped into this reference pointer and should be loaded
5926 // from it, the original variable is discarded. On host both exist and
5927 // metadata is generated (elsewhere in the convertDeclareTargetAttr)
5928 // function to link the two variables in the runtime and then both the
5929 // reference pointer and the pointer are assigned in the kernel argument
5930 // structure for the host.
5931 if (mapData.IsDeclareTarget[i]) {
5932 // If the original map value is a constant, then we have to make sure all
5933 // of it's uses within the current kernel/function that we are going to
5934 // rewrite are converted to instructions, as we will be altering the old
5935 // use (OriginalValue) from a constant to an instruction, which will be
5936 // illegal and ICE the compiler if the user is a constant expression of
5937 // some kind e.g. a constant GEP.
5938 if (auto *constant = dyn_cast<llvm::Constant>(mapData.OriginalValue[i]))
5939 convertUsersOfConstantsToInstructions(constant, func, false);
5940
5941 // The users iterator will get invalidated if we modify an element,
5942 // so we populate this vector of uses to alter each user on an
5943 // individual basis to emit its own load (rather than one load for
5944 // all).
5946 for (llvm::User *user : mapData.OriginalValue[i]->users())
5947 userVec.push_back(user);
5948
5949 for (llvm::User *user : userVec) {
5950 if (auto *insn = dyn_cast<llvm::Instruction>(user)) {
5951 if (insn->getFunction() == func) {
5952 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
5953 llvm::Value *substitute = mapData.BasePointers[i];
5954 if (isDeclareTargetLink(mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr()
5955 : mapOp.getVarPtr())) {
5956 builder.SetCurrentDebugLocation(insn->getDebugLoc());
5957 substitute = builder.CreateLoad(
5958 mapData.BasePointers[i]->getType(), mapData.BasePointers[i]);
5959 cast<llvm::LoadInst>(substitute)->moveBefore(insn->getIterator());
5960 }
5961 user->replaceUsesOfWith(mapData.OriginalValue[i], substitute);
5962 }
5963 }
5964 }
5965 }
5966 }
5967}
5968
5969// The createDeviceArgumentAccessor function generates
5970// instructions for retrieving (acessing) kernel
5971// arguments inside of the device kernel for use by
5972// the kernel. This enables different semantics such as
5973// the creation of temporary copies of data allowing
5974// semantics like read-only/no host write back kernel
5975// arguments.
5976//
5977// This currently implements a very light version of Clang's
5978// EmitParmDecl's handling of direct argument handling as well
5979// as a portion of the argument access generation based on
5980// capture types found at the end of emitOutlinedFunctionPrologue
5981// in Clang. The indirect path handling of EmitParmDecl's may be
5982// required for future work, but a direct 1-to-1 copy doesn't seem
5983// possible as the logic is rather scattered throughout Clang's
5984// lowering and perhaps we wish to deviate slightly.
5985//
5986// \param mapData - A container containing vectors of information
5987// corresponding to the input argument, which should have a
5988// corresponding entry in the MapInfoData containers
5989// OrigialValue's.
5990// \param arg - This is the generated kernel function argument that
5991// corresponds to the passed in input argument. We generated different
5992// accesses of this Argument, based on capture type and other Input
5993// related information.
5994// \param input - This is the host side value that will be passed to
5995// the kernel i.e. the kernel input, we rewrite all uses of this within
5996// the kernel (as we generate the kernel body based on the target's region
5997// which maintians references to the original input) to the retVal argument
5998// apon exit of this function inside of the OMPIRBuilder. This interlinks
5999// the kernel argument to future uses of it in the function providing
6000// appropriate "glue" instructions inbetween.
6001// \param retVal - This is the value that all uses of input inside of the
6002// kernel will be re-written to, the goal of this function is to generate
6003// an appropriate location for the kernel argument to be accessed from,
6004// e.g. ByRef will result in a temporary allocation location and then
6005// a store of the kernel argument into this allocated memory which
6006// will then be loaded from, ByCopy will use the allocated memory
6007// directly.
6008static llvm::IRBuilderBase::InsertPoint
6009createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg,
6010 llvm::Value *input, llvm::Value *&retVal,
6011 llvm::IRBuilderBase &builder,
6012 llvm::OpenMPIRBuilder &ompBuilder,
6013 LLVM::ModuleTranslation &moduleTranslation,
6014 llvm::IRBuilderBase::InsertPoint allocaIP,
6015 llvm::IRBuilderBase::InsertPoint codeGenIP) {
6016 assert(ompBuilder.Config.isTargetDevice() &&
6017 "function only supported for target device codegen");
6018 builder.restoreIP(allocaIP);
6019
6020 omp::VariableCaptureKind capture = omp::VariableCaptureKind::ByRef;
6021 LLVM::TypeToLLVMIRTranslator typeToLLVMIRTranslator(
6022 ompBuilder.M.getContext());
6023 unsigned alignmentValue = 0;
6024 // Find the associated MapInfoData entry for the current input
6025 for (size_t i = 0; i < mapData.MapClause.size(); ++i)
6026 if (mapData.OriginalValue[i] == input) {
6027 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
6028 capture = mapOp.getMapCaptureType();
6029 // Get information of alignment of mapped object
6030 alignmentValue = typeToLLVMIRTranslator.getPreferredAlignment(
6031 mapOp.getVarType(), ompBuilder.M.getDataLayout());
6032 break;
6033 }
6034
6035 unsigned int allocaAS = ompBuilder.M.getDataLayout().getAllocaAddrSpace();
6036 unsigned int defaultAS =
6037 ompBuilder.M.getDataLayout().getProgramAddressSpace();
6038
6039 // Create the alloca for the argument the current point.
6040 llvm::Value *v = builder.CreateAlloca(arg.getType(), allocaAS);
6041
6042 if (allocaAS != defaultAS && arg.getType()->isPointerTy())
6043 v = builder.CreateAddrSpaceCast(v, builder.getPtrTy(defaultAS));
6044
6045 builder.CreateStore(&arg, v);
6046
6047 builder.restoreIP(codeGenIP);
6048
6049 switch (capture) {
6050 case omp::VariableCaptureKind::ByCopy: {
6051 retVal = v;
6052 break;
6053 }
6054 case omp::VariableCaptureKind::ByRef: {
6055 llvm::LoadInst *loadInst = builder.CreateAlignedLoad(
6056 v->getType(), v,
6057 ompBuilder.M.getDataLayout().getPrefTypeAlign(v->getType()));
6058 // CreateAlignedLoad function creates similar LLVM IR:
6059 // %res = load ptr, ptr %input, align 8
6060 // This LLVM IR does not contain information about alignment
6061 // of the loaded value. We need to add !align metadata to unblock
6062 // optimizer. The existence of the !align metadata on the instruction
6063 // tells the optimizer that the value loaded is known to be aligned to
6064 // a boundary specified by the integer value in the metadata node.
6065 // Example:
6066 // %res = load ptr, ptr %input, align 8, !align !align_md_node
6067 // ^ ^
6068 // | |
6069 // alignment of %input address |
6070 // |
6071 // alignment of %res object
6072 if (v->getType()->isPointerTy() && alignmentValue) {
6073 llvm::MDBuilder MDB(builder.getContext());
6074 loadInst->setMetadata(
6075 llvm::LLVMContext::MD_align,
6076 llvm::MDNode::get(builder.getContext(),
6077 MDB.createConstant(llvm::ConstantInt::get(
6078 llvm::Type::getInt64Ty(builder.getContext()),
6079 alignmentValue))));
6080 }
6081 retVal = loadInst;
6082
6083 break;
6084 }
6085 case omp::VariableCaptureKind::This:
6086 case omp::VariableCaptureKind::VLAType:
6087 // TODO: Consider returning error to use standard reporting for
6088 // unimplemented features.
6089 assert(false && "Currently unsupported capture kind");
6090 break;
6091 }
6092
6093 return builder.saveIP();
6094}
6095
6096/// Follow uses of `host_eval`-defined block arguments of the given `omp.target`
6097/// operation and populate output variables with their corresponding host value
6098/// (i.e. operand evaluated outside of the target region), based on their uses
6099/// inside of the target region.
6100///
6101/// Loop bounds and steps are only optionally populated, if output vectors are
6102/// provided.
6103static void
6104extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
6105 Value &numTeamsLower, Value &numTeamsUpper,
6106 Value &threadLimit,
6107 llvm::SmallVectorImpl<Value> *lowerBounds = nullptr,
6108 llvm::SmallVectorImpl<Value> *upperBounds = nullptr,
6109 llvm::SmallVectorImpl<Value> *steps = nullptr) {
6110 auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
6111 for (auto item : llvm::zip_equal(targetOp.getHostEvalVars(),
6112 blockArgIface.getHostEvalBlockArgs())) {
6113 Value hostEvalVar = std::get<0>(item), blockArg = std::get<1>(item);
6114
6115 for (Operation *user : blockArg.getUsers()) {
6117 .Case([&](omp::TeamsOp teamsOp) {
6118 if (teamsOp.getNumTeamsLower() == blockArg)
6119 numTeamsLower = hostEvalVar;
6120 else if (llvm::is_contained(teamsOp.getNumTeamsUpperVars(),
6121 blockArg))
6122 numTeamsUpper = hostEvalVar;
6123 else if (!teamsOp.getThreadLimitVars().empty() &&
6124 teamsOp.getThreadLimit(0) == blockArg)
6125 threadLimit = hostEvalVar;
6126 else
6127 llvm_unreachable("unsupported host_eval use");
6128 })
6129 .Case([&](omp::ParallelOp parallelOp) {
6130 if (!parallelOp.getNumThreadsVars().empty() &&
6131 parallelOp.getNumThreads(0) == blockArg)
6132 numThreads = hostEvalVar;
6133 else
6134 llvm_unreachable("unsupported host_eval use");
6135 })
6136 .Case([&](omp::LoopNestOp loopOp) {
6137 auto processBounds =
6138 [&](OperandRange opBounds,
6139 llvm::SmallVectorImpl<Value> *outBounds) -> bool {
6140 bool found = false;
6141 for (auto [i, lb] : llvm::enumerate(opBounds)) {
6142 if (lb == blockArg) {
6143 found = true;
6144 if (outBounds)
6145 (*outBounds)[i] = hostEvalVar;
6146 }
6147 }
6148 return found;
6149 };
6150 bool found =
6151 processBounds(loopOp.getLoopLowerBounds(), lowerBounds);
6152 found = processBounds(loopOp.getLoopUpperBounds(), upperBounds) ||
6153 found;
6154 found = processBounds(loopOp.getLoopSteps(), steps) || found;
6155 (void)found;
6156 assert(found && "unsupported host_eval use");
6157 })
6158 .DefaultUnreachable("unsupported host_eval use");
6159 }
6160 }
6161}
6162
6163/// If \p op is of the given type parameter, return it casted to that type.
6164/// Otherwise, if its immediate parent operation (or some other higher-level
6165/// parent, if \p immediateParent is false) is of that type, return that parent
6166/// casted to the given type.
6167///
6168/// If \p op is \c null or neither it or its parent(s) are of the specified
6169/// type, return a \c null operation.
6170template <typename OpTy>
6171static OpTy castOrGetParentOfType(Operation *op, bool immediateParent = false) {
6172 if (!op)
6173 return OpTy();
6174
6175 if (OpTy casted = dyn_cast<OpTy>(op))
6176 return casted;
6177
6178 if (immediateParent)
6179 return dyn_cast_if_present<OpTy>(op->getParentOp());
6180
6181 return op->getParentOfType<OpTy>();
6182}
6183
6184/// If the given \p value is defined by an \c llvm.mlir.constant operation and
6185/// it is of an integer type, return its value.
6186static std::optional<int64_t> extractConstInteger(Value value) {
6187 if (!value)
6188 return std::nullopt;
6189
6190 if (auto constOp = value.getDefiningOp<LLVM::ConstantOp>())
6191 if (auto constAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
6192 return constAttr.getInt();
6193
6194 return std::nullopt;
6195}
6196
6197static uint64_t getTypeByteSize(mlir::Type type, const DataLayout &dl) {
6198 uint64_t sizeInBits = dl.getTypeSizeInBits(type);
6199 uint64_t sizeInBytes = sizeInBits / 8;
6200 return sizeInBytes;
6201}
6202
6203template <typename OpTy>
6204static uint64_t getReductionDataSize(OpTy &op) {
6205 if (op.getNumReductionVars() > 0) {
6207 collectReductionDecls(op, reductions);
6208
6210 members.reserve(reductions.size());
6211 for (omp::DeclareReductionOp &red : reductions)
6212 members.push_back(red.getType());
6213 Operation *opp = op.getOperation();
6214 auto structType = mlir::LLVM::LLVMStructType::getLiteral(
6215 opp->getContext(), members, /*isPacked=*/false);
6216 DataLayout dl = DataLayout(opp->getParentOfType<ModuleOp>());
6217 return getTypeByteSize(structType, dl);
6218 }
6219 return 0;
6220}
6221
6222/// Populate default `MinTeams`, `MaxTeams` and `MaxThreads` to their default
6223/// values as stated by the corresponding clauses, if constant.
6224///
6225/// These default values must be set before the creation of the outlined LLVM
6226/// function for the target region, so that they can be used to initialize the
6227/// corresponding global `ConfigurationEnvironmentTy` structure.
6228static void
6229initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
6230 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs,
6231 bool isTargetDevice, bool isGPU) {
6232 // TODO: Handle constant 'if' clauses.
6233
6234 Value numThreads, numTeamsLower, numTeamsUpper, threadLimit;
6235 if (!isTargetDevice) {
6236 extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper,
6237 threadLimit);
6238 } else {
6239 // In the target device, values for these clauses are not passed as
6240 // host_eval, but instead evaluated prior to entry to the region. This
6241 // ensures values are mapped and available inside of the target region.
6242 if (auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
6243 numTeamsLower = teamsOp.getNumTeamsLower();
6244 // Handle num_teams upper bounds (only first value for now)
6245 if (!teamsOp.getNumTeamsUpperVars().empty())
6246 numTeamsUpper = teamsOp.getNumTeams(0);
6247 if (!teamsOp.getThreadLimitVars().empty())
6248 threadLimit = teamsOp.getThreadLimit(0);
6249 }
6250
6251 if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) {
6252 if (!parallelOp.getNumThreadsVars().empty())
6253 numThreads = parallelOp.getNumThreads(0);
6254 }
6255 }
6256
6257 // Handle clauses impacting the number of teams.
6258
6259 int32_t minTeamsVal = 1, maxTeamsVal = -1;
6260 if (castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
6261 // TODO: Use `hostNumTeamsLower` to initialize `minTeamsVal`. For now,
6262 // match clang and set min and max to the same value.
6263 if (numTeamsUpper) {
6264 if (auto val = extractConstInteger(numTeamsUpper))
6265 minTeamsVal = maxTeamsVal = *val;
6266 } else {
6267 minTeamsVal = maxTeamsVal = 0;
6268 }
6269 } else if (castOrGetParentOfType<omp::ParallelOp>(capturedOp,
6270 /*immediateParent=*/true) ||
6272 /*immediateParent=*/true)) {
6273 minTeamsVal = maxTeamsVal = 1;
6274 } else {
6275 minTeamsVal = maxTeamsVal = -1;
6276 }
6277
6278 // Handle clauses impacting the number of threads.
6279
6280 auto setMaxValueFromClause = [](Value clauseValue, int32_t &result) {
6281 if (!clauseValue)
6282 return;
6283
6284 if (auto val = extractConstInteger(clauseValue))
6285 result = *val;
6286
6287 // Found an applicable clause, so it's not undefined. Mark as unknown
6288 // because it's not constant.
6289 if (result < 0)
6290 result = 0;
6291 };
6292
6293 // Extract 'thread_limit' clause from 'target' and 'teams' directives.
6294 int32_t targetThreadLimitVal = -1, teamsThreadLimitVal = -1;
6295 if (!targetOp.getThreadLimitVars().empty())
6296 setMaxValueFromClause(targetOp.getThreadLimit(0), targetThreadLimitVal);
6297 setMaxValueFromClause(threadLimit, teamsThreadLimitVal);
6298
6299 // Extract 'max_threads' clause from 'parallel' or set to 1 if it's SIMD.
6300 int32_t maxThreadsVal = -1;
6302 setMaxValueFromClause(numThreads, maxThreadsVal);
6303 else if (castOrGetParentOfType<omp::SimdOp>(capturedOp,
6304 /*immediateParent=*/true))
6305 maxThreadsVal = 1;
6306
6307 // For max values, < 0 means unset, == 0 means set but unknown. Select the
6308 // minimum value between 'max_threads' and 'thread_limit' clauses that were
6309 // set.
6310 int32_t combinedMaxThreadsVal = targetThreadLimitVal;
6311 if (combinedMaxThreadsVal < 0 ||
6312 (teamsThreadLimitVal >= 0 && teamsThreadLimitVal < combinedMaxThreadsVal))
6313 combinedMaxThreadsVal = teamsThreadLimitVal;
6314
6315 if (combinedMaxThreadsVal < 0 ||
6316 (maxThreadsVal >= 0 && maxThreadsVal < combinedMaxThreadsVal))
6317 combinedMaxThreadsVal = maxThreadsVal;
6318
6319 int32_t reductionDataSize = 0;
6320 if (isGPU && capturedOp) {
6321 if (auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp))
6322 reductionDataSize = getReductionDataSize(teamsOp);
6323 }
6324
6325 // Update kernel bounds structure for the `OpenMPIRBuilder` to use.
6326 omp::TargetRegionFlags kernelFlags = targetOp.getKernelExecFlags(capturedOp);
6327 assert(
6328 omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic |
6329 omp::TargetRegionFlags::spmd) &&
6330 "invalid kernel flags");
6331 attrs.ExecFlags =
6332 omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic)
6333 ? omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::spmd)
6334 ? llvm::omp::OMP_TGT_EXEC_MODE_GENERIC_SPMD
6335 : llvm::omp::OMP_TGT_EXEC_MODE_GENERIC
6336 : llvm::omp::OMP_TGT_EXEC_MODE_SPMD;
6337 if (omp::bitEnumContainsAll(kernelFlags,
6338 omp::TargetRegionFlags::spmd |
6339 omp::TargetRegionFlags::no_loop) &&
6340 !omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic))
6341 attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_SPMD_NO_LOOP;
6342
6343 attrs.MinTeams = minTeamsVal;
6344 attrs.MaxTeams.front() = maxTeamsVal;
6345 attrs.MinThreads = 1;
6346 attrs.MaxThreads.front() = combinedMaxThreadsVal;
6347 attrs.ReductionDataSize = reductionDataSize;
6348 // TODO: Allow modified buffer length similar to
6349 // fopenmp-cuda-teams-reduction-recs-num flag in clang.
6350 if (attrs.ReductionDataSize != 0)
6351 attrs.ReductionBufferLength = 1024;
6352}
6353
6354/// Gather LLVM runtime values for all clauses evaluated in the host that are
6355/// passed to the kernel invocation.
6356///
6357/// This function must be called only when compiling for the host. Also, it will
6358/// only provide correct results if it's called after the body of \c targetOp
6359/// has been fully generated.
6360static void
6361initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
6362 LLVM::ModuleTranslation &moduleTranslation,
6363 omp::TargetOp targetOp, Operation *capturedOp,
6364 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs) {
6365 omp::LoopNestOp loopOp = castOrGetParentOfType<omp::LoopNestOp>(capturedOp);
6366 unsigned numLoops = loopOp ? loopOp.getNumLoops() : 0;
6367
6368 Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
6369 llvm::SmallVector<Value> lowerBounds(numLoops), upperBounds(numLoops),
6370 steps(numLoops);
6371 extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper,
6372 teamsThreadLimit, &lowerBounds, &upperBounds, &steps);
6373
6374 // TODO: Handle constant 'if' clauses.
6375 if (!targetOp.getThreadLimitVars().empty()) {
6376 Value targetThreadLimit = targetOp.getThreadLimit(0);
6377 attrs.TargetThreadLimit.front() =
6378 moduleTranslation.lookupValue(targetThreadLimit);
6379 }
6380
6381 // The __kmpc_push_num_teams_51 function expects int32 as the arguments. So,
6382 // truncate or sign extend lower and upper num_teams bounds as well as
6383 // thread_limit to match int32 ABI requirements for the OpenMP runtime.
6384 if (numTeamsLower)
6385 attrs.MinTeams = builder.CreateSExtOrTrunc(
6386 moduleTranslation.lookupValue(numTeamsLower), builder.getInt32Ty());
6387
6388 if (numTeamsUpper)
6389 attrs.MaxTeams.front() = builder.CreateSExtOrTrunc(
6390 moduleTranslation.lookupValue(numTeamsUpper), builder.getInt32Ty());
6391
6392 if (teamsThreadLimit)
6393 attrs.TeamsThreadLimit.front() = builder.CreateSExtOrTrunc(
6394 moduleTranslation.lookupValue(teamsThreadLimit), builder.getInt32Ty());
6395
6396 if (numThreads)
6397 attrs.MaxThreads = moduleTranslation.lookupValue(numThreads);
6398
6399 if (omp::bitEnumContainsAny(targetOp.getKernelExecFlags(capturedOp),
6400 omp::TargetRegionFlags::trip_count)) {
6401 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
6402 attrs.LoopTripCount = nullptr;
6403
6404 // To calculate the trip count, we multiply together the trip counts of
6405 // every collapsed canonical loop. We don't need to create the loop nests
6406 // here, since we're only interested in the trip count.
6407 for (auto [loopLower, loopUpper, loopStep] :
6408 llvm::zip_equal(lowerBounds, upperBounds, steps)) {
6409 llvm::Value *lowerBound = moduleTranslation.lookupValue(loopLower);
6410 llvm::Value *upperBound = moduleTranslation.lookupValue(loopUpper);
6411 llvm::Value *step = moduleTranslation.lookupValue(loopStep);
6412
6413 if (!lowerBound || !upperBound || !step) {
6414 attrs.LoopTripCount = nullptr;
6415 break;
6416 }
6417
6418 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
6419 llvm::Value *tripCount = ompBuilder->calculateCanonicalLoopTripCount(
6420 loc, lowerBound, upperBound, step, /*IsSigned=*/true,
6421 loopOp.getLoopInclusive());
6422
6423 if (!attrs.LoopTripCount) {
6424 attrs.LoopTripCount = tripCount;
6425 continue;
6426 }
6427
6428 // TODO: Enable UndefinedSanitizer to diagnose an overflow here.
6429 attrs.LoopTripCount = builder.CreateMul(attrs.LoopTripCount, tripCount,
6430 {}, /*HasNUW=*/true);
6431 }
6432 }
6433
6434 attrs.DeviceID = builder.getInt64(llvm::omp::OMP_DEVICEID_UNDEF);
6435 if (mlir::Value devId = targetOp.getDevice()) {
6436 attrs.DeviceID = moduleTranslation.lookupValue(devId);
6437 attrs.DeviceID =
6438 builder.CreateSExtOrTrunc(attrs.DeviceID, builder.getInt64Ty());
6439 }
6440}
6441
6442static LogicalResult
6443convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
6444 LLVM::ModuleTranslation &moduleTranslation) {
6445 auto targetOp = cast<omp::TargetOp>(opInst);
6446 // The current debug location already has the DISubprogram for the outlined
6447 // function that will be created for the target op. We save it here so that
6448 // we can set it on the outlined function.
6449 llvm::DebugLoc outlinedFnLoc = builder.getCurrentDebugLocation();
6450 if (failed(checkImplementationStatus(opInst)))
6451 return failure();
6452
6453 // During the handling of target op, we will generate instructions in the
6454 // parent function like call to the oulined function or branch to a new
6455 // BasicBlock. We set the debug location here to parent function so that those
6456 // get the correct debug locations. For outlined functions, the normal MLIR op
6457 // conversion will automatically pick the correct location.
6458 llvm::BasicBlock *parentBB = builder.GetInsertBlock();
6459 assert(parentBB && "No insert block is set for the builder");
6460 llvm::Function *parentLLVMFn = parentBB->getParent();
6461 assert(parentLLVMFn && "Parent Function must be valid");
6462 if (llvm::DISubprogram *SP = parentLLVMFn->getSubprogram())
6463 builder.SetCurrentDebugLocation(llvm::DILocation::get(
6464 parentLLVMFn->getContext(), outlinedFnLoc.getLine(),
6465 outlinedFnLoc.getCol(), SP, outlinedFnLoc.getInlinedAt()));
6466
6467 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
6468 bool isTargetDevice = ompBuilder->Config.isTargetDevice();
6469 bool isGPU = ompBuilder->Config.isGPU();
6470
6471 auto parentFn = opInst.getParentOfType<LLVM::LLVMFuncOp>();
6472 auto argIface = cast<omp::BlockArgOpenMPOpInterface>(opInst);
6473 auto &targetRegion = targetOp.getRegion();
6474 // Holds the private vars that have been mapped along with the block
6475 // argument that corresponds to the MapInfoOp corresponding to the private
6476 // var in question. So, for instance:
6477 //
6478 // %10 = omp.map.info var_ptr(%6#0 : !fir.ref<!fir.box<!fir.heap<i32>>>, ..)
6479 // omp.target map_entries(%10 -> %arg0) private(@box.privatizer %6#0-> %arg1)
6480 //
6481 // Then, %10 has been created so that the descriptor can be used by the
6482 // privatizer @box.privatizer on the device side. Here we'd record {%6#0,
6483 // %arg0} in the mappedPrivateVars map.
6484 llvm::DenseMap<Value, Value> mappedPrivateVars;
6485 DataLayout dl = DataLayout(opInst.getParentOfType<ModuleOp>());
6486 SmallVector<Value> mapVars = targetOp.getMapVars();
6487 SmallVector<Value> hdaVars = targetOp.getHasDeviceAddrVars();
6488 ArrayRef<BlockArgument> mapBlockArgs = argIface.getMapBlockArgs();
6489 ArrayRef<BlockArgument> hdaBlockArgs = argIface.getHasDeviceAddrBlockArgs();
6490 llvm::Function *llvmOutlinedFn = nullptr;
6491 TargetDirectiveEnumTy targetDirective =
6492 getTargetDirectiveEnumTyFromOp(&opInst);
6493
6494 // TODO: It can also be false if a compile-time constant `false` IF clause is
6495 // specified.
6496 bool isOffloadEntry =
6497 isTargetDevice || !ompBuilder->Config.TargetTriples.empty();
6498
6499 // For some private variables, the MapsForPrivatizedVariablesPass
6500 // creates MapInfoOp instances. Go through the private variables and
6501 // the mapped variables so that during codegeneration we are able
6502 // to quickly look up the corresponding map variable, if any for each
6503 // private variable.
6504 if (!targetOp.getPrivateVars().empty() && !targetOp.getMapVars().empty()) {
6505 OperandRange privateVars = targetOp.getPrivateVars();
6506 std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms();
6507 std::optional<DenseI64ArrayAttr> privateMapIndices =
6508 targetOp.getPrivateMapsAttr();
6509
6510 for (auto [privVarIdx, privVarSymPair] :
6511 llvm::enumerate(llvm::zip_equal(privateVars, *privateSyms))) {
6512 auto privVar = std::get<0>(privVarSymPair);
6513 auto privSym = std::get<1>(privVarSymPair);
6514
6515 SymbolRefAttr privatizerName = llvm::cast<SymbolRefAttr>(privSym);
6516 omp::PrivateClauseOp privatizer =
6517 findPrivatizer(targetOp, privatizerName);
6518
6519 if (!privatizer.needsMap())
6520 continue;
6521
6522 mlir::Value mappedValue =
6523 targetOp.getMappedValueForPrivateVar(privVarIdx);
6524 assert(mappedValue && "Expected to find mapped value for a privatized "
6525 "variable that needs mapping");
6526
6527 // The MapInfoOp defining the map var isn't really needed later.
6528 // So, we don't store it in any datastructure. Instead, we just
6529 // do some sanity checks on it right now.
6530 auto mapInfoOp = mappedValue.getDefiningOp<omp::MapInfoOp>();
6531 [[maybe_unused]] Type varType = mapInfoOp.getVarType();
6532
6533 // Check #1: Check that the type of the private variable matches
6534 // the type of the variable being mapped.
6535 if (!isa<LLVM::LLVMPointerType>(privVar.getType()))
6536 assert(
6537 varType == privVar.getType() &&
6538 "Type of private var doesn't match the type of the mapped value");
6539
6540 // Ok, only 1 sanity check for now.
6541 // Record the block argument corresponding to this mapvar.
6542 mappedPrivateVars.insert(
6543 {privVar,
6544 targetRegion.getArgument(argIface.getMapBlockArgsStart() +
6545 (*privateMapIndices)[privVarIdx])});
6546 }
6547 }
6548
6549 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
6550 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP)
6551 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
6552 llvm::IRBuilderBase::InsertPointGuard guard(builder);
6553 builder.SetCurrentDebugLocation(llvm::DebugLoc());
6554 // Forward target-cpu and target-features function attributes from the
6555 // original function to the new outlined function.
6556 llvm::Function *llvmParentFn =
6557 moduleTranslation.lookupFunction(parentFn.getName());
6558 llvmOutlinedFn = codeGenIP.getBlock()->getParent();
6559 assert(llvmParentFn && llvmOutlinedFn &&
6560 "Both parent and outlined functions must exist at this point");
6561
6562 if (outlinedFnLoc && llvmParentFn->getSubprogram())
6563 llvmOutlinedFn->setSubprogram(outlinedFnLoc->getScope()->getSubprogram());
6564
6565 if (auto attr = llvmParentFn->getFnAttribute("target-cpu");
6566 attr.isStringAttribute())
6567 llvmOutlinedFn->addFnAttr(attr);
6568
6569 if (auto attr = llvmParentFn->getFnAttribute("target-features");
6570 attr.isStringAttribute())
6571 llvmOutlinedFn->addFnAttr(attr);
6572
6573 for (auto [arg, mapOp] : llvm::zip_equal(mapBlockArgs, mapVars)) {
6574 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
6575 llvm::Value *mapOpValue =
6576 moduleTranslation.lookupValue(mapInfoOp.getVarPtr());
6577 moduleTranslation.mapValue(arg, mapOpValue);
6578 }
6579 for (auto [arg, mapOp] : llvm::zip_equal(hdaBlockArgs, hdaVars)) {
6580 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
6581 llvm::Value *mapOpValue =
6582 moduleTranslation.lookupValue(mapInfoOp.getVarPtr());
6583 moduleTranslation.mapValue(arg, mapOpValue);
6584 }
6585
6586 // Do privatization after moduleTranslation has already recorded
6587 // mapped values.
6588 PrivateVarsInfo privateVarsInfo(targetOp);
6589
6591 allocatePrivateVars(builder, moduleTranslation, privateVarsInfo,
6592 allocaIP, &mappedPrivateVars);
6593
6594 if (failed(handleError(afterAllocas, *targetOp)))
6595 return llvm::make_error<PreviouslyReportedError>();
6596
6597 builder.restoreIP(codeGenIP);
6598 if (handleError(initPrivateVars(builder, moduleTranslation, privateVarsInfo,
6599 &mappedPrivateVars),
6600 *targetOp)
6601 .failed())
6602 return llvm::make_error<PreviouslyReportedError>();
6603
6604 if (failed(copyFirstPrivateVars(
6605 targetOp, builder, moduleTranslation, privateVarsInfo.mlirVars,
6606 privateVarsInfo.llvmVars, privateVarsInfo.privatizers,
6607 targetOp.getPrivateNeedsBarrier(), &mappedPrivateVars)))
6608 return llvm::make_error<PreviouslyReportedError>();
6609
6610 SmallVector<Region *> privateCleanupRegions;
6611 llvm::transform(privateVarsInfo.privatizers,
6612 std::back_inserter(privateCleanupRegions),
6613 [](omp::PrivateClauseOp privatizer) {
6614 return &privatizer.getDeallocRegion();
6615 });
6616
6618 targetRegion, "omp.target", builder, moduleTranslation);
6619
6620 if (!exitBlock)
6621 return exitBlock.takeError();
6622
6623 builder.SetInsertPoint(*exitBlock);
6624 if (!privateCleanupRegions.empty()) {
6625 if (failed(inlineOmpRegionCleanup(
6626 privateCleanupRegions, privateVarsInfo.llvmVars,
6627 moduleTranslation, builder, "omp.targetop.private.cleanup",
6628 /*shouldLoadCleanupRegionArg=*/false))) {
6629 return llvm::createStringError(
6630 "failed to inline `dealloc` region of `omp.private` "
6631 "op in the target region");
6632 }
6633 return builder.saveIP();
6634 }
6635
6636 return InsertPointTy(exitBlock.get(), exitBlock.get()->end());
6637 };
6638
6639 StringRef parentName = parentFn.getName();
6640
6641 llvm::TargetRegionEntryInfo entryInfo;
6642
6643 getTargetEntryUniqueInfo(entryInfo, targetOp, parentName);
6644
6645 MapInfoData mapData;
6646 collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, dl,
6647 builder, /*useDevPtrOperands=*/{},
6648 /*useDevAddrOperands=*/{}, hdaVars);
6649
6650 MapInfosTy combinedInfos;
6651 auto genMapInfoCB =
6652 [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP) -> MapInfosTy & {
6653 builder.restoreIP(codeGenIP);
6654 genMapInfos(builder, moduleTranslation, dl, combinedInfos, mapData,
6655 targetDirective);
6656 return combinedInfos;
6657 };
6658
6659 auto argAccessorCB = [&](llvm::Argument &arg, llvm::Value *input,
6660 llvm::Value *&retVal, InsertPointTy allocaIP,
6661 InsertPointTy codeGenIP)
6662 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
6663 llvm::IRBuilderBase::InsertPointGuard guard(builder);
6664 builder.SetCurrentDebugLocation(llvm::DebugLoc());
6665 // We just return the unaltered argument for the host function
6666 // for now, some alterations may be required in the future to
6667 // keep host fallback functions working identically to the device
6668 // version (e.g. pass ByCopy values should be treated as such on
6669 // host and device, currently not always the case)
6670 if (!isTargetDevice) {
6671 retVal = cast<llvm::Value>(&arg);
6672 return codeGenIP;
6673 }
6674
6675 return createDeviceArgumentAccessor(mapData, arg, input, retVal, builder,
6676 *ompBuilder, moduleTranslation,
6677 allocaIP, codeGenIP);
6678 };
6679
6680 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs runtimeAttrs;
6681 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs;
6682 Operation *targetCapturedOp = targetOp.getInnermostCapturedOmpOp();
6683 initTargetDefaultAttrs(targetOp, targetCapturedOp, defaultAttrs,
6684 isTargetDevice, isGPU);
6685
6686 // Collect host-evaluated values needed to properly launch the kernel from the
6687 // host.
6688 if (!isTargetDevice)
6689 initTargetRuntimeAttrs(builder, moduleTranslation, targetOp,
6690 targetCapturedOp, runtimeAttrs);
6691
6692 // Pass host-evaluated values as parameters to the kernel / host fallback,
6693 // except if they are constants. In any case, map the MLIR block argument to
6694 // the corresponding LLVM values.
6696 SmallVector<Value> hostEvalVars = targetOp.getHostEvalVars();
6697 ArrayRef<BlockArgument> hostEvalBlockArgs = argIface.getHostEvalBlockArgs();
6698 for (auto [arg, var] : llvm::zip_equal(hostEvalBlockArgs, hostEvalVars)) {
6699 llvm::Value *value = moduleTranslation.lookupValue(var);
6700 moduleTranslation.mapValue(arg, value);
6701
6702 if (!llvm::isa<llvm::Constant>(value))
6703 kernelInput.push_back(value);
6704 }
6705
6706 for (size_t i = 0, e = mapData.OriginalValue.size(); i != e; ++i) {
6707 // declare target arguments are not passed to kernels as arguments
6708 // TODO: We currently do not handle cases where a member is explicitly
6709 // passed in as an argument, this will likley need to be handled in
6710 // the near future, rather than using IsAMember, it may be better to
6711 // test if the relevant BlockArg is used within the target region and
6712 // then use that as a basis for exclusion in the kernel inputs.
6713 if (!mapData.IsDeclareTarget[i] && !mapData.IsAMember[i])
6714 kernelInput.push_back(mapData.OriginalValue[i]);
6715 }
6716
6718 buildDependData(targetOp.getDependKinds(), targetOp.getDependVars(),
6719 moduleTranslation, dds);
6720
6721 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
6722 findAllocaInsertPoint(builder, moduleTranslation);
6723 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
6724
6725 llvm::OpenMPIRBuilder::TargetDataInfo info(
6726 /*RequiresDevicePointerInfo=*/false,
6727 /*SeparateBeginEndCalls=*/true);
6728
6729 auto customMapperCB =
6730 [&](unsigned int i) -> llvm::Expected<llvm::Function *> {
6731 if (!combinedInfos.Mappers[i])
6732 return nullptr;
6733 info.HasMapper = true;
6734 return getOrCreateUserDefinedMapperFunc(combinedInfos.Mappers[i], builder,
6735 moduleTranslation, targetDirective);
6736 };
6737
6738 llvm::Value *ifCond = nullptr;
6739 if (Value targetIfCond = targetOp.getIfExpr())
6740 ifCond = moduleTranslation.lookupValue(targetIfCond);
6741
6742 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
6743 moduleTranslation.getOpenMPBuilder()->createTarget(
6744 ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), info, entryInfo,
6745 defaultAttrs, runtimeAttrs, ifCond, kernelInput, genMapInfoCB, bodyCB,
6746 argAccessorCB, customMapperCB, dds, targetOp.getNowait());
6747
6748 if (failed(handleError(afterIP, opInst)))
6749 return failure();
6750
6751 builder.restoreIP(*afterIP);
6752
6753 // Remap access operations to declare target reference pointers for the
6754 // device, essentially generating extra loadop's as necessary
6755 if (moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice())
6756 handleDeclareTargetMapVar(mapData, moduleTranslation, builder,
6757 llvmOutlinedFn);
6758
6759 return success();
6760}
6761
6762static LogicalResult
6763convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute,
6764 LLVM::ModuleTranslation &moduleTranslation) {
6765 // Amend omp.declare_target by deleting the IR of the outlined functions
6766 // created for target regions. They cannot be filtered out from MLIR earlier
6767 // because the omp.target operation inside must be translated to LLVM, but
6768 // the wrapper functions themselves must not remain at the end of the
6769 // process. We know that functions where omp.declare_target does not match
6770 // omp.is_target_device at this stage can only be wrapper functions because
6771 // those that aren't are removed earlier as an MLIR transformation pass.
6772 if (FunctionOpInterface funcOp = dyn_cast<FunctionOpInterface>(op)) {
6773 if (auto offloadMod = dyn_cast<omp::OffloadModuleInterface>(
6774 op->getParentOfType<ModuleOp>().getOperation())) {
6775 if (!offloadMod.getIsTargetDevice())
6776 return success();
6777
6778 omp::DeclareTargetDeviceType declareType =
6779 attribute.getDeviceType().getValue();
6780
6781 if (declareType == omp::DeclareTargetDeviceType::host) {
6782 llvm::Function *llvmFunc =
6783 moduleTranslation.lookupFunction(funcOp.getName());
6784 llvmFunc->dropAllReferences();
6785 llvmFunc->eraseFromParent();
6786 }
6787 }
6788 return success();
6789 }
6790
6791 if (LLVM::GlobalOp gOp = dyn_cast<LLVM::GlobalOp>(op)) {
6792 llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
6793 if (auto *gVal = llvmModule->getNamedValue(gOp.getSymName())) {
6794 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
6795 bool isDeclaration = gOp.isDeclaration();
6796 bool isExternallyVisible =
6797 gOp.getVisibility() != mlir::SymbolTable::Visibility::Private;
6798 auto loc = op->getLoc()->findInstanceOf<FileLineColLoc>();
6799 llvm::StringRef mangledName = gOp.getSymName();
6800 auto captureClause =
6801 convertToCaptureClauseKind(attribute.getCaptureClause().getValue());
6802 auto deviceClause =
6803 convertToDeviceClauseKind(attribute.getDeviceType().getValue());
6804 // unused for MLIR at the moment, required in Clang for book
6805 // keeping
6806 std::vector<llvm::GlobalVariable *> generatedRefs;
6807
6808 std::vector<llvm::Triple> targetTriple;
6809 auto targetTripleAttr = dyn_cast_or_null<mlir::StringAttr>(
6810 op->getParentOfType<mlir::ModuleOp>()->getAttr(
6811 LLVM::LLVMDialect::getTargetTripleAttrName()));
6812 if (targetTripleAttr)
6813 targetTriple.emplace_back(targetTripleAttr.data());
6814
6815 auto fileInfoCallBack = [&loc]() {
6816 std::string filename = "";
6817 std::uint64_t lineNo = 0;
6818
6819 if (loc) {
6820 filename = loc.getFilename().str();
6821 lineNo = loc.getLine();
6822 }
6823
6824 return std::pair<std::string, std::uint64_t>(llvm::StringRef(filename),
6825 lineNo);
6826 };
6827
6828 auto vfs = llvm::vfs::getRealFileSystem();
6829
6830 ompBuilder->registerTargetGlobalVariable(
6831 captureClause, deviceClause, isDeclaration, isExternallyVisible,
6832 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack, *vfs),
6833 mangledName, generatedRefs, /*OpenMPSimd*/ false, targetTriple,
6834 /*GlobalInitializer*/ nullptr, /*VariableLinkage*/ nullptr,
6835 gVal->getType(), gVal);
6836
6837 if (ompBuilder->Config.isTargetDevice() &&
6838 (attribute.getCaptureClause().getValue() !=
6839 mlir::omp::DeclareTargetCaptureClause::to ||
6840 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
6841 ompBuilder->getAddrOfDeclareTargetVar(
6842 captureClause, deviceClause, isDeclaration, isExternallyVisible,
6843 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack, *vfs),
6844 mangledName, generatedRefs, /*OpenMPSimd*/ false, targetTriple,
6845 gVal->getType(), /*GlobalInitializer*/ nullptr,
6846 /*VariableLinkage*/ nullptr);
6847 }
6848 }
6849 }
6850
6851 return success();
6852}
6853
6854namespace {
6855
6856/// Implementation of the dialect interface that converts operations belonging
6857/// to the OpenMP dialect to LLVM IR.
6858class OpenMPDialectLLVMIRTranslationInterface
6859 : public LLVMTranslationDialectInterface {
6860public:
6862
6863 /// Translates the given operation to LLVM IR using the provided IR builder
6864 /// and saving the state in `moduleTranslation`.
6865 LogicalResult
6866 convertOperation(Operation *op, llvm::IRBuilderBase &builder,
6867 LLVM::ModuleTranslation &moduleTranslation) const final;
6868
6869 /// Given an OpenMP MLIR attribute, create the corresponding LLVM-IR,
6870 /// runtime calls, or operation amendments
6871 LogicalResult
6872 amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
6873 NamedAttribute attribute,
6874 LLVM::ModuleTranslation &moduleTranslation) const final;
6875};
6876
6877} // namespace
6878
6879LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
6880 Operation *op, ArrayRef<llvm::Instruction *> instructions,
6881 NamedAttribute attribute,
6882 LLVM::ModuleTranslation &moduleTranslation) const {
6883 return llvm::StringSwitch<llvm::function_ref<LogicalResult(Attribute)>>(
6884 attribute.getName())
6885 .Case("omp.is_target_device",
6886 [&](Attribute attr) {
6887 if (auto deviceAttr = dyn_cast<BoolAttr>(attr)) {
6888 llvm::OpenMPIRBuilderConfig &config =
6889 moduleTranslation.getOpenMPBuilder()->Config;
6890 config.setIsTargetDevice(deviceAttr.getValue());
6891 return success();
6892 }
6893 return failure();
6894 })
6895 .Case("omp.is_gpu",
6896 [&](Attribute attr) {
6897 if (auto gpuAttr = dyn_cast<BoolAttr>(attr)) {
6898 llvm::OpenMPIRBuilderConfig &config =
6899 moduleTranslation.getOpenMPBuilder()->Config;
6900 config.setIsGPU(gpuAttr.getValue());
6901 return success();
6902 }
6903 return failure();
6904 })
6905 .Case("omp.host_ir_filepath",
6906 [&](Attribute attr) {
6907 if (auto filepathAttr = dyn_cast<StringAttr>(attr)) {
6908 llvm::OpenMPIRBuilder *ompBuilder =
6909 moduleTranslation.getOpenMPBuilder();
6910 auto VFS = llvm::vfs::getRealFileSystem();
6911 ompBuilder->loadOffloadInfoMetadata(*VFS,
6912 filepathAttr.getValue());
6913 return success();
6914 }
6915 return failure();
6916 })
6917 .Case("omp.flags",
6918 [&](Attribute attr) {
6919 if (auto rtlAttr = dyn_cast<omp::FlagsAttr>(attr))
6920 return convertFlagsAttr(op, rtlAttr, moduleTranslation);
6921 return failure();
6922 })
6923 .Case("omp.version",
6924 [&](Attribute attr) {
6925 if (auto versionAttr = dyn_cast<omp::VersionAttr>(attr)) {
6926 llvm::OpenMPIRBuilder *ompBuilder =
6927 moduleTranslation.getOpenMPBuilder();
6928 ompBuilder->M.addModuleFlag(llvm::Module::Max, "openmp",
6929 versionAttr.getVersion());
6930 return success();
6931 }
6932 return failure();
6933 })
6934 .Case("omp.declare_target",
6935 [&](Attribute attr) {
6936 if (auto declareTargetAttr =
6937 dyn_cast<omp::DeclareTargetAttr>(attr))
6938 return convertDeclareTargetAttr(op, declareTargetAttr,
6939 moduleTranslation);
6940 return failure();
6941 })
6942 .Case("omp.requires",
6943 [&](Attribute attr) {
6944 if (auto requiresAttr = dyn_cast<omp::ClauseRequiresAttr>(attr)) {
6945 using Requires = omp::ClauseRequires;
6946 Requires flags = requiresAttr.getValue();
6947 llvm::OpenMPIRBuilderConfig &config =
6948 moduleTranslation.getOpenMPBuilder()->Config;
6949 config.setHasRequiresReverseOffload(
6950 bitEnumContainsAll(flags, Requires::reverse_offload));
6951 config.setHasRequiresUnifiedAddress(
6952 bitEnumContainsAll(flags, Requires::unified_address));
6953 config.setHasRequiresUnifiedSharedMemory(
6954 bitEnumContainsAll(flags, Requires::unified_shared_memory));
6955 config.setHasRequiresDynamicAllocators(
6956 bitEnumContainsAll(flags, Requires::dynamic_allocators));
6957 return success();
6958 }
6959 return failure();
6960 })
6961 .Case("omp.target_triples",
6962 [&](Attribute attr) {
6963 if (auto triplesAttr = dyn_cast<ArrayAttr>(attr)) {
6964 llvm::OpenMPIRBuilderConfig &config =
6965 moduleTranslation.getOpenMPBuilder()->Config;
6966 config.TargetTriples.clear();
6967 config.TargetTriples.reserve(triplesAttr.size());
6968 for (Attribute tripleAttr : triplesAttr) {
6969 if (auto tripleStrAttr = dyn_cast<StringAttr>(tripleAttr))
6970 config.TargetTriples.emplace_back(tripleStrAttr.getValue());
6971 else
6972 return failure();
6973 }
6974 return success();
6975 }
6976 return failure();
6977 })
6978 .Default([](Attribute) {
6979 // Fall through for omp attributes that do not require lowering.
6980 return success();
6981 })(attribute.getValue());
6982
6983 return failure();
6984}
6985
6986// Returns true if the operation is not inside a TargetOp, it is part of a
6987// function and that function is not declare target.
6988static bool isHostDeviceOp(Operation *op) {
6989 // Assumes no reverse offloading
6990 if (op->getParentOfType<omp::TargetOp>())
6991 return false;
6992
6993 if (auto parentFn = op->getParentOfType<LLVM::LLVMFuncOp>()) {
6994 if (auto declareTargetIface =
6995 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
6996 parentFn.getOperation()))
6997 if (declareTargetIface.isDeclareTarget() &&
6998 declareTargetIface.getDeclareTargetDeviceType() !=
6999 mlir::omp::DeclareTargetDeviceType::host)
7000 return false;
7001
7002 return true;
7003 }
7004
7005 return false;
7006}
7007
7008static llvm::Function *getOmpTargetAlloc(llvm::IRBuilderBase &builder,
7009 llvm::Module *llvmModule) {
7010 llvm::Type *i64Ty = builder.getInt64Ty();
7011 llvm::Type *i32Ty = builder.getInt32Ty();
7012 llvm::Type *returnType = builder.getPtrTy(0);
7013 llvm::FunctionType *fnType =
7014 llvm::FunctionType::get(returnType, {i64Ty, i32Ty}, false);
7015 llvm::Function *func = cast<llvm::Function>(
7016 llvmModule->getOrInsertFunction("omp_target_alloc", fnType).getCallee());
7017 return func;
7018}
7019
7020static LogicalResult
7021convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder,
7022 LLVM::ModuleTranslation &moduleTranslation) {
7023 auto allocMemOp = cast<omp::TargetAllocMemOp>(opInst);
7024 if (!allocMemOp)
7025 return failure();
7026
7027 // Get "omp_target_alloc" function
7028 llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
7029 llvm::Function *ompTargetAllocFunc = getOmpTargetAlloc(builder, llvmModule);
7030 // Get the corresponding device value in llvm
7031 mlir::Value deviceNum = allocMemOp.getDevice();
7032 llvm::Value *llvmDeviceNum = moduleTranslation.lookupValue(deviceNum);
7033 // Get the allocation size.
7034 llvm::DataLayout dataLayout = llvmModule->getDataLayout();
7035 mlir::Type heapTy = allocMemOp.getAllocatedType();
7036 llvm::Type *llvmHeapTy = moduleTranslation.convertType(heapTy);
7037 llvm::TypeSize typeSize = dataLayout.getTypeStoreSize(llvmHeapTy);
7038 llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue());
7039 for (auto typeParam : allocMemOp.getTypeparams())
7040 allocSize =
7041 builder.CreateMul(allocSize, moduleTranslation.lookupValue(typeParam));
7042 // Create call to "omp_target_alloc" with the args as translated llvm values.
7043 llvm::CallInst *call =
7044 builder.CreateCall(ompTargetAllocFunc, {allocSize, llvmDeviceNum});
7045 llvm::Value *resultI64 = builder.CreatePtrToInt(call, builder.getInt64Ty());
7046
7047 // Map the result
7048 moduleTranslation.mapValue(allocMemOp.getResult(), resultI64);
7049 return success();
7050}
7051
7052static llvm::Function *getOmpTargetFree(llvm::IRBuilderBase &builder,
7053 llvm::Module *llvmModule) {
7054 llvm::Type *ptrTy = builder.getPtrTy(0);
7055 llvm::Type *i32Ty = builder.getInt32Ty();
7056 llvm::Type *voidTy = builder.getVoidTy();
7057 llvm::FunctionType *fnType =
7058 llvm::FunctionType::get(voidTy, {ptrTy, i32Ty}, false);
7059 llvm::Function *func = dyn_cast<llvm::Function>(
7060 llvmModule->getOrInsertFunction("omp_target_free", fnType).getCallee());
7061 return func;
7062}
7063
7064static LogicalResult
7065convertTargetFreeMemOp(Operation &opInst, llvm::IRBuilderBase &builder,
7066 LLVM::ModuleTranslation &moduleTranslation) {
7067 auto freeMemOp = cast<omp::TargetFreeMemOp>(opInst);
7068 if (!freeMemOp)
7069 return failure();
7070
7071 // Get "omp_target_free" function
7072 llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
7073 llvm::Function *ompTragetFreeFunc = getOmpTargetFree(builder, llvmModule);
7074 // Get the corresponding device value in llvm
7075 mlir::Value deviceNum = freeMemOp.getDevice();
7076 llvm::Value *llvmDeviceNum = moduleTranslation.lookupValue(deviceNum);
7077 // Get the corresponding heapref value in llvm
7078 mlir::Value heapref = freeMemOp.getHeapref();
7079 llvm::Value *llvmHeapref = moduleTranslation.lookupValue(heapref);
7080 // Convert heapref int to ptr and call "omp_target_free"
7081 llvm::Value *intToPtr =
7082 builder.CreateIntToPtr(llvmHeapref, builder.getPtrTy(0));
7083 builder.CreateCall(ompTragetFreeFunc, {intToPtr, llvmDeviceNum});
7084 return success();
7085}
7086
7087/// Given an OpenMP MLIR operation, create the corresponding LLVM IR (including
7088/// OpenMP runtime calls).
7089LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
7090 Operation *op, llvm::IRBuilderBase &builder,
7091 LLVM::ModuleTranslation &moduleTranslation) const {
7092 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
7093
7094 if (ompBuilder->Config.isTargetDevice() &&
7095 !isa<omp::TargetOp, omp::MapInfoOp, omp::TerminatorOp, omp::YieldOp>(
7096 op) &&
7097 isHostDeviceOp(op))
7098 return op->emitOpError() << "unsupported host op found in device";
7099
7100 // For each loop, introduce one stack frame to hold loop information. Ensure
7101 // this is only done for the outermost loop wrapper to prevent introducing
7102 // multiple stack frames for a single loop. Initially set to null, the loop
7103 // information structure is initialized during translation of the nested
7104 // omp.loop_nest operation, making it available to translation of all loop
7105 // wrappers after their body has been successfully translated.
7106 bool isOutermostLoopWrapper =
7107 isa_and_present<omp::LoopWrapperInterface>(op) &&
7108 !dyn_cast_if_present<omp::LoopWrapperInterface>(op->getParentOp());
7109
7110 if (isOutermostLoopWrapper)
7111 moduleTranslation.stackPush<OpenMPLoopInfoStackFrame>();
7112
7113 auto result =
7114 llvm::TypeSwitch<Operation *, LogicalResult>(op)
7115 .Case([&](omp::BarrierOp op) -> LogicalResult {
7117 return failure();
7118
7119 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
7120 ompBuilder->createBarrier(builder.saveIP(),
7121 llvm::omp::OMPD_barrier);
7122 LogicalResult res = handleError(afterIP, *op);
7123 if (res.succeeded()) {
7124 // If the barrier generated a cancellation check, the insertion
7125 // point might now need to be changed to a new continuation block
7126 builder.restoreIP(*afterIP);
7127 }
7128 return res;
7129 })
7130 .Case([&](omp::TaskyieldOp op) {
7132 return failure();
7133
7134 ompBuilder->createTaskyield(builder.saveIP());
7135 return success();
7136 })
7137 .Case([&](omp::FlushOp op) {
7139 return failure();
7140
7141 // No support in Openmp runtime function (__kmpc_flush) to accept
7142 // the argument list.
7143 // OpenMP standard states the following:
7144 // "An implementation may implement a flush with a list by ignoring
7145 // the list, and treating it the same as a flush without a list."
7146 //
7147 // The argument list is discarded so that, flush with a list is
7148 // treated same as a flush without a list.
7149 ompBuilder->createFlush(builder.saveIP());
7150 return success();
7151 })
7152 .Case([&](omp::ParallelOp op) {
7153 return convertOmpParallel(op, builder, moduleTranslation);
7154 })
7155 .Case([&](omp::MaskedOp) {
7156 return convertOmpMasked(*op, builder, moduleTranslation);
7157 })
7158 .Case([&](omp::MasterOp) {
7159 return convertOmpMaster(*op, builder, moduleTranslation);
7160 })
7161 .Case([&](omp::CriticalOp) {
7162 return convertOmpCritical(*op, builder, moduleTranslation);
7163 })
7164 .Case([&](omp::OrderedRegionOp) {
7165 return convertOmpOrderedRegion(*op, builder, moduleTranslation);
7166 })
7167 .Case([&](omp::OrderedOp) {
7168 return convertOmpOrdered(*op, builder, moduleTranslation);
7169 })
7170 .Case([&](omp::WsloopOp) {
7171 return convertOmpWsloop(*op, builder, moduleTranslation);
7172 })
7173 .Case([&](omp::SimdOp) {
7174 return convertOmpSimd(*op, builder, moduleTranslation);
7175 })
7176 .Case([&](omp::AtomicReadOp) {
7177 return convertOmpAtomicRead(*op, builder, moduleTranslation);
7178 })
7179 .Case([&](omp::AtomicWriteOp) {
7180 return convertOmpAtomicWrite(*op, builder, moduleTranslation);
7181 })
7182 .Case([&](omp::AtomicUpdateOp op) {
7183 return convertOmpAtomicUpdate(op, builder, moduleTranslation);
7184 })
7185 .Case([&](omp::AtomicCaptureOp op) {
7186 return convertOmpAtomicCapture(op, builder, moduleTranslation);
7187 })
7188 .Case([&](omp::CancelOp op) {
7189 return convertOmpCancel(op, builder, moduleTranslation);
7190 })
7191 .Case([&](omp::CancellationPointOp op) {
7192 return convertOmpCancellationPoint(op, builder, moduleTranslation);
7193 })
7194 .Case([&](omp::SectionsOp) {
7195 return convertOmpSections(*op, builder, moduleTranslation);
7196 })
7197 .Case([&](omp::SingleOp op) {
7198 return convertOmpSingle(op, builder, moduleTranslation);
7199 })
7200 .Case([&](omp::TeamsOp op) {
7201 return convertOmpTeams(op, builder, moduleTranslation);
7202 })
7203 .Case([&](omp::TaskOp op) {
7204 return convertOmpTaskOp(op, builder, moduleTranslation);
7205 })
7206 .Case([&](omp::TaskloopOp op) {
7207 return convertOmpTaskloopOp(*op, builder, moduleTranslation);
7208 })
7209 .Case([&](omp::TaskgroupOp op) {
7210 return convertOmpTaskgroupOp(op, builder, moduleTranslation);
7211 })
7212 .Case([&](omp::TaskwaitOp op) {
7213 return convertOmpTaskwaitOp(op, builder, moduleTranslation);
7214 })
7215 .Case<omp::YieldOp, omp::TerminatorOp, omp::DeclareMapperOp,
7216 omp::DeclareMapperInfoOp, omp::DeclareReductionOp,
7217 omp::CriticalDeclareOp>([](auto op) {
7218 // `yield` and `terminator` can be just omitted. The block structure
7219 // was created in the region that handles their parent operation.
7220 // `declare_reduction` will be used by reductions and is not
7221 // converted directly, skip it.
7222 // `declare_mapper` and `declare_mapper.info` are handled whenever
7223 // they are referred to through a `map` clause.
7224 // `critical.declare` is only used to declare names of critical
7225 // sections which will be used by `critical` ops and hence can be
7226 // ignored for lowering. The OpenMP IRBuilder will create unique
7227 // name for critical section names.
7228 return success();
7229 })
7230 .Case([&](omp::ThreadprivateOp) {
7231 return convertOmpThreadprivate(*op, builder, moduleTranslation);
7232 })
7233 .Case<omp::TargetDataOp, omp::TargetEnterDataOp,
7234 omp::TargetExitDataOp, omp::TargetUpdateOp>([&](auto op) {
7235 return convertOmpTargetData(op, builder, moduleTranslation);
7236 })
7237 .Case([&](omp::TargetOp) {
7238 return convertOmpTarget(*op, builder, moduleTranslation);
7239 })
7240 .Case([&](omp::DistributeOp) {
7241 return convertOmpDistribute(*op, builder, moduleTranslation);
7242 })
7243 .Case([&](omp::LoopNestOp) {
7244 return convertOmpLoopNest(*op, builder, moduleTranslation);
7245 })
7246 .Case<omp::MapInfoOp, omp::MapBoundsOp, omp::PrivateClauseOp>(
7247 [&](auto op) {
7248 // No-op, should be handled by relevant owning operations e.g.
7249 // TargetOp, TargetEnterDataOp, TargetExitDataOp, TargetDataOp
7250 // etc. and then discarded
7251 return success();
7252 })
7253 .Case([&](omp::NewCliOp op) {
7254 // Meta-operation: Doesn't do anything by itself, but used to
7255 // identify a loop.
7256 return success();
7257 })
7258 .Case([&](omp::CanonicalLoopOp op) {
7259 return convertOmpCanonicalLoopOp(op, builder, moduleTranslation);
7260 })
7261 .Case([&](omp::UnrollHeuristicOp op) {
7262 // FIXME: Handling omp.unroll_heuristic as an executable requires
7263 // that the generator (e.g. omp.canonical_loop) has been seen first.
7264 // For construct that require all codegen to occur inside a callback
7265 // (e.g. OpenMPIRBilder::createParallel), all codegen of that
7266 // contained region including their transformations must occur at
7267 // the omp.canonical_loop.
7268 return applyUnrollHeuristic(op, builder, moduleTranslation);
7269 })
7270 .Case([&](omp::TileOp op) {
7271 return applyTile(op, builder, moduleTranslation);
7272 })
7273 .Case([&](omp::TargetAllocMemOp) {
7274 return convertTargetAllocMemOp(*op, builder, moduleTranslation);
7275 })
7276 .Case([&](omp::TargetFreeMemOp) {
7277 return convertTargetFreeMemOp(*op, builder, moduleTranslation);
7278 })
7279 .Default([&](Operation *inst) {
7280 return inst->emitError()
7281 << "not yet implemented: " << inst->getName();
7282 });
7283
7284 if (isOutermostLoopWrapper)
7285 moduleTranslation.stackPop();
7286
7287 return result;
7288}
7289
7291 registry.insert<omp::OpenMPDialect>();
7292 registry.addExtension(+[](MLIRContext *ctx, omp::OpenMPDialect *dialect) {
7293 dialect->addInterfaces<OpenMPDialectLLVMIRTranslationInterface>();
7294 });
7295}
7296
7298 DialectRegistry registry;
7300 context.appendDialectRegistry(registry);
7301}
for(Operation *op :ops)
return success()
lhs
static ze_device_handle_t getDevice(const uint32_t driverIdx=0, const int32_t devIdx=0)
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
if(!isCopyOut)
static bool mapTypeToBool(ClauseMapFlags value, ClauseMapFlags flag)
static void handleDeclareTargetMapVar(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::Function *func)
static LogicalResult convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP atomic update operation using OpenMPIRBuilder.
static llvm::omp::OrderKind convertOrderKind(std::optional< omp::ClauseOrderKind > o)
Convert Order attribute to llvm::omp::OrderKind.
static bool teamsReductionContainedInDistribute(omp::TeamsOp teamsOp)
static LogicalResult convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
owningDataPtrPtrReductionGens[i]
static Operation * getGlobalOpFromValue(Value value)
static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind convertToCaptureClauseKind(mlir::omp::DeclareTargetCaptureClause captureClause)
static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo, bool first)
static OpTy castOrGetParentOfType(Operation *op, bool immediateParent=false)
If op is of the given type parameter, return it casted to that type. Otherwise, if its immediate pare...
static LogicalResult convertOmpOrderedRegion(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'ordered_region' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertTargetFreeMemOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static void processMapMembersWithParent(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, llvm::omp::OpenMPOffloadMappingFlags memberOfFlag, TargetDirectiveEnumTy targetDirective)
static LogicalResult convertOmpAtomicWrite(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an omp.atomic.write operation to LLVM IR.
static OwningAtomicReductionGen makeAtomicReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Create an OpenMPIRBuilder-compatible atomic reduction generator for the given reduction declaration.
static OwningDataPtrPtrReductionGen makeRefDataPtrGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, bool isByRef)
Create an OpenMPIRBuilder-compatible data_ptr_ptr reduction generator for the given reduction declara...
static llvm::Value * getRefPtrIfDeclareTarget(Value value, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult applyUnrollHeuristic(omp::UnrollHeuristicOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Apply a #pragma omp unroll / "!$omp unroll" transformation using the OpenMPIRBuilder.
static LogicalResult convertOmpMaster(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'master' operation into LLVM IR using OpenMPIRBuilder.
static void getAsIntegers(ArrayAttr values, llvm::SmallVector< int64_t > &ints)
static llvm::Value * findAssociatedValue(Value privateVar, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Return the llvm::Value * corresponding to the privateVar that is being privatized....
static ArrayRef< bool > getIsByRef(std::optional< ArrayRef< bool > > attr)
static LogicalResult convertOmpTaskwaitOp(omp::TaskwaitOp twOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP loop nest into LLVM IR using OpenMPIRBuilder.
static void popCancelFinalizationCB(const ArrayRef< llvm::BranchInst * > cancelTerminators, llvm::OpenMPIRBuilder &ompBuilder, const llvm::OpenMPIRBuilder::InsertPointTy &afterIP)
If we cancelled the construct, we should branch to the finalization block of that construct....
static LogicalResult cleanupPrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, Location loc, SmallVectorImpl< llvm::Value * > &llvmPrivateVars, SmallVectorImpl< omp::PrivateClauseOp > &privateDecls)
static void createAlteredByCaptureMap(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder)
static LogicalResult convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP task construct into LLVM IR using OpenMPIRBuilder.
static void genMapInfos(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, TargetDirectiveEnumTy targetDirective)
static llvm::AtomicOrdering convertAtomicOrdering(std::optional< omp::ClauseMemoryOrderKind > ao)
Convert an Atomic Ordering attribute to llvm::AtomicOrdering.
static void setInsertPointForPossiblyEmptyBlock(llvm::IRBuilderBase &builder, llvm::BasicBlock *block=nullptr)
static LogicalResult convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'critical' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, TargetDirectiveEnumTy targetDirective)
static LogicalResult convertOmpCanonicalLoopOp(omp::CanonicalLoopOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Convert an omp.canonical_loop to LLVM-IR.
static LogicalResult convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static std::optional< int64_t > extractConstInteger(Value value)
If the given value is defined by an llvm.mlir.constant operation and it is of an integer type,...
static llvm::Expected< llvm::Value * > initPrivateVar(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, omp::PrivateClauseOp &privDecl, llvm::Value *nonPrivateVar, BlockArgument &blockArg, llvm::Value *llvmPrivateVar, llvm::BasicBlock *privInitBlock, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Initialize a single (first)private variable. You probably want to use allocateAndInitPrivateVars inst...
static LogicalResult allocAndInitializeReductionVars(OP op, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, llvm::ArrayRef< bool > isByRef)
static LogicalResult convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Function * getOmpTargetAlloc(llvm::IRBuilderBase &builder, llvm::Module *llvmModule)
static llvm::Expected< llvm::Function * > emitUserDefinedMapper(Operation *declMapperOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::StringRef mapperFuncName, TargetDirectiveEnumTy targetDirective)
static llvm::Expected< llvm::BasicBlock * > allocatePrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, PrivateVarsInfo &privateVarsInfo, const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Allocate and initialize delayed private variables. Returns the basic block which comes after all of t...
static LogicalResult convertOmpOrdered(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'ordered' operation into LLVM IR using OpenMPIRBuilder.
static void processMapWithMembersOf(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, TargetDirectiveEnumTy targetDirective)
static LogicalResult convertOmpTaskloopOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpMasked(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'masked' operation into LLVM IR using OpenMPIRBuilder.
static llvm::AtomicRMWInst::BinOp convertBinOpToAtomic(Operation &op)
Converts an LLVM dialect binary operation to the corresponding enum value for atomicrmw supported bin...
static void pushCancelFinalizationCB(SmallVectorImpl< llvm::BranchInst * > &cancelTerminators, llvm::IRBuilderBase &llvmBuilder, llvm::OpenMPIRBuilder &ompBuilder, mlir::Operation *op, llvm::omp::Directive cancelDirective)
Shared implementation of a callback which adds a termiator for the new block created for the branch t...
static LogicalResult convertOmpCancel(omp::CancelOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp)
allocatedType moduleTranslation static convertType(allocatedType) LogicalResult inlineOmpRegionCleanup(llvm::SmallVectorImpl< Region * > &cleanupRegions, llvm::ArrayRef< llvm::Value * > privateVariables, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, StringRef regionName, bool shouldLoadCleanupRegionArg=true)
handling of DeclareReductionOp's cleanup region
static llvm::SmallString< 64 > getDeclareTargetRefPtrSuffix(LLVM::GlobalOp globalOp, llvm::OpenMPIRBuilder &ompBuilder)
static llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Error initPrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, PrivateVarsInfo &privateVarsInfo, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
static llvm::CanonicalLoopInfo * findCurrentLoopInfo(LLVM::ModuleTranslation &moduleTranslation)
Find the loop information structure for the loop nest being translated.
static OwningReductionGen makeReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Create an OpenMPIRBuilder-compatible reduction generator for the given reduction declaration.
static std::vector< llvm::Value * > calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, bool isArrayTy, OperandRange bounds)
This function calculates the array/pointer offset for map data provided with bounds operations,...
static LogicalResult convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts the OpenMP parallel operation to LLVM IR.
static llvm::OpenMPIRBuilder::InsertPointTy findAllocaInsertPoint(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Find the insertion point for allocas given the current insertion point for normal operations in the b...
static LogicalResult inlineConvertOmpRegions(Region &region, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::Value * > *continuationBlockArgs=nullptr)
Translates the blocks contained in the given region and appends them to at the current insertion poin...
static LogicalResult convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP Threadprivate operation into LLVM IR using OpenMPIRBuilder.
static omp::PrivateClauseOp findPrivatizer(Operation *from, SymbolRefAttr symbolName)
Looks up from the operation from and returns the PrivateClauseOp with name symbolName.
static LogicalResult convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Expected< llvm::Function * > getOrCreateUserDefinedMapperFunc(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, TargetDirectiveEnumTy targetDirective)
static uint64_t getTypeByteSize(mlir::Type type, const DataLayout &dl)
static void processIndividualMap(MapInfoData &mapData, size_t mapDataIdx, MapInfosTy &combinedInfo, TargetDirectiveEnumTy targetDirective, int mapDataParentIdx=-1)
static void getTargetEntryUniqueInfo(llvm::TargetRegionEntryInfo &targetInfo, omp::TargetOp targetOp, llvm::StringRef parentName="")
static void extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads, Value &numTeamsLower, Value &numTeamsUpper, Value &threadLimit, llvm::SmallVectorImpl< Value > *lowerBounds=nullptr, llvm::SmallVectorImpl< Value > *upperBounds=nullptr, llvm::SmallVectorImpl< Value > *steps=nullptr)
Follow uses of host_eval-defined block arguments of the given omp.target operation and populate outpu...
static llvm::Expected< llvm::BasicBlock * > convertOmpOpRegions(Region &region, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::PHINode * > *continuationBlockPHIs=nullptr)
Converts the given region that appears within an OpenMP dialect operation to LLVM IR,...
static LogicalResult copyFirstPrivateVars(mlir::Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::Value * > &moldVars, ArrayRef< llvm::Value * > llvmPrivateVars, SmallVectorImpl< omp::PrivateClauseOp > &privateDecls, bool insertBarrier, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
static LogicalResult allocReductionVars(T loop, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, SmallVectorImpl< DeferredStore > &deferredStores, llvm::ArrayRef< bool > isByRefs)
Allocate space for privatized reduction variables.
static bool constructIsCancellable(Operation *op)
Returns true if the construct contains omp.cancel or omp.cancellation_point.
static llvm::omp::OpenMPOffloadMappingFlags convertClauseMapFlags(omp::ClauseMapFlags mlirFlags)
static llvm::IRBuilderBase::InsertPoint createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg, llvm::Value *input, llvm::Value *&retVal, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase::InsertPoint allocaIP, llvm::IRBuilderBase::InsertPoint codeGenIP)
static void mapInitializationArgs(T loop, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, DenseMap< Value, llvm::Value * > &reductionVariableMap, unsigned i)
Map input arguments to reduction initialization region.
static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind)
Convert ProcBindKind from MLIR-generated enum to LLVM enum.
static void getOverlappedMembers(llvm::SmallVectorImpl< size_t > &overlapMapDataIdxs, omp::MapInfoOp parentOp)
static LogicalResult convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP single construct into LLVM IR using OpenMPIRBuilder.
static bool isDeclareTargetTo(Value value)
static uint64_t getArrayElementSizeInBits(LLVM::LLVMArrayType arrTy, DataLayout &dl)
static void collectReductionDecls(T op, SmallVectorImpl< omp::DeclareReductionOp > &reductions)
Populates reductions with reduction declarations used in the given op.
static LogicalResult handleError(llvm::Error error, Operation &op)
static LogicalResult convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind convertToDeviceClauseKind(mlir::omp::DeclareTargetDeviceType deviceClause)
static LogicalResult checkImplementationStatus(Operation &op)
Check whether translation to LLVM IR for the given operation is currently supported.
static LogicalResult createReductionsAndCleanup(OP op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, ArrayRef< llvm::Value * > privateReductionVariables, ArrayRef< bool > isByRef, bool isNowait=false, bool isTeamsReduction=false)
static LogicalResult convertOmpCancellationPoint(omp::CancellationPointOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static bool opIsInSingleThread(mlir::Operation *op)
This can't always be determined statically, but when we can, it is good to avoid generating compiler-...
static uint64_t getReductionDataSize(OpTy &op)
static LogicalResult convertOmpAtomicRead(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Convert omp.atomic.read operation to LLVM IR.
static llvm::omp::Directive convertCancellationConstructType(omp::ClauseCancellationConstructType directive)
static void initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp, llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs, bool isTargetDevice, bool isGPU)
Populate default MinTeams, MaxTeams and MaxThreads to their default values as stated by the correspon...
static void initTargetRuntimeAttrs(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, omp::TargetOp targetOp, Operation *capturedOp, llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs)
Gather LLVM runtime values for all clauses evaluated in the host that are passed to the kernel invoca...
static LogicalResult convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static bool isHostDeviceOp(Operation *op)
static bool isDeclareTargetLink(Value value)
static LogicalResult convertFlagsAttr(Operation *op, mlir::omp::FlagsAttr attribute, LLVM::ModuleTranslation &moduleTranslation)
Lowers the FlagsAttr which is applied to the module on the device pass when offloading,...
static bool checkIfPointerMap(omp::MapInfoOp mapOp)
static LogicalResult applyTile(omp::TileOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Apply a #pragma omp tile / !$omp tile transformation using the OpenMPIRBuilder.
static void sortMapIndices(llvm::SmallVectorImpl< size_t > &indices, omp::MapInfoOp mapInfo)
static void buildDependData(std::optional< ArrayAttr > dependKinds, OperandRange dependVars, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::OpenMPIRBuilder::DependData > &dds)
static llvm::Function * getOmpTargetFree(llvm::IRBuilderBase &builder, llvm::Module *llvmModule)
static LogicalResult convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder.
static void collectMapDataFromMapOperands(MapInfoData &mapData, SmallVectorImpl< Value > &mapVars, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, llvm::IRBuilderBase &builder, ArrayRef< Value > useDevPtrOperands={}, ArrayRef< Value > useDevAddrOperands={}, ArrayRef< Value > hasDevAddrOperands={})
static void extractAtomicControlFlags(omp::AtomicUpdateOp atomicUpdateOp, bool &isIgnoreDenormalMode, bool &isFineGrainedMemory, bool &isRemoteMemory)
static Operation * genLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr, unsigned numCases, bool needsUniv, ArrayRef< TensorLevel > tidLvls)
Generates a for-loop or a while-loop, depending on whether it implements singleton iteration or co-it...
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition TypeID.h:331
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:309
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
Operation & front()
Definition Block.h:163
Operation & back()
Definition Block.h:162
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
iterator begin()
Definition Block.h:153
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
An instance of this location represents a tuple of file, line number, and column number.
Definition Location.h:174
Implementation class for module translation.
llvm::BasicBlock * lookupBlock(Block *block) const
Finds an LLVM IR basic block that corresponds to the given MLIR block.
WalkResult stackWalk(llvm::function_ref< WalkResult(T &)> callback)
Calls callback for every ModuleTranslation stack frame of type T starting from the top of the stack.
void stackPush(Args &&...args)
Creates a stack frame of type T on ModuleTranslation stack.
LogicalResult convertBlock(Block &bb, bool ignoreArguments, llvm::IRBuilderBase &builder)
Translates the contents of the given block to LLVM IR using this translator.
SmallVector< llvm::Value * > lookupValues(ValueRange values)
Looks up remapped a list of remapped values.
void mapFunction(StringRef name, llvm::Function *func)
Stores the mapping between a function name and its LLVM IR representation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
void invalidateOmpLoop(omp::NewCliOp mlir)
Mark an OpenMP loop as having been consumed.
SymbolTableCollection & symbolTable()
llvm::Type * convertType(Type type)
Converts the type from MLIR LLVM dialect to LLVM.
llvm::OpenMPIRBuilder * getOpenMPBuilder()
Returns the OpenMP IR builder associated with the LLVM IR module being constructed.
void mapOmpLoop(omp::NewCliOp mlir, llvm::CanonicalLoopInfo *llvm)
Map an MLIR OpenMP dialect CanonicalLoopInfo to its lowered LLVM-IR OpenMPIRBuilder CanonicalLoopInfo...
llvm::GlobalValue * lookupGlobal(Operation *op)
Finds an LLVM IR global value that corresponds to the given MLIR operation defining a global value.
SaveStateStack< T, ModuleTranslation > SaveStack
RAII object calling stackPush/stackPop on construction/destruction.
llvm::Function * lookupFunction(StringRef name) const
Finds an LLVM IR function by its name.
void mapBlock(Block *mlir, llvm::BasicBlock *llvm)
Stores the mapping between an MLIR block and LLVM IR basic block.
llvm::Module * getLLVMModule()
Returns the LLVM module in which the IR is being constructed.
void stackPop()
Pops the last element from the ModuleTranslation stack.
void forgetMapping(Region &region)
Removes the mapping for blocks contained in the region and values defined in these blocks.
void mapValue(Value mlir, llvm::Value *llvm)
Stores the mapping between an MLIR value and its LLVM IR counterpart.
llvm::CanonicalLoopInfo * lookupOMPLoop(omp::NewCliOp mlir) const
Find the LLVM-IR loop that represents an MLIR loop.
llvm::LLVMContext & getLLVMContext() const
Returns the LLVM context in which the IR is being constructed.
Utility class to translate MLIR LLVM dialect types to LLVM IR.
Definition TypeToLLVM.h:39
unsigned getPreferredAlignment(Type type, const llvm::DataLayout &layout)
Returns the preferred alignment for the type given the data layout.
T findInstanceOf()
Return an instance of the given location type if one is nested under the current location.
Definition Location.h:45
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
void appendDialectRegistry(const DialectRegistry &registry)
Append the contents of the given dialect registry to the registry associated with this context.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:686
Value getOperand(unsigned idx)
Definition Operation.h:350
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
unsigned getNumOperands()
Definition Operation.h:346
OperandRange operand_range
Definition Operation.h:371
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:238
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:797
user_range getUsers()
Returns a range of all users.
Definition Operation.h:873
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
BlockArgListType getArguments()
Definition Region.h:81
bool empty()
Definition Region.h:60
unsigned getNumArguments()
Definition Region.h:123
iterator begin()
Definition Region.h:55
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition Region.h:200
BlockListType & getBlocks()
Definition Region.h:45
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
Concrete CRTP base class for StateStack frames.
Definition StateStack.h:47
@ Private
The symbol is private and may only be referenced by SymbolRefAttrs local to the operations within the...
Definition SymbolTable.h:97
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult advance()
Definition WalkResult.h:47
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition WalkResult.h:51
static WalkResult interrupt()
Definition WalkResult.h:46
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Definition CallGraph.h:229
void connectPHINodes(Region &region, const ModuleTranslation &state)
For all blocks in the region that were converted to LLVM IR using the given ModuleTranslation,...
llvm::Constant * createMappingInformation(Location loc, llvm::OpenMPIRBuilder &builder)
Create a constant string representing the mapping information extracted from the MLIR location inform...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:578
Include the generated interface declarations.
SetVector< Block * > getBlocksSortedByDominance(Region &region)
Gets a list of blocks that is sorted according to dominance.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
const FrozenRewritePatternSet GreedyRewriteConfig config
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void registerOpenMPDialectTranslation(DialectRegistry &registry)
Register the OpenMP dialect and the translation from it to the LLVM IR in the given registry;.
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:123
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition Utils.cpp:1294
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:118
A util to collect info needed to convert delayed privatizers from MLIR to LLVM.
SmallVector< mlir::Value > mlirVars
SmallVector< omp::PrivateClauseOp > privatizers
MutableArrayRef< BlockArgument > blockArgs
SmallVector< llvm::Value * > llvmVars