MLIR 23.0.0git
OpenMPToLLVMIRTranslation.cpp
Go to the documentation of this file.
1//===- OpenMPToLLVMIRTranslation.cpp - Translate OpenMP dialect to LLVM IR-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a translation between the MLIR OpenMP dialect and LLVM
10// IR.
11//
12//===----------------------------------------------------------------------===//
19#include "mlir/IR/Operation.h"
20#include "mlir/Support/LLVM.h"
23
24#include "llvm/ADT/ArrayRef.h"
25#include "llvm/ADT/SmallVector.h"
26#include "llvm/ADT/TypeSwitch.h"
27#include "llvm/Frontend/OpenMP/OMPConstants.h"
28#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
29#include "llvm/IR/Constants.h"
30#include "llvm/IR/DebugInfoMetadata.h"
31#include "llvm/IR/DerivedTypes.h"
32#include "llvm/IR/IRBuilder.h"
33#include "llvm/IR/MDBuilder.h"
34#include "llvm/IR/ReplaceConstant.h"
35#include "llvm/Support/FileSystem.h"
36#include "llvm/Support/VirtualFileSystem.h"
37#include "llvm/TargetParser/Triple.h"
38#include "llvm/Transforms/Utils/ModuleUtils.h"
39
40#include <cstdint>
41#include <iterator>
42#include <numeric>
43#include <optional>
44#include <utility>
45
46using namespace mlir;
47
48namespace {
49static llvm::omp::ScheduleKind
50convertToScheduleKind(std::optional<omp::ClauseScheduleKind> schedKind) {
51 if (!schedKind.has_value())
52 return llvm::omp::OMP_SCHEDULE_Default;
53 switch (schedKind.value()) {
54 case omp::ClauseScheduleKind::Static:
55 return llvm::omp::OMP_SCHEDULE_Static;
56 case omp::ClauseScheduleKind::Dynamic:
57 return llvm::omp::OMP_SCHEDULE_Dynamic;
58 case omp::ClauseScheduleKind::Guided:
59 return llvm::omp::OMP_SCHEDULE_Guided;
60 case omp::ClauseScheduleKind::Auto:
61 return llvm::omp::OMP_SCHEDULE_Auto;
62 case omp::ClauseScheduleKind::Runtime:
63 return llvm::omp::OMP_SCHEDULE_Runtime;
64 case omp::ClauseScheduleKind::Distribute:
65 return llvm::omp::OMP_SCHEDULE_Distribute;
66 }
67 llvm_unreachable("unhandled schedule clause argument");
68}
69
70/// ModuleTranslation stack frame for OpenMP operations. This keeps track of the
71/// insertion points for allocas.
72class OpenMPAllocaStackFrame
73 : public StateStackFrameBase<OpenMPAllocaStackFrame> {
74public:
76
77 explicit OpenMPAllocaStackFrame(llvm::OpenMPIRBuilder::InsertPointTy allocaIP)
78 : allocaInsertPoint(allocaIP) {}
79 llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
80};
81
82/// Stack frame to hold a \see llvm::CanonicalLoopInfo representing the
83/// collapsed canonical loop information corresponding to an \c omp.loop_nest
84/// operation.
85class OpenMPLoopInfoStackFrame
86 : public StateStackFrameBase<OpenMPLoopInfoStackFrame> {
87public:
88 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPLoopInfoStackFrame)
89 llvm::CanonicalLoopInfo *loopInfo = nullptr;
90};
91
92/// Custom error class to signal translation errors that don't need reporting,
93/// since encountering them will have already triggered relevant error messages.
94///
95/// Its purpose is to serve as the glue between MLIR failures represented as
96/// \see LogicalResult instances and \see llvm::Error instances used to
97/// propagate errors through the \see llvm::OpenMPIRBuilder. Generally, when an
98/// error of the first type is raised, a message is emitted directly (the \see
99/// LogicalResult itself does not hold any information). If we need to forward
100/// this error condition as an \see llvm::Error while avoiding triggering some
101/// redundant error reporting later on, we need a custom \see llvm::ErrorInfo
102/// class to just signal this situation has happened.
103///
104/// For example, this class should be used to trigger errors from within
105/// callbacks passed to the \see OpenMPIRBuilder when they were triggered by the
106/// translation of their own regions. This unclutters the error log from
107/// redundant messages.
108class PreviouslyReportedError
109 : public llvm::ErrorInfo<PreviouslyReportedError> {
110public:
111 void log(raw_ostream &) const override {
112 // Do not log anything.
113 }
114
115 std::error_code convertToErrorCode() const override {
116 llvm_unreachable(
117 "PreviouslyReportedError doesn't support ECError conversion");
118 }
119
120 // Used by ErrorInfo::classID.
121 static char ID;
122};
123
124char PreviouslyReportedError::ID = 0;
125
126/*
127 * Custom class for processing linear clause for omp.wsloop
128 * and omp.simd. Linear clause translation requires setup,
129 * initialization, update, and finalization at varying
130 * basic blocks in the IR. This class helps maintain
131 * internal state to allow consistent translation in
132 * each of these stages.
133 */
134
135class LinearClauseProcessor {
136
137private:
138 SmallVector<llvm::Value *> linearPreconditionVars;
139 SmallVector<llvm::Value *> linearLoopBodyTemps;
140 SmallVector<llvm::Value *> linearOrigVal;
141 SmallVector<llvm::Value *> linearSteps;
142 SmallVector<llvm::Type *> linearVarTypes;
143 llvm::BasicBlock *linearFinalizationBB;
144 llvm::BasicBlock *linearExitBB;
145 llvm::BasicBlock *linearLastIterExitBB;
146
147public:
148 // Register type for the linear variables
149 void registerType(LLVM::ModuleTranslation &moduleTranslation,
150 mlir::Attribute &ty) {
151 linearVarTypes.push_back(moduleTranslation.convertType(
152 mlir::cast<mlir::TypeAttr>(ty).getValue()));
153 }
154
155 // Allocate space for linear variabes
156 void createLinearVar(llvm::IRBuilderBase &builder,
157 LLVM::ModuleTranslation &moduleTranslation,
158 llvm::Value *linearVar, int idx) {
159 linearPreconditionVars.push_back(
160 builder.CreateAlloca(linearVarTypes[idx], nullptr, ".linear_var"));
161 llvm::Value *linearLoopBodyTemp =
162 builder.CreateAlloca(linearVarTypes[idx], nullptr, ".linear_result");
163 linearOrigVal.push_back(linearVar);
164 linearLoopBodyTemps.push_back(linearLoopBodyTemp);
165 }
166
167 // Initialize linear step
168 inline void initLinearStep(LLVM::ModuleTranslation &moduleTranslation,
169 mlir::Value &linearStep) {
170 linearSteps.push_back(moduleTranslation.lookupValue(linearStep));
171 }
172
173 // Emit IR for initialization of linear variables
174 void initLinearVar(llvm::IRBuilderBase &builder,
175 LLVM::ModuleTranslation &moduleTranslation,
176 llvm::BasicBlock *loopPreHeader) {
177 builder.SetInsertPoint(loopPreHeader->getTerminator());
178 for (size_t index = 0; index < linearOrigVal.size(); index++) {
179 llvm::LoadInst *linearVarLoad =
180 builder.CreateLoad(linearVarTypes[index], linearOrigVal[index]);
181 builder.CreateStore(linearVarLoad, linearPreconditionVars[index]);
182 }
183 }
184
185 // Emit IR for updating Linear variables
186 void updateLinearVar(llvm::IRBuilderBase &builder, llvm::BasicBlock *loopBody,
187 llvm::Value *loopInductionVar) {
188 builder.SetInsertPoint(loopBody->getTerminator());
189 for (size_t index = 0; index < linearPreconditionVars.size(); index++) {
190 llvm::Type *linearVarType = linearVarTypes[index];
191 llvm::Value *iv = loopInductionVar;
192 llvm::Value *step = linearSteps[index];
193
194 if (!iv->getType()->isIntegerTy())
195 llvm_unreachable("OpenMP loop induction variable must be an integer "
196 "type");
197
198 if (linearVarType->isIntegerTy()) {
199 // Integer path: normalize all arithmetic to linearVarType
200 iv = builder.CreateSExtOrTrunc(iv, linearVarType);
201 step = builder.CreateSExtOrTrunc(step, linearVarType);
202
203 llvm::LoadInst *linearVarStart =
204 builder.CreateLoad(linearVarType, linearPreconditionVars[index]);
205 llvm::Value *mulInst = builder.CreateMul(iv, step);
206 llvm::Value *addInst = builder.CreateAdd(linearVarStart, mulInst);
207 builder.CreateStore(addInst, linearLoopBodyTemps[index]);
208 } else if (linearVarType->isFloatingPointTy()) {
209 // Float path: perform multiply in integer, then convert to float
210 step = builder.CreateSExtOrTrunc(step, iv->getType());
211 llvm::Value *mulInst = builder.CreateMul(iv, step);
212
213 llvm::LoadInst *linearVarStart =
214 builder.CreateLoad(linearVarType, linearPreconditionVars[index]);
215 llvm::Value *mulFp = builder.CreateSIToFP(mulInst, linearVarType);
216 llvm::Value *addInst = builder.CreateFAdd(linearVarStart, mulFp);
217 builder.CreateStore(addInst, linearLoopBodyTemps[index]);
218 } else {
219 llvm_unreachable(
220 "Linear variable must be of integer or floating-point type");
221 }
222 }
223 }
224
225 // Linear variable finalization is conditional on the last logical iteration.
226 // Create BB splits to manage the same.
227 void splitLinearFiniBB(llvm::IRBuilderBase &builder,
228 llvm::BasicBlock *loopExit) {
229 linearFinalizationBB = loopExit->splitBasicBlock(
230 loopExit->getTerminator(), "omp_loop.linear_finalization");
231 linearExitBB = linearFinalizationBB->splitBasicBlock(
232 linearFinalizationBB->getTerminator(), "omp_loop.linear_exit");
233 linearLastIterExitBB = linearFinalizationBB->splitBasicBlock(
234 linearFinalizationBB->getTerminator(), "omp_loop.linear_lastiter_exit");
235 }
236
237 // Finalize the linear vars
238 llvm::OpenMPIRBuilder::InsertPointOrErrorTy
239 finalizeLinearVar(llvm::IRBuilderBase &builder,
240 LLVM::ModuleTranslation &moduleTranslation,
241 llvm::Value *lastIter) {
242 // Emit condition to check whether last logical iteration is being executed
243 builder.SetInsertPoint(linearFinalizationBB->getTerminator());
244 llvm::Value *loopLastIterLoad = builder.CreateLoad(
245 llvm::Type::getInt32Ty(builder.getContext()), lastIter);
246 llvm::Value *isLast =
247 builder.CreateCmp(llvm::CmpInst::ICMP_NE, loopLastIterLoad,
248 llvm::ConstantInt::get(
249 llvm::Type::getInt32Ty(builder.getContext()), 0));
250 // Store the linear variable values to original variables.
251 builder.SetInsertPoint(linearLastIterExitBB->getTerminator());
252 for (size_t index = 0; index < linearOrigVal.size(); index++) {
253 llvm::LoadInst *linearVarTemp =
254 builder.CreateLoad(linearVarTypes[index], linearLoopBodyTemps[index]);
255 builder.CreateStore(linearVarTemp, linearOrigVal[index]);
256 }
257
258 // Create conditional branch such that the linear variable
259 // values are stored to original variables only at the
260 // last logical iteration
261 builder.SetInsertPoint(linearFinalizationBB->getTerminator());
262 builder.CreateCondBr(isLast, linearLastIterExitBB, linearExitBB);
263 linearFinalizationBB->getTerminator()->eraseFromParent();
264 // Emit barrier
265 builder.SetInsertPoint(linearExitBB->getTerminator());
266 return moduleTranslation.getOpenMPBuilder()->createBarrier(
267 builder.saveIP(), llvm::omp::OMPD_barrier);
268 }
269
270 // Emit stores for linear variables. Useful in case of SIMD
271 // construct.
272 void emitStoresForLinearVar(llvm::IRBuilderBase &builder) {
273 for (size_t index = 0; index < linearOrigVal.size(); index++) {
274 llvm::LoadInst *linearVarTemp =
275 builder.CreateLoad(linearVarTypes[index], linearLoopBodyTemps[index]);
276 builder.CreateStore(linearVarTemp, linearOrigVal[index]);
277 }
278 }
279
280 // Rewrite all uses of the original variable in `BBName`
281 // with the linear variable in-place
282 void rewriteInPlace(llvm::IRBuilderBase &builder, const std::string &BBName,
283 size_t varIndex) {
284 llvm::SmallVector<llvm::User *> users;
285 for (llvm::User *user : linearOrigVal[varIndex]->users())
286 users.push_back(user);
287 for (auto *user : users) {
288 if (auto *userInst = dyn_cast<llvm::Instruction>(user)) {
289 if (userInst->getParent()->getName().str().find(BBName) !=
290 std::string::npos)
291 user->replaceUsesOfWith(linearOrigVal[varIndex],
292 linearLoopBodyTemps[varIndex]);
293 }
294 }
295 }
296};
297
298} // namespace
299
300/// Looks up from the operation from and returns the PrivateClauseOp with
301/// name symbolName
302static omp::PrivateClauseOp findPrivatizer(Operation *from,
303 SymbolRefAttr symbolName) {
304 omp::PrivateClauseOp privatizer =
306 symbolName);
307 assert(privatizer && "privatizer not found in the symbol table");
308 return privatizer;
309}
310
311/// Check whether translation to LLVM IR for the given operation is currently
312/// supported. If not, descriptive diagnostics will be emitted to let users know
313/// this is a not-yet-implemented feature.
314///
315/// \returns success if no unimplemented features are needed to translate the
316/// given operation.
317static LogicalResult checkImplementationStatus(Operation &op) {
318 auto todo = [&op](StringRef clauseName) {
319 return op.emitError() << "not yet implemented: Unhandled clause "
320 << clauseName << " in " << op.getName()
321 << " operation";
322 };
323
324 auto checkAffinity = [&todo](auto op, LogicalResult &result) {
325 if (!op.getAffinityVars().empty())
326 result = todo("affinity");
327 };
328 auto checkAllocate = [&todo](auto op, LogicalResult &result) {
329 if (!op.getAllocateVars().empty() || !op.getAllocatorVars().empty())
330 result = todo("allocate");
331 };
332 auto checkBare = [&todo](auto op, LogicalResult &result) {
333 if (op.getBare())
334 result = todo("ompx_bare");
335 };
336 auto checkDepend = [&todo](auto op, LogicalResult &result) {
337 if (!op.getDependVars().empty() || op.getDependKinds())
338 result = todo("depend");
339 };
340 auto checkHint = [](auto op, LogicalResult &) {
341 if (op.getHint())
342 op.emitWarning("hint clause discarded");
343 };
344 auto checkInReduction = [&todo](auto op, LogicalResult &result) {
345 if (!op.getInReductionVars().empty() || op.getInReductionByref() ||
346 op.getInReductionSyms())
347 result = todo("in_reduction");
348 };
349 auto checkNowait = [&todo](auto op, LogicalResult &result) {
350 if (op.getNowait())
351 result = todo("nowait");
352 };
353 auto checkOrder = [&todo](auto op, LogicalResult &result) {
354 if (op.getOrder() || op.getOrderMod())
355 result = todo("order");
356 };
357 auto checkParLevelSimd = [&todo](auto op, LogicalResult &result) {
358 if (op.getParLevelSimd())
359 result = todo("parallelization-level");
360 };
361 auto checkPrivate = [&todo](auto op, LogicalResult &result) {
362 if (!op.getPrivateVars().empty() || op.getPrivateSyms())
363 result = todo("privatization");
364 };
365 auto checkReduction = [&todo](auto op, LogicalResult &result) {
366 if (isa<omp::TeamsOp>(op) || isa<omp::TaskloopOp>(op))
367 if (!op.getReductionVars().empty() || op.getReductionByref() ||
368 op.getReductionSyms())
369 result = todo("reduction");
370 if (op.getReductionMod() &&
371 op.getReductionMod().value() != omp::ReductionModifier::defaultmod)
372 result = todo("reduction with modifier");
373 };
374 auto checkTaskReduction = [&todo](auto op, LogicalResult &result) {
375 if (!op.getTaskReductionVars().empty() || op.getTaskReductionByref() ||
376 op.getTaskReductionSyms())
377 result = todo("task_reduction");
378 };
379 auto checkNumTeams = [&todo](auto op, LogicalResult &result) {
380 if (op.hasNumTeamsMultiDim())
381 result = todo("num_teams with multi-dimensional values");
382 };
383 auto checkNumThreads = [&todo](auto op, LogicalResult &result) {
384 if (op.hasNumThreadsMultiDim())
385 result = todo("num_threads with multi-dimensional values");
386 };
387
388 auto checkThreadLimit = [&todo](auto op, LogicalResult &result) {
389 if (op.hasThreadLimitMultiDim())
390 result = todo("thread_limit with multi-dimensional values");
391 };
392
393 LogicalResult result = success();
395 .Case([&](omp::DistributeOp op) {
396 checkAllocate(op, result);
397 checkOrder(op, result);
398 })
399 .Case([&](omp::OrderedRegionOp op) { checkParLevelSimd(op, result); })
400 .Case([&](omp::SectionsOp op) {
401 checkAllocate(op, result);
402 checkPrivate(op, result);
403 checkReduction(op, result);
404 })
405 .Case([&](omp::SingleOp op) {
406 checkAllocate(op, result);
407 checkPrivate(op, result);
408 })
409 .Case([&](omp::TeamsOp op) {
410 checkAllocate(op, result);
411 checkPrivate(op, result);
412 checkNumTeams(op, result);
413 checkThreadLimit(op, result);
414 })
415 .Case([&](omp::TaskOp op) {
416 checkAffinity(op, result);
417 checkAllocate(op, result);
418 checkInReduction(op, result);
419 })
420 .Case([&](omp::TaskgroupOp op) {
421 checkAllocate(op, result);
422 checkTaskReduction(op, result);
423 })
424 .Case([&](omp::TaskwaitOp op) {
425 checkDepend(op, result);
426 checkNowait(op, result);
427 })
428 .Case([&](omp::TaskloopOp op) {
429 checkAllocate(op, result);
430 checkInReduction(op, result);
431 checkReduction(op, result);
432 })
433 .Case([&](omp::WsloopOp op) {
434 checkAllocate(op, result);
435 checkOrder(op, result);
436 checkReduction(op, result);
437 })
438 .Case([&](omp::ParallelOp op) {
439 checkAllocate(op, result);
440 checkReduction(op, result);
441 checkNumThreads(op, result);
442 })
443 .Case([&](omp::SimdOp op) { checkReduction(op, result); })
444 .Case<omp::AtomicReadOp, omp::AtomicWriteOp, omp::AtomicUpdateOp,
445 omp::AtomicCaptureOp>([&](auto op) { checkHint(op, result); })
446 .Case<omp::TargetEnterDataOp, omp::TargetExitDataOp>(
447 [&](auto op) { checkDepend(op, result); })
448 .Case([&](omp::TargetUpdateOp op) { checkDepend(op, result); })
449 .Case([&](omp::TargetOp op) {
450 checkAllocate(op, result);
451 checkBare(op, result);
452 checkInReduction(op, result);
453 checkThreadLimit(op, result);
454 })
455 .Default([](Operation &) {
456 // Assume all clauses for an operation can be translated unless they are
457 // checked above.
458 });
459 return result;
460}
461
462static LogicalResult handleError(llvm::Error error, Operation &op) {
463 LogicalResult result = success();
464 if (error) {
465 llvm::handleAllErrors(
466 std::move(error),
467 [&](const PreviouslyReportedError &) { result = failure(); },
468 [&](const llvm::ErrorInfoBase &err) {
469 result = op.emitError(err.message());
470 });
471 }
472 return result;
473}
474
475template <typename T>
476static LogicalResult handleError(llvm::Expected<T> &result, Operation &op) {
477 if (!result)
478 return handleError(result.takeError(), op);
479
480 return success();
481}
482
483/// Find the insertion point for allocas given the current insertion point for
484/// normal operations in the builder.
485static llvm::OpenMPIRBuilder::InsertPointTy
486findAllocaInsertPoint(llvm::IRBuilderBase &builder,
487 LLVM::ModuleTranslation &moduleTranslation) {
488 // If there is an alloca insertion point on stack, i.e. we are in a nested
489 // operation and a specific point was provided by some surrounding operation,
490 // use it.
491 llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
492 WalkResult walkResult = moduleTranslation.stackWalk<OpenMPAllocaStackFrame>(
493 [&](OpenMPAllocaStackFrame &frame) {
494 allocaInsertPoint = frame.allocaInsertPoint;
495 return WalkResult::interrupt();
496 });
497 // In cases with multiple levels of outlining, the tree walk might find an
498 // alloca insertion point that is inside the original function while the
499 // builder insertion point is inside the outlined function. We need to make
500 // sure that we do not use it in those cases.
501 if (walkResult.wasInterrupted() &&
502 allocaInsertPoint.getBlock()->getParent() ==
503 builder.GetInsertBlock()->getParent())
504 return allocaInsertPoint;
505
506 // Otherwise, insert to the entry block of the surrounding function.
507 // If the current IRBuilder InsertPoint is the function's entry, it cannot
508 // also be used for alloca insertion which would result in insertion order
509 // confusion. Create a new BasicBlock for the Builder and use the entry block
510 // for the allocs.
511 // TODO: Create a dedicated alloca BasicBlock at function creation such that
512 // we do not need to move the current InertPoint here.
513 if (builder.GetInsertBlock() ==
514 &builder.GetInsertBlock()->getParent()->getEntryBlock()) {
515 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end() &&
516 "Assuming end of basic block");
517 llvm::BasicBlock *entryBB = llvm::BasicBlock::Create(
518 builder.getContext(), "entry", builder.GetInsertBlock()->getParent(),
519 builder.GetInsertBlock()->getNextNode());
520 builder.CreateBr(entryBB);
521 builder.SetInsertPoint(entryBB);
522 }
523
524 llvm::BasicBlock &funcEntryBlock =
525 builder.GetInsertBlock()->getParent()->getEntryBlock();
526 return llvm::OpenMPIRBuilder::InsertPointTy(
527 &funcEntryBlock, funcEntryBlock.getFirstInsertionPt());
528}
529
530/// Find the loop information structure for the loop nest being translated. It
531/// will return a `null` value unless called from the translation function for
532/// a loop wrapper operation after successfully translating its body.
533static llvm::CanonicalLoopInfo *
535 llvm::CanonicalLoopInfo *loopInfo = nullptr;
536 moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
537 [&](OpenMPLoopInfoStackFrame &frame) {
538 loopInfo = frame.loopInfo;
539 return WalkResult::interrupt();
540 });
541 return loopInfo;
542}
543
544/// Converts the given region that appears within an OpenMP dialect operation to
545/// LLVM IR, creating a branch from the `sourceBlock` to the entry block of the
546/// region, and a branch from any block with an successor-less OpenMP terminator
547/// to `continuationBlock`. Populates `continuationBlockPHIs` with the PHI nodes
548/// of the continuation block if provided.
550 Region &region, StringRef blockName, llvm::IRBuilderBase &builder,
551 LLVM::ModuleTranslation &moduleTranslation,
552 SmallVectorImpl<llvm::PHINode *> *continuationBlockPHIs = nullptr) {
553 bool isLoopWrapper = isa<omp::LoopWrapperInterface>(region.getParentOp());
554
555 llvm::BasicBlock *continuationBlock =
556 splitBB(builder, true, "omp.region.cont");
557 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
558
559 llvm::LLVMContext &llvmContext = builder.getContext();
560 for (Block &bb : region) {
561 llvm::BasicBlock *llvmBB = llvm::BasicBlock::Create(
562 llvmContext, blockName, builder.GetInsertBlock()->getParent(),
563 builder.GetInsertBlock()->getNextNode());
564 moduleTranslation.mapBlock(&bb, llvmBB);
565 }
566
567 llvm::Instruction *sourceTerminator = sourceBlock->getTerminator();
568
569 // Terminators (namely YieldOp) may be forwarding values to the region that
570 // need to be available in the continuation block. Collect the types of these
571 // operands in preparation of creating PHI nodes. This is skipped for loop
572 // wrapper operations, for which we know in advance they have no terminators.
573 SmallVector<llvm::Type *> continuationBlockPHITypes;
574 unsigned numYields = 0;
575
576 if (!isLoopWrapper) {
577 bool operandsProcessed = false;
578 for (Block &bb : region.getBlocks()) {
579 if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(bb.getTerminator())) {
580 if (!operandsProcessed) {
581 for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
582 continuationBlockPHITypes.push_back(
583 moduleTranslation.convertType(yield->getOperand(i).getType()));
584 }
585 operandsProcessed = true;
586 } else {
587 assert(continuationBlockPHITypes.size() == yield->getNumOperands() &&
588 "mismatching number of values yielded from the region");
589 for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
590 llvm::Type *operandType =
591 moduleTranslation.convertType(yield->getOperand(i).getType());
592 (void)operandType;
593 assert(continuationBlockPHITypes[i] == operandType &&
594 "values of mismatching types yielded from the region");
595 }
596 }
597 numYields++;
598 }
599 }
600 }
601
602 // Insert PHI nodes in the continuation block for any values forwarded by the
603 // terminators in this region.
604 if (!continuationBlockPHITypes.empty())
605 assert(
606 continuationBlockPHIs &&
607 "expected continuation block PHIs if converted regions yield values");
608 if (continuationBlockPHIs) {
609 llvm::IRBuilderBase::InsertPointGuard guard(builder);
610 continuationBlockPHIs->reserve(continuationBlockPHITypes.size());
611 builder.SetInsertPoint(continuationBlock, continuationBlock->begin());
612 for (llvm::Type *ty : continuationBlockPHITypes)
613 continuationBlockPHIs->push_back(builder.CreatePHI(ty, numYields));
614 }
615
616 // Convert blocks one by one in topological order to ensure
617 // defs are converted before uses.
619 for (Block *bb : blocks) {
620 llvm::BasicBlock *llvmBB = moduleTranslation.lookupBlock(bb);
621 // Retarget the branch of the entry block to the entry block of the
622 // converted region (regions are single-entry).
623 if (bb->isEntryBlock()) {
624 assert(sourceTerminator->getNumSuccessors() == 1 &&
625 "provided entry block has multiple successors");
626 assert(sourceTerminator->getSuccessor(0) == continuationBlock &&
627 "ContinuationBlock is not the successor of the entry block");
628 sourceTerminator->setSuccessor(0, llvmBB);
629 }
630
631 llvm::IRBuilderBase::InsertPointGuard guard(builder);
632 if (failed(
633 moduleTranslation.convertBlock(*bb, bb->isEntryBlock(), builder)))
634 return llvm::make_error<PreviouslyReportedError>();
635
636 // Create a direct branch here for loop wrappers to prevent their lack of a
637 // terminator from causing a crash below.
638 if (isLoopWrapper) {
639 builder.CreateBr(continuationBlock);
640 continue;
641 }
642
643 // Special handling for `omp.yield` and `omp.terminator` (we may have more
644 // than one): they return the control to the parent OpenMP dialect operation
645 // so replace them with the branch to the continuation block. We handle this
646 // here to avoid relying inter-function communication through the
647 // ModuleTranslation class to set up the correct insertion point. This is
648 // also consistent with MLIR's idiom of handling special region terminators
649 // in the same code that handles the region-owning operation.
650 Operation *terminator = bb->getTerminator();
651 if (isa<omp::TerminatorOp, omp::YieldOp>(terminator)) {
652 builder.CreateBr(continuationBlock);
653
654 for (unsigned i = 0, e = terminator->getNumOperands(); i < e; ++i)
655 (*continuationBlockPHIs)[i]->addIncoming(
656 moduleTranslation.lookupValue(terminator->getOperand(i)), llvmBB);
657 }
658 }
659 // After all blocks have been traversed and values mapped, connect the PHI
660 // nodes to the results of preceding blocks.
661 LLVM::detail::connectPHINodes(region, moduleTranslation);
662
663 // Remove the blocks and values defined in this region from the mapping since
664 // they are not visible outside of this region. This allows the same region to
665 // be converted several times, that is cloned, without clashes, and slightly
666 // speeds up the lookups.
667 moduleTranslation.forgetMapping(region);
668
669 return continuationBlock;
670}
671
672/// Convert ProcBindKind from MLIR-generated enum to LLVM enum.
673static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind) {
674 switch (kind) {
675 case omp::ClauseProcBindKind::Close:
676 return llvm::omp::ProcBindKind::OMP_PROC_BIND_close;
677 case omp::ClauseProcBindKind::Master:
678 return llvm::omp::ProcBindKind::OMP_PROC_BIND_master;
679 case omp::ClauseProcBindKind::Primary:
680 return llvm::omp::ProcBindKind::OMP_PROC_BIND_primary;
681 case omp::ClauseProcBindKind::Spread:
682 return llvm::omp::ProcBindKind::OMP_PROC_BIND_spread;
683 }
684 llvm_unreachable("Unknown ClauseProcBindKind kind");
685}
686
687/// Converts an OpenMP 'masked' operation into LLVM IR using OpenMPIRBuilder.
688static LogicalResult
689convertOmpMasked(Operation &opInst, llvm::IRBuilderBase &builder,
690 LLVM::ModuleTranslation &moduleTranslation) {
691 auto maskedOp = cast<omp::MaskedOp>(opInst);
692 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
693
694 if (failed(checkImplementationStatus(opInst)))
695 return failure();
696
697 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
698 // MaskedOp has only one region associated with it.
699 auto &region = maskedOp.getRegion();
700 builder.restoreIP(codeGenIP);
701 return convertOmpOpRegions(region, "omp.masked.region", builder,
702 moduleTranslation)
703 .takeError();
704 };
705
706 // TODO: Perform finalization actions for variables. This has to be
707 // called for variables which have destructors/finalizers.
708 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
709
710 llvm::Value *filterVal = nullptr;
711 if (auto filterVar = maskedOp.getFilteredThreadId()) {
712 filterVal = moduleTranslation.lookupValue(filterVar);
713 } else {
714 llvm::LLVMContext &llvmContext = builder.getContext();
715 filterVal =
716 llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext), /*V=*/0);
717 }
718 assert(filterVal != nullptr);
719 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
720 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
721 moduleTranslation.getOpenMPBuilder()->createMasked(ompLoc, bodyGenCB,
722 finiCB, filterVal);
723
724 if (failed(handleError(afterIP, opInst)))
725 return failure();
726
727 builder.restoreIP(*afterIP);
728 return success();
729}
730
731/// Converts an OpenMP 'master' operation into LLVM IR using OpenMPIRBuilder.
732static LogicalResult
733convertOmpMaster(Operation &opInst, llvm::IRBuilderBase &builder,
734 LLVM::ModuleTranslation &moduleTranslation) {
735 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
736 auto masterOp = cast<omp::MasterOp>(opInst);
737
738 if (failed(checkImplementationStatus(opInst)))
739 return failure();
740
741 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
742 // MasterOp has only one region associated with it.
743 auto &region = masterOp.getRegion();
744 builder.restoreIP(codeGenIP);
745 return convertOmpOpRegions(region, "omp.master.region", builder,
746 moduleTranslation)
747 .takeError();
748 };
749
750 // TODO: Perform finalization actions for variables. This has to be
751 // called for variables which have destructors/finalizers.
752 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
753
754 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
755 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
756 moduleTranslation.getOpenMPBuilder()->createMaster(ompLoc, bodyGenCB,
757 finiCB);
758
759 if (failed(handleError(afterIP, opInst)))
760 return failure();
761
762 builder.restoreIP(*afterIP);
763 return success();
764}
765
766/// Converts an OpenMP 'critical' operation into LLVM IR using OpenMPIRBuilder.
767static LogicalResult
768convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder,
769 LLVM::ModuleTranslation &moduleTranslation) {
770 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
771 auto criticalOp = cast<omp::CriticalOp>(opInst);
772
773 if (failed(checkImplementationStatus(opInst)))
774 return failure();
775
776 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
777 // CriticalOp has only one region associated with it.
778 auto &region = cast<omp::CriticalOp>(opInst).getRegion();
779 builder.restoreIP(codeGenIP);
780 return convertOmpOpRegions(region, "omp.critical.region", builder,
781 moduleTranslation)
782 .takeError();
783 };
784
785 // TODO: Perform finalization actions for variables. This has to be
786 // called for variables which have destructors/finalizers.
787 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
788
789 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
790 llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
791 llvm::Constant *hint = nullptr;
792
793 // If it has a name, it probably has a hint too.
794 if (criticalOp.getNameAttr()) {
795 // The verifiers in OpenMP Dialect guarentee that all the pointers are
796 // non-null
797 auto symbolRef = cast<SymbolRefAttr>(criticalOp.getNameAttr());
798 auto criticalDeclareOp =
800 symbolRef);
801 hint =
802 llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext),
803 static_cast<int>(criticalDeclareOp.getHint()));
804 }
805 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
806 moduleTranslation.getOpenMPBuilder()->createCritical(
807 ompLoc, bodyGenCB, finiCB, criticalOp.getName().value_or(""), hint);
808
809 if (failed(handleError(afterIP, opInst)))
810 return failure();
811
812 builder.restoreIP(*afterIP);
813 return success();
814}
815
816/// A util to collect info needed to convert delayed privatizers from MLIR to
817/// LLVM.
819 template <typename OP>
821 : blockArgs(
822 cast<omp::BlockArgOpenMPOpInterface>(*op).getPrivateBlockArgs()) {
823 mlirVars.reserve(blockArgs.size());
824 llvmVars.reserve(blockArgs.size());
825 collectPrivatizationDecls<OP>(op);
826
827 for (mlir::Value privateVar : op.getPrivateVars())
828 mlirVars.push_back(privateVar);
829 }
830
835
836private:
837 /// Populates `privatizations` with privatization declarations used for the
838 /// given op.
839 template <class OP>
840 void collectPrivatizationDecls(OP op) {
841 std::optional<ArrayAttr> attr = op.getPrivateSyms();
842 if (!attr)
843 return;
844
845 privatizers.reserve(privatizers.size() + attr->size());
846 for (auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
847 privatizers.push_back(findPrivatizer(op, symbolRef));
848 }
849 }
850};
851
852/// Populates `reductions` with reduction declarations used in the given op.
853template <typename T>
854static void
857 std::optional<ArrayAttr> attr = op.getReductionSyms();
858 if (!attr)
859 return;
860
861 reductions.reserve(reductions.size() + op.getNumReductionVars());
862 for (auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
863 reductions.push_back(
865 op, symbolRef));
866 }
867}
868
869/// Translates the blocks contained in the given region and appends them to at
870/// the current insertion point of `builder`. The operations of the entry block
871/// are appended to the current insertion block. If set, `continuationBlockArgs`
872/// is populated with translated values that correspond to the values
873/// omp.yield'ed from the region.
874static LogicalResult inlineConvertOmpRegions(
875 Region &region, StringRef blockName, llvm::IRBuilderBase &builder,
876 LLVM::ModuleTranslation &moduleTranslation,
877 SmallVectorImpl<llvm::Value *> *continuationBlockArgs = nullptr) {
878 if (region.empty())
879 return success();
880
881 // Special case for single-block regions that don't create additional blocks:
882 // insert operations without creating additional blocks.
883 if (region.hasOneBlock()) {
884 llvm::Instruction *potentialTerminator =
885 builder.GetInsertBlock()->empty() ? nullptr
886 : &builder.GetInsertBlock()->back();
887
888 if (potentialTerminator && potentialTerminator->isTerminator())
889 potentialTerminator->removeFromParent();
890 moduleTranslation.mapBlock(&region.front(), builder.GetInsertBlock());
891
892 if (failed(moduleTranslation.convertBlock(
893 region.front(), /*ignoreArguments=*/true, builder)))
894 return failure();
895
896 // The continuation arguments are simply the translated terminator operands.
897 if (continuationBlockArgs)
898 llvm::append_range(
899 *continuationBlockArgs,
900 moduleTranslation.lookupValues(region.front().back().getOperands()));
901
902 // Drop the mapping that is no longer necessary so that the same region can
903 // be processed multiple times.
904 moduleTranslation.forgetMapping(region);
905
906 if (potentialTerminator && potentialTerminator->isTerminator()) {
907 llvm::BasicBlock *block = builder.GetInsertBlock();
908 if (block->empty()) {
909 // this can happen for really simple reduction init regions e.g.
910 // %0 = llvm.mlir.constant(0 : i32) : i32
911 // omp.yield(%0 : i32)
912 // because the llvm.mlir.constant (MLIR op) isn't converted into any
913 // llvm op
914 potentialTerminator->insertInto(block, block->begin());
915 } else {
916 potentialTerminator->insertAfter(&block->back());
917 }
918 }
919
920 return success();
921 }
922
924 llvm::Expected<llvm::BasicBlock *> continuationBlock =
925 convertOmpOpRegions(region, blockName, builder, moduleTranslation, &phis);
926
927 if (failed(handleError(continuationBlock, *region.getParentOp())))
928 return failure();
929
930 if (continuationBlockArgs)
931 llvm::append_range(*continuationBlockArgs, phis);
932 builder.SetInsertPoint(*continuationBlock,
933 (*continuationBlock)->getFirstInsertionPt());
934 return success();
935}
936
937namespace {
938/// Owning equivalents of OpenMPIRBuilder::(Atomic)ReductionGen that are used to
939/// store lambdas with capture.
940using OwningReductionGen =
941 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
942 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *,
943 llvm::Value *&)>;
944using OwningAtomicReductionGen =
945 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
946 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Type *, llvm::Value *,
947 llvm::Value *)>;
948using OwningDataPtrPtrReductionGen =
949 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
950 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *&)>;
951} // namespace
952
953/// Create an OpenMPIRBuilder-compatible reduction generator for the given
954/// reduction declaration. The generator uses `builder` but ignores its
955/// insertion point.
956static OwningReductionGen
957makeReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder,
958 LLVM::ModuleTranslation &moduleTranslation) {
959 // The lambda is mutable because we need access to non-const methods of decl
960 // (which aren't actually mutating it), and we must capture decl by-value to
961 // avoid the dangling reference after the parent function returns.
962 OwningReductionGen gen =
963 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
964 llvm::Value *lhs, llvm::Value *rhs,
965 llvm::Value *&result) mutable
966 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
967 moduleTranslation.mapValue(decl.getReductionLhsArg(), lhs);
968 moduleTranslation.mapValue(decl.getReductionRhsArg(), rhs);
969 builder.restoreIP(insertPoint);
971 if (failed(inlineConvertOmpRegions(decl.getReductionRegion(),
972 "omp.reduction.nonatomic.body", builder,
973 moduleTranslation, &phis)))
974 return llvm::createStringError(
975 "failed to inline `combiner` region of `omp.declare_reduction`");
976 result = llvm::getSingleElement(phis);
977 return builder.saveIP();
978 };
979 return gen;
980}
981
982/// Create an OpenMPIRBuilder-compatible atomic reduction generator for the
983/// given reduction declaration. The generator uses `builder` but ignores its
984/// insertion point. Returns null if there is no atomic region available in the
985/// reduction declaration.
986static OwningAtomicReductionGen
987makeAtomicReductionGen(omp::DeclareReductionOp decl,
988 llvm::IRBuilderBase &builder,
989 LLVM::ModuleTranslation &moduleTranslation) {
990 if (decl.getAtomicReductionRegion().empty())
991 return OwningAtomicReductionGen();
992
993 // The lambda is mutable because we need access to non-const methods of decl
994 // (which aren't actually mutating it), and we must capture decl by-value to
995 // avoid the dangling reference after the parent function returns.
996 OwningAtomicReductionGen atomicGen =
997 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint, llvm::Type *,
998 llvm::Value *lhs, llvm::Value *rhs) mutable
999 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1000 moduleTranslation.mapValue(decl.getAtomicReductionLhsArg(), lhs);
1001 moduleTranslation.mapValue(decl.getAtomicReductionRhsArg(), rhs);
1002 builder.restoreIP(insertPoint);
1004 if (failed(inlineConvertOmpRegions(decl.getAtomicReductionRegion(),
1005 "omp.reduction.atomic.body", builder,
1006 moduleTranslation, &phis)))
1007 return llvm::createStringError(
1008 "failed to inline `atomic` region of `omp.declare_reduction`");
1009 assert(phis.empty());
1010 return builder.saveIP();
1011 };
1012 return atomicGen;
1013}
1014
1015/// Create an OpenMPIRBuilder-compatible `data_ptr_ptr` reduction generator for
1016/// the given reduction declaration. The generator uses `builder` but ignores
1017/// its insertion point. Returns null if there is no `data_ptr_ptr` region
1018/// available in the reduction declaration.
1019static OwningDataPtrPtrReductionGen
1020makeRefDataPtrGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder,
1021 LLVM::ModuleTranslation &moduleTranslation, bool isByRef) {
1022 if (!isByRef)
1023 return OwningDataPtrPtrReductionGen();
1024
1025 OwningDataPtrPtrReductionGen refDataPtrGen =
1026 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
1027 llvm::Value *byRefVal, llvm::Value *&result) mutable
1028 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1029 moduleTranslation.mapValue(decl.getDataPtrPtrRegionArg(), byRefVal);
1030 builder.restoreIP(insertPoint);
1032 if (failed(inlineConvertOmpRegions(decl.getDataPtrPtrRegion(),
1033 "omp.data_ptr_ptr.body", builder,
1034 moduleTranslation, &phis)))
1035 return llvm::createStringError(
1036 "failed to inline `data_ptr_ptr` region of `omp.declare_reduction`");
1037 result = llvm::getSingleElement(phis);
1038 return builder.saveIP();
1039 };
1040
1041 return refDataPtrGen;
1042}
1043
1044/// Converts an OpenMP 'ordered' operation into LLVM IR using OpenMPIRBuilder.
1045static LogicalResult
1046convertOmpOrdered(Operation &opInst, llvm::IRBuilderBase &builder,
1047 LLVM::ModuleTranslation &moduleTranslation) {
1048 auto orderedOp = cast<omp::OrderedOp>(opInst);
1049
1050 if (failed(checkImplementationStatus(opInst)))
1051 return failure();
1052
1053 omp::ClauseDepend dependType = *orderedOp.getDoacrossDependType();
1054 bool isDependSource = dependType == omp::ClauseDepend::dependsource;
1055 unsigned numLoops = *orderedOp.getDoacrossNumLoops();
1056 SmallVector<llvm::Value *> vecValues =
1057 moduleTranslation.lookupValues(orderedOp.getDoacrossDependVars());
1058
1059 size_t indexVecValues = 0;
1060 while (indexVecValues < vecValues.size()) {
1061 SmallVector<llvm::Value *> storeValues;
1062 storeValues.reserve(numLoops);
1063 for (unsigned i = 0; i < numLoops; i++) {
1064 storeValues.push_back(vecValues[indexVecValues]);
1065 indexVecValues++;
1066 }
1067 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1068 findAllocaInsertPoint(builder, moduleTranslation);
1069 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1070 builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createOrderedDepend(
1071 ompLoc, allocaIP, numLoops, storeValues, ".cnt.addr", isDependSource));
1072 }
1073 return success();
1074}
1075
1076/// Converts an OpenMP 'ordered_region' operation into LLVM IR using
1077/// OpenMPIRBuilder.
1078static LogicalResult
1079convertOmpOrderedRegion(Operation &opInst, llvm::IRBuilderBase &builder,
1080 LLVM::ModuleTranslation &moduleTranslation) {
1081 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1082 auto orderedRegionOp = cast<omp::OrderedRegionOp>(opInst);
1083
1084 if (failed(checkImplementationStatus(opInst)))
1085 return failure();
1086
1087 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
1088 // OrderedOp has only one region associated with it.
1089 auto &region = cast<omp::OrderedRegionOp>(opInst).getRegion();
1090 builder.restoreIP(codeGenIP);
1091 return convertOmpOpRegions(region, "omp.ordered.region", builder,
1092 moduleTranslation)
1093 .takeError();
1094 };
1095
1096 // TODO: Perform finalization actions for variables. This has to be
1097 // called for variables which have destructors/finalizers.
1098 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
1099
1100 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1101 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1102 moduleTranslation.getOpenMPBuilder()->createOrderedThreadsSimd(
1103 ompLoc, bodyGenCB, finiCB, !orderedRegionOp.getParLevelSimd());
1104
1105 if (failed(handleError(afterIP, opInst)))
1106 return failure();
1107
1108 builder.restoreIP(*afterIP);
1109 return success();
1110}
1111
1112namespace {
1113/// Contains the arguments for an LLVM store operation
1114struct DeferredStore {
1115 DeferredStore(llvm::Value *value, llvm::Value *address)
1116 : value(value), address(address) {}
1117
1118 llvm::Value *value;
1119 llvm::Value *address;
1120};
1121} // namespace
1122
1123/// Allocate space for privatized reduction variables.
1124/// `deferredStores` contains information to create store operations which needs
1125/// to be inserted after all allocas
1126template <typename T>
1127static LogicalResult
1129 llvm::IRBuilderBase &builder,
1130 LLVM::ModuleTranslation &moduleTranslation,
1131 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1133 SmallVectorImpl<llvm::Value *> &privateReductionVariables,
1134 DenseMap<Value, llvm::Value *> &reductionVariableMap,
1135 SmallVectorImpl<DeferredStore> &deferredStores,
1136 llvm::ArrayRef<bool> isByRefs) {
1137 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1138 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1139
1140 // delay creating stores until after all allocas
1141 deferredStores.reserve(loop.getNumReductionVars());
1142
1143 for (std::size_t i = 0; i < loop.getNumReductionVars(); ++i) {
1144 Region &allocRegion = reductionDecls[i].getAllocRegion();
1145 if (isByRefs[i]) {
1146 if (allocRegion.empty())
1147 continue;
1148
1150 if (failed(inlineConvertOmpRegions(allocRegion, "omp.reduction.alloc",
1151 builder, moduleTranslation, &phis)))
1152 return loop.emitError(
1153 "failed to inline `alloc` region of `omp.declare_reduction`");
1154
1155 assert(phis.size() == 1 && "expected one allocation to be yielded");
1156 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1157
1158 // Allocate reduction variable (which is a pointer to the real reduction
1159 // variable allocated in the inlined region)
1160 llvm::Value *var = builder.CreateAlloca(
1161 moduleTranslation.convertType(reductionDecls[i].getType()));
1162
1163 llvm::Type *ptrTy = builder.getPtrTy();
1164 llvm::Value *castVar =
1165 builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
1166 llvm::Value *castPhi =
1167 builder.CreatePointerBitCastOrAddrSpaceCast(phis[0], ptrTy);
1168
1169 deferredStores.emplace_back(castPhi, castVar);
1170
1171 privateReductionVariables[i] = castVar;
1172 moduleTranslation.mapValue(reductionArgs[i], castPhi);
1173 reductionVariableMap.try_emplace(loop.getReductionVars()[i], castPhi);
1174 } else {
1175 assert(allocRegion.empty() &&
1176 "allocaction is implicit for by-val reduction");
1177 llvm::Value *var = builder.CreateAlloca(
1178 moduleTranslation.convertType(reductionDecls[i].getType()));
1179
1180 llvm::Type *ptrTy = builder.getPtrTy();
1181 llvm::Value *castVar =
1182 builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
1183
1184 moduleTranslation.mapValue(reductionArgs[i], castVar);
1185 privateReductionVariables[i] = castVar;
1186 reductionVariableMap.try_emplace(loop.getReductionVars()[i], castVar);
1187 }
1188 }
1189
1190 return success();
1191}
1192
1193/// Map input arguments to reduction initialization region
1194template <typename T>
1195static void
1197 llvm::IRBuilderBase &builder,
1199 DenseMap<Value, llvm::Value *> &reductionVariableMap,
1200 unsigned i) {
1201 // map input argument to the initialization region
1202 mlir::omp::DeclareReductionOp &reduction = reductionDecls[i];
1203 Region &initializerRegion = reduction.getInitializerRegion();
1204 Block &entry = initializerRegion.front();
1205
1206 mlir::Value mlirSource = loop.getReductionVars()[i];
1207 llvm::Value *llvmSource = moduleTranslation.lookupValue(mlirSource);
1208 llvm::Value *origVal = llvmSource;
1209 // If a non-pointer value is expected, load the value from the source pointer.
1210 if (!isa<LLVM::LLVMPointerType>(
1211 reduction.getInitializerMoldArg().getType()) &&
1212 isa<LLVM::LLVMPointerType>(mlirSource.getType())) {
1213 origVal =
1214 builder.CreateLoad(moduleTranslation.convertType(
1215 reduction.getInitializerMoldArg().getType()),
1216 llvmSource, "omp_orig");
1217 }
1218 moduleTranslation.mapValue(reduction.getInitializerMoldArg(), origVal);
1219
1220 if (entry.getNumArguments() > 1) {
1221 llvm::Value *allocation =
1222 reductionVariableMap.lookup(loop.getReductionVars()[i]);
1223 moduleTranslation.mapValue(reduction.getInitializerAllocArg(), allocation);
1224 }
1225}
1226
1227static void
1228setInsertPointForPossiblyEmptyBlock(llvm::IRBuilderBase &builder,
1229 llvm::BasicBlock *block = nullptr) {
1230 if (block == nullptr)
1231 block = builder.GetInsertBlock();
1232
1233 if (block->empty() || block->getTerminator() == nullptr)
1234 builder.SetInsertPoint(block);
1235 else
1236 builder.SetInsertPoint(block->getTerminator());
1237}
1238
1239/// Inline reductions' `init` regions. This functions assumes that the
1240/// `builder`'s insertion point is where the user wants the `init` regions to be
1241/// inlined; i.e. it does not try to find a proper insertion location for the
1242/// `init` regions. It also leaves the `builder's insertions point in a state
1243/// where the user can continue the code-gen directly afterwards.
1244template <typename OP>
1245static LogicalResult
1246initReductionVars(OP op, ArrayRef<BlockArgument> reductionArgs,
1247 llvm::IRBuilderBase &builder,
1248 LLVM::ModuleTranslation &moduleTranslation,
1249 llvm::BasicBlock *latestAllocaBlock,
1251 SmallVectorImpl<llvm::Value *> &privateReductionVariables,
1252 DenseMap<Value, llvm::Value *> &reductionVariableMap,
1253 llvm::ArrayRef<bool> isByRef,
1254 SmallVectorImpl<DeferredStore> &deferredStores) {
1255 if (op.getNumReductionVars() == 0)
1256 return success();
1257
1258 llvm::BasicBlock *initBlock = splitBB(builder, true, "omp.reduction.init");
1259 auto allocaIP = llvm::IRBuilderBase::InsertPoint(
1260 latestAllocaBlock, latestAllocaBlock->getTerminator()->getIterator());
1261 builder.restoreIP(allocaIP);
1262 SmallVector<llvm::Value *> byRefVars(op.getNumReductionVars());
1263
1264 for (unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1265 if (isByRef[i]) {
1266 if (!reductionDecls[i].getAllocRegion().empty())
1267 continue;
1268
1269 // TODO: remove after all users of by-ref are updated to use the alloc
1270 // region: Allocate reduction variable (which is a pointer to the real
1271 // reduciton variable allocated in the inlined region)
1272 byRefVars[i] = builder.CreateAlloca(
1273 moduleTranslation.convertType(reductionDecls[i].getType()));
1274 }
1275 }
1276
1277 setInsertPointForPossiblyEmptyBlock(builder, initBlock);
1278
1279 // store result of the alloc region to the allocated pointer to the real
1280 // reduction variable
1281 for (auto [data, addr] : deferredStores)
1282 builder.CreateStore(data, addr);
1283
1284 // Before the loop, store the initial values of reductions into reduction
1285 // variables. Although this could be done after allocas, we don't want to mess
1286 // up with the alloca insertion point.
1287 for (unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1289
1290 // map block argument to initializer region
1291 mapInitializationArgs(op, moduleTranslation, builder, reductionDecls,
1292 reductionVariableMap, i);
1293
1294 // TODO In some cases (specially on the GPU), the init regions may
1295 // contains stack alloctaions. If the region is inlined in a loop, this is
1296 // problematic. Instead of just inlining the region, handle allocations by
1297 // hoisting fixed length allocations to the function entry and using
1298 // stacksave and restore for variable length ones.
1299 if (failed(inlineConvertOmpRegions(reductionDecls[i].getInitializerRegion(),
1300 "omp.reduction.neutral", builder,
1301 moduleTranslation, &phis)))
1302 return failure();
1303
1304 assert(phis.size() == 1 && "expected one value to be yielded from the "
1305 "reduction neutral element declaration region");
1306
1308
1309 if (isByRef[i]) {
1310 if (!reductionDecls[i].getAllocRegion().empty())
1311 // done in allocReductionVars
1312 continue;
1313
1314 // TODO: this path can be removed once all users of by-ref are updated to
1315 // use an alloc region
1316
1317 // Store the result of the inlined region to the allocated reduction var
1318 // ptr
1319 builder.CreateStore(phis[0], byRefVars[i]);
1320
1321 privateReductionVariables[i] = byRefVars[i];
1322 moduleTranslation.mapValue(reductionArgs[i], phis[0]);
1323 reductionVariableMap.try_emplace(op.getReductionVars()[i], phis[0]);
1324 } else {
1325 // for by-ref case the store is inside of the reduction region
1326 builder.CreateStore(phis[0], privateReductionVariables[i]);
1327 // the rest was handled in allocByValReductionVars
1328 }
1329
1330 // forget the mapping for the initializer region because we might need a
1331 // different mapping if this reduction declaration is re-used for a
1332 // different variable
1333 moduleTranslation.forgetMapping(reductionDecls[i].getInitializerRegion());
1334 }
1335
1336 return success();
1337}
1338
1339/// Collect reduction info
1340template <typename T>
1341static void collectReductionInfo(
1342 T loop, llvm::IRBuilderBase &builder,
1343 LLVM::ModuleTranslation &moduleTranslation,
1346 SmallVectorImpl<OwningAtomicReductionGen> &owningAtomicReductionGens,
1348 const ArrayRef<llvm::Value *> privateReductionVariables,
1350 ArrayRef<bool> isByRef) {
1351 unsigned numReductions = loop.getNumReductionVars();
1352
1353 for (unsigned i = 0; i < numReductions; ++i) {
1354 owningReductionGens.push_back(
1355 makeReductionGen(reductionDecls[i], builder, moduleTranslation));
1356 owningAtomicReductionGens.push_back(
1357 makeAtomicReductionGen(reductionDecls[i], builder, moduleTranslation));
1359 reductionDecls[i], builder, moduleTranslation, isByRef[i]));
1360 }
1361
1362 // Collect the reduction information.
1363 reductionInfos.reserve(numReductions);
1364 for (unsigned i = 0; i < numReductions; ++i) {
1365 llvm::OpenMPIRBuilder::ReductionGenAtomicCBTy atomicGen = nullptr;
1366 if (owningAtomicReductionGens[i])
1367 atomicGen = owningAtomicReductionGens[i];
1368 llvm::Value *variable =
1369 moduleTranslation.lookupValue(loop.getReductionVars()[i]);
1370 mlir::Type allocatedType;
1371 reductionDecls[i].getAllocRegion().walk([&](mlir::Operation *op) {
1372 if (auto alloca = mlir::dyn_cast<LLVM::AllocaOp>(op)) {
1373 allocatedType = alloca.getElemType();
1375 }
1376
1378 });
1379
1380 reductionInfos.push_back(
1381 {moduleTranslation.convertType(reductionDecls[i].getType()), variable,
1382 privateReductionVariables[i],
1383 /*EvaluationKind=*/llvm::OpenMPIRBuilder::EvalKind::Scalar,
1385 /*ReductionGenClang=*/nullptr, atomicGen,
1387 allocatedType ? moduleTranslation.convertType(allocatedType) : nullptr,
1388 reductionDecls[i].getByrefElementType()
1389 ? moduleTranslation.convertType(
1390 *reductionDecls[i].getByrefElementType())
1391 : nullptr});
1392 }
1393}
1394
1395/// handling of DeclareReductionOp's cleanup region
1396static LogicalResult
1398 llvm::ArrayRef<llvm::Value *> privateVariables,
1399 LLVM::ModuleTranslation &moduleTranslation,
1400 llvm::IRBuilderBase &builder, StringRef regionName,
1401 bool shouldLoadCleanupRegionArg = true) {
1402 for (auto [i, cleanupRegion] : llvm::enumerate(cleanupRegions)) {
1403 if (cleanupRegion->empty())
1404 continue;
1405
1406 // map the argument to the cleanup region
1407 Block &entry = cleanupRegion->front();
1408
1409 llvm::Instruction *potentialTerminator =
1410 builder.GetInsertBlock()->empty() ? nullptr
1411 : &builder.GetInsertBlock()->back();
1412 if (potentialTerminator && potentialTerminator->isTerminator())
1413 builder.SetInsertPoint(potentialTerminator);
1414 llvm::Value *privateVarValue =
1415 shouldLoadCleanupRegionArg
1416 ? builder.CreateLoad(
1417 moduleTranslation.convertType(entry.getArgument(0).getType()),
1418 privateVariables[i])
1419 : privateVariables[i];
1420
1421 moduleTranslation.mapValue(entry.getArgument(0), privateVarValue);
1422
1423 if (failed(inlineConvertOmpRegions(*cleanupRegion, regionName, builder,
1424 moduleTranslation)))
1425 return failure();
1426
1427 // clear block argument mapping in case it needs to be re-created with a
1428 // different source for another use of the same reduction decl
1429 moduleTranslation.forgetMapping(*cleanupRegion);
1430 }
1431 return success();
1432}
1433
1434// TODO: not used by ParallelOp
1435template <class OP>
1436static LogicalResult createReductionsAndCleanup(
1437 OP op, llvm::IRBuilderBase &builder,
1438 LLVM::ModuleTranslation &moduleTranslation,
1439 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1441 ArrayRef<llvm::Value *> privateReductionVariables, ArrayRef<bool> isByRef,
1442 bool isNowait = false, bool isTeamsReduction = false) {
1443 // Process the reductions if required.
1444 if (op.getNumReductionVars() == 0)
1445 return success();
1446
1448 SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
1449 SmallVector<OwningDataPtrPtrReductionGen> owningReductionGenRefDataPtrGens;
1451
1452 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
1453
1454 // Create the reduction generators. We need to own them here because
1455 // ReductionInfo only accepts references to the generators.
1456 collectReductionInfo(op, builder, moduleTranslation, reductionDecls,
1457 owningReductionGens, owningAtomicReductionGens,
1458 owningReductionGenRefDataPtrGens,
1459 privateReductionVariables, reductionInfos, isByRef);
1460
1461 // The call to createReductions below expects the block to have a
1462 // terminator. Create an unreachable instruction to serve as terminator
1463 // and remove it later.
1464 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
1465 builder.SetInsertPoint(tempTerminator);
1466 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
1467 ompBuilder->createReductions(builder.saveIP(), allocaIP, reductionInfos,
1468 isByRef, isNowait, isTeamsReduction);
1469
1470 if (failed(handleError(contInsertPoint, *op)))
1471 return failure();
1472
1473 if (!contInsertPoint->getBlock())
1474 return op->emitOpError() << "failed to convert reductions";
1475
1476 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1477 ompBuilder->createBarrier(*contInsertPoint, llvm::omp::OMPD_for);
1478
1479 if (failed(handleError(afterIP, *op)))
1480 return failure();
1481
1482 tempTerminator->eraseFromParent();
1483 builder.restoreIP(*afterIP);
1484
1485 // after the construct, deallocate private reduction variables
1486 SmallVector<Region *> reductionRegions;
1487 llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
1488 [](omp::DeclareReductionOp reductionDecl) {
1489 return &reductionDecl.getCleanupRegion();
1490 });
1491 return inlineOmpRegionCleanup(reductionRegions, privateReductionVariables,
1492 moduleTranslation, builder,
1493 "omp.reduction.cleanup");
1494 return success();
1495}
1496
1497static ArrayRef<bool> getIsByRef(std::optional<ArrayRef<bool>> attr) {
1498 if (!attr)
1499 return {};
1500 return *attr;
1501}
1502
1503// TODO: not used by omp.parallel
1504template <typename OP>
1506 OP op, ArrayRef<BlockArgument> reductionArgs, llvm::IRBuilderBase &builder,
1507 LLVM::ModuleTranslation &moduleTranslation,
1508 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1510 SmallVectorImpl<llvm::Value *> &privateReductionVariables,
1511 DenseMap<Value, llvm::Value *> &reductionVariableMap,
1512 llvm::ArrayRef<bool> isByRef) {
1513 if (op.getNumReductionVars() == 0)
1514 return success();
1515
1516 SmallVector<DeferredStore> deferredStores;
1517
1518 if (failed(allocReductionVars(op, reductionArgs, builder, moduleTranslation,
1519 allocaIP, reductionDecls,
1520 privateReductionVariables, reductionVariableMap,
1521 deferredStores, isByRef)))
1522 return failure();
1523
1524 return initReductionVars(op, reductionArgs, builder, moduleTranslation,
1525 allocaIP.getBlock(), reductionDecls,
1526 privateReductionVariables, reductionVariableMap,
1527 isByRef, deferredStores);
1528}
1529
1530/// Return the llvm::Value * corresponding to the `privateVar` that
1531/// is being privatized. It isn't always as simple as looking up
1532/// moduleTranslation with privateVar. For instance, in case of
1533/// an allocatable, the descriptor for the allocatable is privatized.
1534/// This descriptor is mapped using an MapInfoOp. So, this function
1535/// will return a pointer to the llvm::Value corresponding to the
1536/// block argument for the mapped descriptor.
1537static llvm::Value *
1538findAssociatedValue(Value privateVar, llvm::IRBuilderBase &builder,
1539 LLVM::ModuleTranslation &moduleTranslation,
1540 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1541 if (mappedPrivateVars == nullptr || !mappedPrivateVars->contains(privateVar))
1542 return moduleTranslation.lookupValue(privateVar);
1543
1544 Value blockArg = (*mappedPrivateVars)[privateVar];
1545 Type privVarType = privateVar.getType();
1546 Type blockArgType = blockArg.getType();
1547 assert(isa<LLVM::LLVMPointerType>(blockArgType) &&
1548 "A block argument corresponding to a mapped var should have "
1549 "!llvm.ptr type");
1550
1551 if (privVarType == blockArgType)
1552 return moduleTranslation.lookupValue(blockArg);
1553
1554 // This typically happens when the privatized type is lowered from
1555 // boxchar<KIND> and gets lowered to !llvm.struct<(ptr, i64)>. That is the
1556 // struct/pair is passed by value. But, mapped values are passed only as
1557 // pointers, so before we privatize, we must load the pointer.
1558 if (!isa<LLVM::LLVMPointerType>(privVarType))
1559 return builder.CreateLoad(moduleTranslation.convertType(privVarType),
1560 moduleTranslation.lookupValue(blockArg));
1561
1562 return moduleTranslation.lookupValue(privateVar);
1563}
1564
1565/// Initialize a single (first)private variable. You probably want to use
1566/// allocateAndInitPrivateVars instead of this.
1567/// This returns the private variable which has been initialized. This
1568/// variable should be mapped before constructing the body of the Op.
1570initPrivateVar(llvm::IRBuilderBase &builder,
1571 LLVM::ModuleTranslation &moduleTranslation,
1572 omp::PrivateClauseOp &privDecl, llvm::Value *nonPrivateVar,
1573 BlockArgument &blockArg, llvm::Value *llvmPrivateVar,
1574 llvm::BasicBlock *privInitBlock,
1575 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1576 Region &initRegion = privDecl.getInitRegion();
1577 if (initRegion.empty())
1578 return llvmPrivateVar;
1579
1580 assert(nonPrivateVar);
1581 moduleTranslation.mapValue(privDecl.getInitMoldArg(), nonPrivateVar);
1582 moduleTranslation.mapValue(privDecl.getInitPrivateArg(), llvmPrivateVar);
1583
1584 // in-place convert the private initialization region
1586 if (failed(inlineConvertOmpRegions(initRegion, "omp.private.init", builder,
1587 moduleTranslation, &phis)))
1588 return llvm::createStringError(
1589 "failed to inline `init` region of `omp.private`");
1590
1591 assert(phis.size() == 1 && "expected one allocation to be yielded");
1592
1593 // clear init region block argument mapping in case it needs to be
1594 // re-created with a different source for another use of the same
1595 // reduction decl
1596 moduleTranslation.forgetMapping(initRegion);
1597
1598 // Prefer the value yielded from the init region to the allocated private
1599 // variable in case the region is operating on arguments by-value (e.g.
1600 // Fortran character boxes).
1601 return phis[0];
1602}
1603
1604/// Version of initPrivateVar which looks up the nonPrivateVar from mlirPrivVar.
1606 llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation,
1607 omp::PrivateClauseOp &privDecl, Value mlirPrivVar, BlockArgument &blockArg,
1608 llvm::Value *llvmPrivateVar, llvm::BasicBlock *privInitBlock,
1609 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1610 return initPrivateVar(
1611 builder, moduleTranslation, privDecl,
1612 findAssociatedValue(mlirPrivVar, builder, moduleTranslation,
1613 mappedPrivateVars),
1614 blockArg, llvmPrivateVar, privInitBlock, mappedPrivateVars);
1615}
1616
1617static llvm::Error
1618initPrivateVars(llvm::IRBuilderBase &builder,
1619 LLVM::ModuleTranslation &moduleTranslation,
1620 PrivateVarsInfo &privateVarsInfo,
1621 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1622 if (privateVarsInfo.blockArgs.empty())
1623 return llvm::Error::success();
1624
1625 llvm::BasicBlock *privInitBlock = splitBB(builder, true, "omp.private.init");
1626 setInsertPointForPossiblyEmptyBlock(builder, privInitBlock);
1627
1628 for (auto [idx, zip] : llvm::enumerate(llvm::zip_equal(
1629 privateVarsInfo.privatizers, privateVarsInfo.mlirVars,
1630 privateVarsInfo.blockArgs, privateVarsInfo.llvmVars))) {
1631 auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVar] = zip;
1633 builder, moduleTranslation, privDecl, mlirPrivVar, blockArg,
1634 llvmPrivateVar, privInitBlock, mappedPrivateVars);
1635
1636 if (!privVarOrErr)
1637 return privVarOrErr.takeError();
1638
1639 llvmPrivateVar = privVarOrErr.get();
1640 moduleTranslation.mapValue(blockArg, llvmPrivateVar);
1641
1643 }
1644
1645 return llvm::Error::success();
1646}
1647
1648/// Allocate and initialize delayed private variables. Returns the basic block
1649/// which comes after all of these allocations. llvm::Value * for each of these
1650/// private variables are populated in llvmPrivateVars.
1652allocatePrivateVars(llvm::IRBuilderBase &builder,
1653 LLVM::ModuleTranslation &moduleTranslation,
1654 PrivateVarsInfo &privateVarsInfo,
1655 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1656 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1657 // Allocate private vars
1658 llvm::Instruction *allocaTerminator = allocaIP.getBlock()->getTerminator();
1659 splitBB(llvm::OpenMPIRBuilder::InsertPointTy(allocaIP.getBlock(),
1660 allocaTerminator->getIterator()),
1661 true, allocaTerminator->getStableDebugLoc(),
1662 "omp.region.after_alloca");
1663
1664 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1665 // Update the allocaTerminator since the alloca block was split above.
1666 allocaTerminator = allocaIP.getBlock()->getTerminator();
1667 builder.SetInsertPoint(allocaTerminator);
1668 // The new terminator is an uncondition branch created by the splitBB above.
1669 assert(allocaTerminator->getNumSuccessors() == 1 &&
1670 "This is an unconditional branch created by splitBB");
1671
1672 llvm::DataLayout dataLayout = builder.GetInsertBlock()->getDataLayout();
1673 llvm::BasicBlock *afterAllocas = allocaTerminator->getSuccessor(0);
1674
1675 unsigned int allocaAS =
1676 moduleTranslation.getLLVMModule()->getDataLayout().getAllocaAddrSpace();
1677 unsigned int defaultAS = moduleTranslation.getLLVMModule()
1678 ->getDataLayout()
1679 .getProgramAddressSpace();
1680
1681 for (auto [privDecl, mlirPrivVar, blockArg] :
1682 llvm::zip_equal(privateVarsInfo.privatizers, privateVarsInfo.mlirVars,
1683 privateVarsInfo.blockArgs)) {
1684 llvm::Type *llvmAllocType =
1685 moduleTranslation.convertType(privDecl.getType());
1686 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1687 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
1688 llvmAllocType, /*ArraySize=*/nullptr, "omp.private.alloc");
1689 if (allocaAS != defaultAS)
1690 llvmPrivateVar = builder.CreateAddrSpaceCast(llvmPrivateVar,
1691 builder.getPtrTy(defaultAS));
1692
1693 privateVarsInfo.llvmVars.push_back(llvmPrivateVar);
1694 }
1695
1696 return afterAllocas;
1697}
1698
1699/// This can't always be determined statically, but when we can, it is good to
1700/// avoid generating compiler-added barriers which will deadlock the program.
1702 for (mlir::Operation *parent = op->getParentOp(); parent != nullptr;
1703 parent = parent->getParentOp()) {
1704 if (mlir::isa<omp::SingleOp, omp::CriticalOp>(parent))
1705 return true;
1706
1707 // e.g.
1708 // omp.single {
1709 // omp.parallel {
1710 // op
1711 // }
1712 // }
1713 if (mlir::isa<omp::ParallelOp>(parent))
1714 return false;
1715 }
1716 return false;
1717}
1718
1719static LogicalResult copyFirstPrivateVars(
1720 mlir::Operation *op, llvm::IRBuilderBase &builder,
1721 LLVM::ModuleTranslation &moduleTranslation,
1723 ArrayRef<llvm::Value *> llvmPrivateVars,
1724 SmallVectorImpl<omp::PrivateClauseOp> &privateDecls, bool insertBarrier,
1725 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1726 // Apply copy region for firstprivate.
1727 bool needsFirstprivate =
1728 llvm::any_of(privateDecls, [](omp::PrivateClauseOp &privOp) {
1729 return privOp.getDataSharingType() ==
1730 omp::DataSharingClauseType::FirstPrivate;
1731 });
1732
1733 if (!needsFirstprivate)
1734 return success();
1735
1736 llvm::BasicBlock *copyBlock =
1737 splitBB(builder, /*CreateBranch=*/true, "omp.private.copy");
1738 setInsertPointForPossiblyEmptyBlock(builder, copyBlock);
1739
1740 for (auto [decl, moldVar, llvmVar] :
1741 llvm::zip_equal(privateDecls, moldVars, llvmPrivateVars)) {
1742 if (decl.getDataSharingType() != omp::DataSharingClauseType::FirstPrivate)
1743 continue;
1744
1745 // copyRegion implements `lhs = rhs`
1746 Region &copyRegion = decl.getCopyRegion();
1747
1748 moduleTranslation.mapValue(decl.getCopyMoldArg(), moldVar);
1749
1750 // map copyRegion lhs arg
1751 moduleTranslation.mapValue(decl.getCopyPrivateArg(), llvmVar);
1752
1753 // in-place convert copy region
1754 if (failed(inlineConvertOmpRegions(copyRegion, "omp.private.copy", builder,
1755 moduleTranslation)))
1756 return decl.emitError("failed to inline `copy` region of `omp.private`");
1757
1759
1760 // ignore unused value yielded from copy region
1761
1762 // clear copy region block argument mapping in case it needs to be
1763 // re-created with different sources for reuse of the same reduction
1764 // decl
1765 moduleTranslation.forgetMapping(copyRegion);
1766 }
1767
1768 if (insertBarrier && !opIsInSingleThread(op)) {
1769 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
1770 llvm::OpenMPIRBuilder::InsertPointOrErrorTy res =
1771 ompBuilder->createBarrier(builder.saveIP(), llvm::omp::OMPD_barrier);
1772 if (failed(handleError(res, *op)))
1773 return failure();
1774 }
1775
1776 return success();
1777}
1778
1779static LogicalResult copyFirstPrivateVars(
1780 mlir::Operation *op, llvm::IRBuilderBase &builder,
1781 LLVM::ModuleTranslation &moduleTranslation,
1782 SmallVectorImpl<mlir::Value> &mlirPrivateVars,
1783 ArrayRef<llvm::Value *> llvmPrivateVars,
1784 SmallVectorImpl<omp::PrivateClauseOp> &privateDecls, bool insertBarrier,
1785 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1786 llvm::SmallVector<llvm::Value *> moldVars(mlirPrivateVars.size());
1787 llvm::transform(mlirPrivateVars, moldVars.begin(), [&](mlir::Value mlirVar) {
1788 // map copyRegion rhs arg
1789 llvm::Value *moldVar = findAssociatedValue(
1790 mlirVar, builder, moduleTranslation, mappedPrivateVars);
1791 assert(moldVar);
1792 return moldVar;
1793 });
1794 return copyFirstPrivateVars(op, builder, moduleTranslation, moldVars,
1795 llvmPrivateVars, privateDecls, insertBarrier,
1796 mappedPrivateVars);
1797}
1798
1799static LogicalResult
1800cleanupPrivateVars(llvm::IRBuilderBase &builder,
1801 LLVM::ModuleTranslation &moduleTranslation, Location loc,
1802 SmallVectorImpl<llvm::Value *> &llvmPrivateVars,
1804 // private variable deallocation
1805 SmallVector<Region *> privateCleanupRegions;
1806 llvm::transform(privateDecls, std::back_inserter(privateCleanupRegions),
1807 [](omp::PrivateClauseOp privatizer) {
1808 return &privatizer.getDeallocRegion();
1809 });
1810
1811 if (failed(inlineOmpRegionCleanup(
1812 privateCleanupRegions, llvmPrivateVars, moduleTranslation, builder,
1813 "omp.private.dealloc", /*shouldLoadCleanupRegionArg=*/false)))
1814 return mlir::emitError(loc, "failed to inline `dealloc` region of an "
1815 "`omp.private` op in");
1816
1817 return success();
1818}
1819
1820/// Returns true if the construct contains omp.cancel or omp.cancellation_point
1822 // omp.cancel and omp.cancellation_point must be "closely nested" so they will
1823 // be visible and not inside of function calls. This is enforced by the
1824 // verifier.
1825 return op
1826 ->walk([](Operation *child) {
1827 if (mlir::isa<omp::CancelOp, omp::CancellationPointOp>(child))
1828 return WalkResult::interrupt();
1829 return WalkResult::advance();
1830 })
1831 .wasInterrupted();
1832}
1833
1834static LogicalResult
1835convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder,
1836 LLVM::ModuleTranslation &moduleTranslation) {
1837 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1838 using StorableBodyGenCallbackTy =
1839 llvm::OpenMPIRBuilder::StorableBodyGenCallbackTy;
1840
1841 auto sectionsOp = cast<omp::SectionsOp>(opInst);
1842
1843 if (failed(checkImplementationStatus(opInst)))
1844 return failure();
1845
1846 llvm::ArrayRef<bool> isByRef = getIsByRef(sectionsOp.getReductionByref());
1847 assert(isByRef.size() == sectionsOp.getNumReductionVars());
1848
1850 collectReductionDecls(sectionsOp, reductionDecls);
1851 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1852 findAllocaInsertPoint(builder, moduleTranslation);
1853
1854 SmallVector<llvm::Value *> privateReductionVariables(
1855 sectionsOp.getNumReductionVars());
1856 DenseMap<Value, llvm::Value *> reductionVariableMap;
1857
1858 MutableArrayRef<BlockArgument> reductionArgs =
1859 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
1860
1862 sectionsOp, reductionArgs, builder, moduleTranslation, allocaIP,
1863 reductionDecls, privateReductionVariables, reductionVariableMap,
1864 isByRef)))
1865 return failure();
1866
1868
1869 for (Operation &op : *sectionsOp.getRegion().begin()) {
1870 auto sectionOp = dyn_cast<omp::SectionOp>(op);
1871 if (!sectionOp) // omp.terminator
1872 continue;
1873
1874 Region &region = sectionOp.getRegion();
1875 auto sectionCB = [&sectionsOp, &region, &builder, &moduleTranslation](
1876 InsertPointTy allocaIP, InsertPointTy codeGenIP) {
1877 builder.restoreIP(codeGenIP);
1878
1879 // map the omp.section reduction block argument to the omp.sections block
1880 // arguments
1881 // TODO: this assumes that the only block arguments are reduction
1882 // variables
1883 assert(region.getNumArguments() ==
1884 sectionsOp.getRegion().getNumArguments());
1885 for (auto [sectionsArg, sectionArg] : llvm::zip_equal(
1886 sectionsOp.getRegion().getArguments(), region.getArguments())) {
1887 llvm::Value *llvmVal = moduleTranslation.lookupValue(sectionsArg);
1888 assert(llvmVal);
1889 moduleTranslation.mapValue(sectionArg, llvmVal);
1890 }
1891
1892 return convertOmpOpRegions(region, "omp.section.region", builder,
1893 moduleTranslation)
1894 .takeError();
1895 };
1896 sectionCBs.push_back(sectionCB);
1897 }
1898
1899 // No sections within omp.sections operation - skip generation. This situation
1900 // is only possible if there is only a terminator operation inside the
1901 // sections operation
1902 if (sectionCBs.empty())
1903 return success();
1904
1905 assert(isa<omp::SectionOp>(*sectionsOp.getRegion().op_begin()));
1906
1907 // TODO: Perform appropriate actions according to the data-sharing
1908 // attribute (shared, private, firstprivate, ...) of variables.
1909 // Currently defaults to shared.
1910 auto privCB = [&](InsertPointTy, InsertPointTy codeGenIP, llvm::Value &,
1911 llvm::Value &vPtr, llvm::Value *&replacementValue)
1912 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1913 replacementValue = &vPtr;
1914 return codeGenIP;
1915 };
1916
1917 // TODO: Perform finalization actions for variables. This has to be
1918 // called for variables which have destructors/finalizers.
1919 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
1920
1921 allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
1922 bool isCancellable = constructIsCancellable(sectionsOp);
1923 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1924 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1925 moduleTranslation.getOpenMPBuilder()->createSections(
1926 ompLoc, allocaIP, sectionCBs, privCB, finiCB, isCancellable,
1927 sectionsOp.getNowait());
1928
1929 if (failed(handleError(afterIP, opInst)))
1930 return failure();
1931
1932 builder.restoreIP(*afterIP);
1933
1934 // Process the reductions if required.
1936 sectionsOp, builder, moduleTranslation, allocaIP, reductionDecls,
1937 privateReductionVariables, isByRef, sectionsOp.getNowait());
1938}
1939
1940/// Converts an OpenMP single construct into LLVM IR using OpenMPIRBuilder.
1941static LogicalResult
1942convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder,
1943 LLVM::ModuleTranslation &moduleTranslation) {
1944 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1945 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1946
1947 if (failed(checkImplementationStatus(*singleOp)))
1948 return failure();
1949
1950 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
1951 builder.restoreIP(codegenIP);
1952 return convertOmpOpRegions(singleOp.getRegion(), "omp.single.region",
1953 builder, moduleTranslation)
1954 .takeError();
1955 };
1956 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
1957
1958 // Handle copyprivate
1959 Operation::operand_range cpVars = singleOp.getCopyprivateVars();
1960 std::optional<ArrayAttr> cpFuncs = singleOp.getCopyprivateSyms();
1963 for (size_t i = 0, e = cpVars.size(); i < e; ++i) {
1964 llvmCPVars.push_back(moduleTranslation.lookupValue(cpVars[i]));
1966 singleOp, cast<SymbolRefAttr>((*cpFuncs)[i]));
1967 llvmCPFuncs.push_back(
1968 moduleTranslation.lookupFunction(llvmFuncOp.getName()));
1969 }
1970
1971 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1972 moduleTranslation.getOpenMPBuilder()->createSingle(
1973 ompLoc, bodyCB, finiCB, singleOp.getNowait(), llvmCPVars,
1974 llvmCPFuncs);
1975
1976 if (failed(handleError(afterIP, *singleOp)))
1977 return failure();
1978
1979 builder.restoreIP(*afterIP);
1980 return success();
1981}
1982
1983static bool teamsReductionContainedInDistribute(omp::TeamsOp teamsOp) {
1984 auto iface =
1985 llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(teamsOp.getOperation());
1986 // Check that all uses of the reduction block arg has the same distribute op
1987 // parent.
1989 Operation *distOp = nullptr;
1990 for (auto ra : iface.getReductionBlockArgs())
1991 for (auto &use : ra.getUses()) {
1992 auto *useOp = use.getOwner();
1993 // Ignore debug uses.
1994 if (mlir::isa<LLVM::DbgDeclareOp, LLVM::DbgValueOp>(useOp)) {
1995 debugUses.push_back(useOp);
1996 continue;
1997 }
1998
1999 auto currentDistOp = useOp->getParentOfType<omp::DistributeOp>();
2000 // Use is not inside a distribute op - return false
2001 if (!currentDistOp)
2002 return false;
2003 // Multiple distribute operations - return false
2004 Operation *currentOp = currentDistOp.getOperation();
2005 if (distOp && (distOp != currentOp))
2006 return false;
2007
2008 distOp = currentOp;
2009 }
2010
2011 // If we are going to use distribute reduction then remove any debug uses of
2012 // the reduction parameters in teamsOp. Otherwise they will be left without
2013 // any mapped value in moduleTranslation and will eventually error out.
2014 for (auto *use : debugUses)
2015 use->erase();
2016 return true;
2017}
2018
2019// Convert an OpenMP Teams construct to LLVM IR using OpenMPIRBuilder
2020static LogicalResult
2021convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder,
2022 LLVM::ModuleTranslation &moduleTranslation) {
2023 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2024 if (failed(checkImplementationStatus(*op)))
2025 return failure();
2026
2027 DenseMap<Value, llvm::Value *> reductionVariableMap;
2028 unsigned numReductionVars = op.getNumReductionVars();
2030 SmallVector<llvm::Value *> privateReductionVariables(numReductionVars);
2031 llvm::ArrayRef<bool> isByRef;
2032 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2033 findAllocaInsertPoint(builder, moduleTranslation);
2034
2035 // Only do teams reduction if there is no distribute op that captures the
2036 // reduction instead.
2037 bool doTeamsReduction = !teamsReductionContainedInDistribute(op);
2038 if (doTeamsReduction) {
2039 isByRef = getIsByRef(op.getReductionByref());
2040
2041 assert(isByRef.size() == op.getNumReductionVars());
2042
2043 MutableArrayRef<BlockArgument> reductionArgs =
2044 llvm::cast<omp::BlockArgOpenMPOpInterface>(*op).getReductionBlockArgs();
2045
2046 collectReductionDecls(op, reductionDecls);
2047
2049 op, reductionArgs, builder, moduleTranslation, allocaIP,
2050 reductionDecls, privateReductionVariables, reductionVariableMap,
2051 isByRef)))
2052 return failure();
2053 }
2054
2055 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
2057 moduleTranslation, allocaIP);
2058 builder.restoreIP(codegenIP);
2059 return convertOmpOpRegions(op.getRegion(), "omp.teams.region", builder,
2060 moduleTranslation)
2061 .takeError();
2062 };
2063
2064 llvm::Value *numTeamsLower = nullptr;
2065 if (Value numTeamsLowerVar = op.getNumTeamsLower())
2066 numTeamsLower = moduleTranslation.lookupValue(numTeamsLowerVar);
2067
2068 llvm::Value *numTeamsUpper = nullptr;
2069 if (!op.getNumTeamsUpperVars().empty())
2070 numTeamsUpper = moduleTranslation.lookupValue(op.getNumTeams(0));
2071
2072 llvm::Value *threadLimit = nullptr;
2073 if (!op.getThreadLimitVars().empty())
2074 threadLimit = moduleTranslation.lookupValue(op.getThreadLimit(0));
2075
2076 llvm::Value *ifExpr = nullptr;
2077 if (Value ifVar = op.getIfExpr())
2078 ifExpr = moduleTranslation.lookupValue(ifVar);
2079
2080 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2081 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2082 moduleTranslation.getOpenMPBuilder()->createTeams(
2083 ompLoc, bodyCB, numTeamsLower, numTeamsUpper, threadLimit, ifExpr);
2084
2085 if (failed(handleError(afterIP, *op)))
2086 return failure();
2087
2088 builder.restoreIP(*afterIP);
2089 if (doTeamsReduction) {
2090 // Process the reductions if required.
2092 op, builder, moduleTranslation, allocaIP, reductionDecls,
2093 privateReductionVariables, isByRef,
2094 /*isNoWait*/ false, /*isTeamsReduction*/ true);
2095 }
2096 return success();
2097}
2098
2099static void
2100buildDependData(std::optional<ArrayAttr> dependKinds, OperandRange dependVars,
2101 LLVM::ModuleTranslation &moduleTranslation,
2103 if (dependVars.empty())
2104 return;
2105 for (auto dep : llvm::zip(dependVars, dependKinds->getValue())) {
2106 llvm::omp::RTLDependenceKindTy type;
2107 switch (
2108 cast<mlir::omp::ClauseTaskDependAttr>(std::get<1>(dep)).getValue()) {
2109 case mlir::omp::ClauseTaskDepend::taskdependin:
2110 type = llvm::omp::RTLDependenceKindTy::DepIn;
2111 break;
2112 // The OpenMP runtime requires that the codegen for 'depend' clause for
2113 // 'out' dependency kind must be the same as codegen for 'depend' clause
2114 // with 'inout' dependency.
2115 case mlir::omp::ClauseTaskDepend::taskdependout:
2116 case mlir::omp::ClauseTaskDepend::taskdependinout:
2117 type = llvm::omp::RTLDependenceKindTy::DepInOut;
2118 break;
2119 case mlir::omp::ClauseTaskDepend::taskdependmutexinoutset:
2120 type = llvm::omp::RTLDependenceKindTy::DepMutexInOutSet;
2121 break;
2122 case mlir::omp::ClauseTaskDepend::taskdependinoutset:
2123 type = llvm::omp::RTLDependenceKindTy::DepInOutSet;
2124 break;
2125 };
2126 llvm::Value *depVal = moduleTranslation.lookupValue(std::get<0>(dep));
2127 llvm::OpenMPIRBuilder::DependData dd(type, depVal->getType(), depVal);
2128 dds.emplace_back(dd);
2129 }
2130}
2131
2132/// Shared implementation of a callback which adds a termiator for the new block
2133/// created for the branch taken when an openmp construct is cancelled. The
2134/// terminator is saved in \p cancelTerminators. This callback is invoked only
2135/// if there is cancellation inside of the taskgroup body.
2136/// The terminator will need to be fixed to branch to the correct block to
2137/// cleanup the construct.
2138static void
2140 llvm::IRBuilderBase &llvmBuilder,
2141 llvm::OpenMPIRBuilder &ompBuilder, mlir::Operation *op,
2142 llvm::omp::Directive cancelDirective) {
2143 auto finiCB = [&](llvm::OpenMPIRBuilder::InsertPointTy ip) -> llvm::Error {
2144 llvm::IRBuilderBase::InsertPointGuard guard(llvmBuilder);
2145
2146 // ip is currently in the block branched to if cancellation occurred.
2147 // We need to create a branch to terminate that block.
2148 llvmBuilder.restoreIP(ip);
2149
2150 // We must still clean up the construct after cancelling it, so we need to
2151 // branch to the block that finalizes the taskgroup.
2152 // That block has not been created yet so use this block as a dummy for now
2153 // and fix this after creating the operation.
2154 cancelTerminators.push_back(llvmBuilder.CreateBr(ip.getBlock()));
2155 return llvm::Error::success();
2156 };
2157 // We have to add the cleanup to the OpenMPIRBuilder before the body gets
2158 // created in case the body contains omp.cancel (which will then expect to be
2159 // able to find this cleanup callback).
2160 ompBuilder.pushFinalizationCB(
2161 {finiCB, cancelDirective, constructIsCancellable(op)});
2162}
2163
2164/// If we cancelled the construct, we should branch to the finalization block of
2165/// that construct. OMPIRBuilder structures the CFG such that the cleanup block
2166/// is immediately before the continuation block. Now this finalization has
2167/// been created we can fix the branch.
2168static void
2170 llvm::OpenMPIRBuilder &ompBuilder,
2171 const llvm::OpenMPIRBuilder::InsertPointTy &afterIP) {
2172 ompBuilder.popFinalizationCB();
2173 llvm::BasicBlock *constructFini = afterIP.getBlock()->getSinglePredecessor();
2174 for (llvm::BranchInst *cancelBranch : cancelTerminators) {
2175 assert(cancelBranch->getNumSuccessors() == 1 &&
2176 "cancel branch should have one target");
2177 cancelBranch->setSuccessor(0, constructFini);
2178 }
2179}
2180
2181namespace {
2182/// TaskContextStructManager takes care of creating and freeing a structure
2183/// containing information needed by the task body to execute.
2184class TaskContextStructManager {
2185public:
2186 TaskContextStructManager(llvm::IRBuilderBase &builder,
2187 LLVM::ModuleTranslation &moduleTranslation,
2188 MutableArrayRef<omp::PrivateClauseOp> privateDecls)
2189 : builder{builder}, moduleTranslation{moduleTranslation},
2190 privateDecls{privateDecls} {}
2191
2192 /// Creates a heap allocated struct containing space for each private
2193 /// variable. Invariant: privateVarTypes, privateDecls, and the elements of
2194 /// the structure should all have the same order (although privateDecls which
2195 /// do not read from the mold argument are skipped).
2196 void generateTaskContextStruct();
2197
2198 /// Create GEPs to access each member of the structure representing a private
2199 /// variable, adding them to llvmPrivateVars. Null values are added where
2200 /// private decls were skipped so that the ordering continues to match the
2201 /// private decls.
2202 void createGEPsToPrivateVars();
2203
2204 /// Given the address of the structure, return a GEP for each private variable
2205 /// in the structure. Null values are added where private decls were skipped
2206 /// so that the ordering continues to match the private decls.
2207 /// Must be called after generateTaskContextStruct().
2208 SmallVector<llvm::Value *>
2209 createGEPsToPrivateVars(llvm::Value *altStructPtr) const;
2210
2211 /// De-allocate the task context structure.
2212 void freeStructPtr();
2213
2214 MutableArrayRef<llvm::Value *> getLLVMPrivateVarGEPs() {
2215 return llvmPrivateVarGEPs;
2216 }
2217
2218 llvm::Value *getStructPtr() { return structPtr; }
2219
2220private:
2221 llvm::IRBuilderBase &builder;
2222 LLVM::ModuleTranslation &moduleTranslation;
2223 MutableArrayRef<omp::PrivateClauseOp> privateDecls;
2224
2225 /// The type of each member of the structure, in order.
2226 SmallVector<llvm::Type *> privateVarTypes;
2227
2228 /// LLVM values for each private variable, or null if that private variable is
2229 /// not included in the task context structure
2230 SmallVector<llvm::Value *> llvmPrivateVarGEPs;
2231
2232 /// A pointer to the structure containing context for this task.
2233 llvm::Value *structPtr = nullptr;
2234 /// The type of the structure
2235 llvm::Type *structTy = nullptr;
2236};
2237} // namespace
2238
2239void TaskContextStructManager::generateTaskContextStruct() {
2240 if (privateDecls.empty())
2241 return;
2242 privateVarTypes.reserve(privateDecls.size());
2243
2244 for (omp::PrivateClauseOp &privOp : privateDecls) {
2245 // Skip private variables which can safely be allocated and initialised
2246 // inside of the task
2247 if (!privOp.readsFromMold())
2248 continue;
2249 Type mlirType = privOp.getType();
2250 privateVarTypes.push_back(moduleTranslation.convertType(mlirType));
2251 }
2252
2253 if (privateVarTypes.empty())
2254 return;
2255
2256 structTy = llvm::StructType::get(moduleTranslation.getLLVMContext(),
2257 privateVarTypes);
2258
2259 llvm::DataLayout dataLayout =
2260 builder.GetInsertBlock()->getModule()->getDataLayout();
2261 llvm::Type *intPtrTy = builder.getIntPtrTy(dataLayout);
2262 llvm::Constant *allocSize = llvm::ConstantExpr::getSizeOf(structTy);
2263
2264 // Heap allocate the structure
2265 structPtr = builder.CreateMalloc(intPtrTy, structTy, allocSize,
2266 /*ArraySize=*/nullptr, /*MallocF=*/nullptr,
2267 "omp.task.context_ptr");
2268}
2269
2270SmallVector<llvm::Value *> TaskContextStructManager::createGEPsToPrivateVars(
2271 llvm::Value *altStructPtr) const {
2272 SmallVector<llvm::Value *> ret;
2273
2274 // Create GEPs for each struct member
2275 ret.reserve(privateDecls.size());
2276 llvm::Value *zero = builder.getInt32(0);
2277 unsigned i = 0;
2278 for (auto privDecl : privateDecls) {
2279 if (!privDecl.readsFromMold()) {
2280 // Handle this inside of the task so we don't pass unnessecary vars in
2281 ret.push_back(nullptr);
2282 continue;
2283 }
2284 llvm::Value *iVal = builder.getInt32(i);
2285 llvm::Value *gep = builder.CreateGEP(structTy, altStructPtr, {zero, iVal});
2286 ret.push_back(gep);
2287 i += 1;
2288 }
2289 return ret;
2290}
2291
2292void TaskContextStructManager::createGEPsToPrivateVars() {
2293 if (!structPtr)
2294 assert(privateVarTypes.empty());
2295 // Still need to run createGEPsToPrivateVars to populate llvmPrivateVarGEPs
2296 // with null values for skipped private decls
2297
2298 llvmPrivateVarGEPs = createGEPsToPrivateVars(structPtr);
2299}
2300
2301void TaskContextStructManager::freeStructPtr() {
2302 if (!structPtr)
2303 return;
2304
2305 llvm::IRBuilderBase::InsertPointGuard guard{builder};
2306 // Ensure we don't put the call to free() after the terminator
2307 builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
2308 builder.CreateFree(structPtr);
2309}
2310
2311/// Converts an OpenMP task construct into LLVM IR using OpenMPIRBuilder.
2312static LogicalResult
2313convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
2314 LLVM::ModuleTranslation &moduleTranslation) {
2315 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2316 if (failed(checkImplementationStatus(*taskOp)))
2317 return failure();
2318
2319 PrivateVarsInfo privateVarsInfo(taskOp);
2320 TaskContextStructManager taskStructMgr{builder, moduleTranslation,
2321 privateVarsInfo.privatizers};
2322
2323 // Allocate and copy private variables before creating the task. This avoids
2324 // accessing invalid memory if (after this scope ends) the private variables
2325 // are initialized from host variables or if the variables are copied into
2326 // from host variables (firstprivate). The insertion point is just before
2327 // where the code for creating and scheduling the task will go. That puts this
2328 // code outside of the outlined task region, which is what we want because
2329 // this way the initialization and copy regions are executed immediately while
2330 // the host variable data are still live.
2331
2332 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2333 findAllocaInsertPoint(builder, moduleTranslation);
2334
2335 // Not using splitBB() because that requires the current block to have a
2336 // terminator.
2337 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end());
2338 llvm::BasicBlock *taskStartBlock = llvm::BasicBlock::Create(
2339 builder.getContext(), "omp.task.start",
2340 /*Parent=*/builder.GetInsertBlock()->getParent());
2341 llvm::Instruction *branchToTaskStartBlock = builder.CreateBr(taskStartBlock);
2342 builder.SetInsertPoint(branchToTaskStartBlock);
2343
2344 // Now do this again to make the initialization and copy blocks
2345 llvm::BasicBlock *copyBlock =
2346 splitBB(builder, /*CreateBranch=*/true, "omp.private.copy");
2347 llvm::BasicBlock *initBlock =
2348 splitBB(builder, /*CreateBranch=*/true, "omp.private.init");
2349
2350 // Now the control flow graph should look like
2351 // starter_block:
2352 // <---- where we started when convertOmpTaskOp was called
2353 // br %omp.private.init
2354 // omp.private.init:
2355 // br %omp.private.copy
2356 // omp.private.copy:
2357 // br %omp.task.start
2358 // omp.task.start:
2359 // <---- where we want the insertion point to be when we call createTask()
2360
2361 // Save the alloca insertion point on ModuleTranslation stack for use in
2362 // nested regions.
2364 moduleTranslation, allocaIP);
2365
2366 // Allocate and initialize private variables
2367 builder.SetInsertPoint(initBlock->getTerminator());
2368
2369 // Create task variable structure
2370 taskStructMgr.generateTaskContextStruct();
2371 // GEPs so that we can initialize the variables. Don't use these GEPs inside
2372 // of the body otherwise it will be the GEP not the struct which is fowarded
2373 // to the outlined function. GEPs forwarded in this way are passed in a
2374 // stack-allocated (by OpenMPIRBuilder) structure which is not safe for tasks
2375 // which may not be executed until after the current stack frame goes out of
2376 // scope.
2377 taskStructMgr.createGEPsToPrivateVars();
2378
2379 for (auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVarAlloc] :
2380 llvm::zip_equal(privateVarsInfo.privatizers, privateVarsInfo.mlirVars,
2381 privateVarsInfo.blockArgs,
2382 taskStructMgr.getLLVMPrivateVarGEPs())) {
2383 // To be handled inside the task.
2384 if (!privDecl.readsFromMold())
2385 continue;
2386 assert(llvmPrivateVarAlloc &&
2387 "reads from mold so shouldn't have been skipped");
2388
2389 llvm::Expected<llvm::Value *> privateVarOrErr =
2390 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2391 blockArg, llvmPrivateVarAlloc, initBlock);
2392 if (!privateVarOrErr)
2393 return handleError(privateVarOrErr, *taskOp.getOperation());
2394
2396
2397 // TODO: this is a bit of a hack for Fortran character boxes.
2398 // Character boxes are passed by value into the init region and then the
2399 // initialized character box is yielded by value. Here we need to store the
2400 // yielded value into the private allocation, and load the private
2401 // allocation to match the type expected by region block arguments.
2402 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
2403 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2404 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
2405 // Load it so we have the value pointed to by the GEP
2406 llvmPrivateVarAlloc = builder.CreateLoad(privateVarOrErr.get()->getType(),
2407 llvmPrivateVarAlloc);
2408 }
2409 assert(llvmPrivateVarAlloc->getType() ==
2410 moduleTranslation.convertType(blockArg.getType()));
2411
2412 // Mapping blockArg -> llvmPrivateVarAlloc is done inside the body callback
2413 // so that OpenMPIRBuilder doesn't try to pass each GEP address through a
2414 // stack allocated structure.
2415 }
2416
2417 // firstprivate copy region
2418 setInsertPointForPossiblyEmptyBlock(builder, copyBlock);
2419 if (failed(copyFirstPrivateVars(
2420 taskOp, builder, moduleTranslation, privateVarsInfo.mlirVars,
2421 taskStructMgr.getLLVMPrivateVarGEPs(), privateVarsInfo.privatizers,
2422 taskOp.getPrivateNeedsBarrier())))
2423 return llvm::failure();
2424
2425 // Set up for call to createTask()
2426 builder.SetInsertPoint(taskStartBlock);
2427
2428 auto bodyCB = [&](InsertPointTy allocaIP,
2429 InsertPointTy codegenIP) -> llvm::Error {
2430 // Save the alloca insertion point on ModuleTranslation stack for use in
2431 // nested regions.
2433 moduleTranslation, allocaIP);
2434
2435 // translate the body of the task:
2436 builder.restoreIP(codegenIP);
2437
2438 llvm::BasicBlock *privInitBlock = nullptr;
2439 privateVarsInfo.llvmVars.resize(privateVarsInfo.blockArgs.size());
2440 for (auto [i, zip] : llvm::enumerate(llvm::zip_equal(
2441 privateVarsInfo.blockArgs, privateVarsInfo.privatizers,
2442 privateVarsInfo.mlirVars))) {
2443 auto [blockArg, privDecl, mlirPrivVar] = zip;
2444 // This is handled before the task executes
2445 if (privDecl.readsFromMold())
2446 continue;
2447
2448 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2449 llvm::Type *llvmAllocType =
2450 moduleTranslation.convertType(privDecl.getType());
2451 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
2452 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
2453 llvmAllocType, /*ArraySize=*/nullptr, "omp.private.alloc");
2454
2455 llvm::Expected<llvm::Value *> privateVarOrError =
2456 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2457 blockArg, llvmPrivateVar, privInitBlock);
2458 if (!privateVarOrError)
2459 return privateVarOrError.takeError();
2460 moduleTranslation.mapValue(blockArg, privateVarOrError.get());
2461 privateVarsInfo.llvmVars[i] = privateVarOrError.get();
2462 }
2463
2464 taskStructMgr.createGEPsToPrivateVars();
2465 for (auto [i, llvmPrivVar] :
2466 llvm::enumerate(taskStructMgr.getLLVMPrivateVarGEPs())) {
2467 if (!llvmPrivVar) {
2468 assert(privateVarsInfo.llvmVars[i] &&
2469 "This is added in the loop above");
2470 continue;
2471 }
2472 privateVarsInfo.llvmVars[i] = llvmPrivVar;
2473 }
2474
2475 // Find and map the addresses of each variable within the task context
2476 // structure
2477 for (auto [blockArg, llvmPrivateVar, privateDecl] :
2478 llvm::zip_equal(privateVarsInfo.blockArgs, privateVarsInfo.llvmVars,
2479 privateVarsInfo.privatizers)) {
2480 // This was handled above.
2481 if (!privateDecl.readsFromMold())
2482 continue;
2483 // Fix broken pass-by-value case for Fortran character boxes
2484 if (!mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2485 llvmPrivateVar = builder.CreateLoad(
2486 moduleTranslation.convertType(blockArg.getType()), llvmPrivateVar);
2487 }
2488 assert(llvmPrivateVar->getType() ==
2489 moduleTranslation.convertType(blockArg.getType()));
2490 moduleTranslation.mapValue(blockArg, llvmPrivateVar);
2491 }
2492
2493 auto continuationBlockOrError = convertOmpOpRegions(
2494 taskOp.getRegion(), "omp.task.region", builder, moduleTranslation);
2495 if (failed(handleError(continuationBlockOrError, *taskOp)))
2496 return llvm::make_error<PreviouslyReportedError>();
2497
2498 builder.SetInsertPoint(continuationBlockOrError.get()->getTerminator());
2499
2500 if (failed(cleanupPrivateVars(builder, moduleTranslation, taskOp.getLoc(),
2501 privateVarsInfo.llvmVars,
2502 privateVarsInfo.privatizers)))
2503 return llvm::make_error<PreviouslyReportedError>();
2504
2505 // Free heap allocated task context structure at the end of the task.
2506 taskStructMgr.freeStructPtr();
2507
2508 return llvm::Error::success();
2509 };
2510
2511 llvm::OpenMPIRBuilder &ompBuilder = *moduleTranslation.getOpenMPBuilder();
2512 SmallVector<llvm::BranchInst *> cancelTerminators;
2513 // The directive to match here is OMPD_taskgroup because it is the taskgroup
2514 // which is canceled. This is handled here because it is the task's cleanup
2515 // block which should be branched to.
2516 pushCancelFinalizationCB(cancelTerminators, builder, ompBuilder, taskOp,
2517 llvm::omp::Directive::OMPD_taskgroup);
2518
2520 buildDependData(taskOp.getDependKinds(), taskOp.getDependVars(),
2521 moduleTranslation, dds);
2522
2523 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2524 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2525 moduleTranslation.getOpenMPBuilder()->createTask(
2526 ompLoc, allocaIP, bodyCB, !taskOp.getUntied(),
2527 moduleTranslation.lookupValue(taskOp.getFinal()),
2528 moduleTranslation.lookupValue(taskOp.getIfExpr()), dds,
2529 taskOp.getMergeable(),
2530 moduleTranslation.lookupValue(taskOp.getEventHandle()),
2531 moduleTranslation.lookupValue(taskOp.getPriority()));
2532
2533 if (failed(handleError(afterIP, *taskOp)))
2534 return failure();
2535
2536 // Set the correct branch target for task cancellation
2537 popCancelFinalizationCB(cancelTerminators, ompBuilder, afterIP.get());
2538
2539 builder.restoreIP(*afterIP);
2540 return success();
2541}
2542
2543// Converts an OpenMP taskloop construct into LLVM IR using OpenMPIRBuilder.
2544static LogicalResult
2545convertOmpTaskloopOp(Operation &opInst, llvm::IRBuilderBase &builder,
2546 LLVM::ModuleTranslation &moduleTranslation) {
2547 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2548 auto taskloopOp = cast<omp::TaskloopOp>(opInst);
2549 if (failed(checkImplementationStatus(opInst)))
2550 return failure();
2551
2552 // It stores the pointer of allocated firstprivate copies,
2553 // which can be used later for freeing the allocated space.
2554 SmallVector<llvm::Value *> llvmFirstPrivateVars;
2555 PrivateVarsInfo privateVarsInfo(taskloopOp);
2556 TaskContextStructManager taskStructMgr{builder, moduleTranslation,
2557 privateVarsInfo.privatizers};
2558
2559 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2560 findAllocaInsertPoint(builder, moduleTranslation);
2561
2562 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end());
2563 llvm::BasicBlock *taskloopStartBlock = llvm::BasicBlock::Create(
2564 builder.getContext(), "omp.taskloop.start",
2565 /*Parent=*/builder.GetInsertBlock()->getParent());
2566 llvm::Instruction *branchToTaskloopStartBlock =
2567 builder.CreateBr(taskloopStartBlock);
2568 builder.SetInsertPoint(branchToTaskloopStartBlock);
2569
2570 llvm::BasicBlock *copyBlock =
2571 splitBB(builder, /*CreateBranch=*/true, "omp.private.copy");
2572 llvm::BasicBlock *initBlock =
2573 splitBB(builder, /*CreateBranch=*/true, "omp.private.init");
2574
2576 moduleTranslation, allocaIP);
2577
2578 // Allocate and initialize private variables
2579 builder.SetInsertPoint(initBlock->getTerminator());
2580
2581 // TODO: don't allocate if the loop has zero iterations.
2582 taskStructMgr.generateTaskContextStruct();
2583 taskStructMgr.createGEPsToPrivateVars();
2584
2585 llvmFirstPrivateVars.resize(privateVarsInfo.blockArgs.size());
2586
2587 for (auto [i, zip] : llvm::enumerate(llvm::zip_equal(
2588 privateVarsInfo.privatizers, privateVarsInfo.mlirVars,
2589 privateVarsInfo.blockArgs, taskStructMgr.getLLVMPrivateVarGEPs()))) {
2590 auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVarAlloc] = zip;
2591 // To be handled inside the taskloop.
2592 if (!privDecl.readsFromMold())
2593 continue;
2594 assert(llvmPrivateVarAlloc &&
2595 "reads from mold so shouldn't have been skipped");
2596
2597 llvm::Expected<llvm::Value *> privateVarOrErr =
2598 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2599 blockArg, llvmPrivateVarAlloc, initBlock);
2600 if (!privateVarOrErr)
2601 return handleError(privateVarOrErr, *taskloopOp.getOperation());
2602
2603 llvmFirstPrivateVars[i] = privateVarOrErr.get();
2604
2605 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2606 builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
2607
2608 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
2609 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2610 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
2611 // Load it so we have the value pointed to by the GEP
2612 llvmPrivateVarAlloc = builder.CreateLoad(privateVarOrErr.get()->getType(),
2613 llvmPrivateVarAlloc);
2614 }
2615 assert(llvmPrivateVarAlloc->getType() ==
2616 moduleTranslation.convertType(blockArg.getType()));
2617 }
2618
2619 // firstprivate copy region
2620 setInsertPointForPossiblyEmptyBlock(builder, copyBlock);
2621 if (failed(copyFirstPrivateVars(
2622 taskloopOp, builder, moduleTranslation, privateVarsInfo.mlirVars,
2623 taskStructMgr.getLLVMPrivateVarGEPs(), privateVarsInfo.privatizers,
2624 taskloopOp.getPrivateNeedsBarrier())))
2625 return llvm::failure();
2626
2627 // Set up inserttion point for call to createTaskloop()
2628 builder.SetInsertPoint(taskloopStartBlock);
2629
2630 auto bodyCB = [&](InsertPointTy allocaIP,
2631 InsertPointTy codegenIP) -> llvm::Error {
2632 // Save the alloca insertion point on ModuleTranslation stack for use in
2633 // nested regions.
2635 moduleTranslation, allocaIP);
2636
2637 // translate the body of the taskloop:
2638 builder.restoreIP(codegenIP);
2639
2640 llvm::BasicBlock *privInitBlock = nullptr;
2641 privateVarsInfo.llvmVars.resize(privateVarsInfo.blockArgs.size());
2642 for (auto [i, zip] : llvm::enumerate(llvm::zip_equal(
2643 privateVarsInfo.blockArgs, privateVarsInfo.privatizers,
2644 privateVarsInfo.mlirVars))) {
2645 auto [blockArg, privDecl, mlirPrivVar] = zip;
2646 // This is handled before the task executes
2647 if (privDecl.readsFromMold())
2648 continue;
2649
2650 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2651 llvm::Type *llvmAllocType =
2652 moduleTranslation.convertType(privDecl.getType());
2653 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
2654 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
2655 llvmAllocType, /*ArraySize=*/nullptr, "omp.private.alloc");
2656
2657 llvm::Expected<llvm::Value *> privateVarOrError =
2658 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2659 blockArg, llvmPrivateVar, privInitBlock);
2660 if (!privateVarOrError)
2661 return privateVarOrError.takeError();
2662 moduleTranslation.mapValue(blockArg, privateVarOrError.get());
2663 privateVarsInfo.llvmVars[i] = privateVarOrError.get();
2664 }
2665
2666 taskStructMgr.createGEPsToPrivateVars();
2667 for (auto [i, llvmPrivVar] :
2668 llvm::enumerate(taskStructMgr.getLLVMPrivateVarGEPs())) {
2669 if (!llvmPrivVar) {
2670 assert(privateVarsInfo.llvmVars[i] &&
2671 "This is added in the loop above");
2672 continue;
2673 }
2674 privateVarsInfo.llvmVars[i] = llvmPrivVar;
2675 }
2676
2677 // Find and map the addresses of each variable within the taskloop context
2678 // structure
2679 for (auto [blockArg, llvmPrivateVar, privateDecl] :
2680 llvm::zip_equal(privateVarsInfo.blockArgs, privateVarsInfo.llvmVars,
2681 privateVarsInfo.privatizers)) {
2682 // This was handled above.
2683 if (!privateDecl.readsFromMold())
2684 continue;
2685 // Fix broken pass-by-value case for Fortran character boxes
2686 if (!mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2687 llvmPrivateVar = builder.CreateLoad(
2688 moduleTranslation.convertType(blockArg.getType()), llvmPrivateVar);
2689 }
2690 assert(llvmPrivateVar->getType() ==
2691 moduleTranslation.convertType(blockArg.getType()));
2692 moduleTranslation.mapValue(blockArg, llvmPrivateVar);
2693 }
2694
2695 auto continuationBlockOrError =
2696 convertOmpOpRegions(taskloopOp.getRegion(), "omp.taskloop.region",
2697 builder, moduleTranslation);
2698
2699 if (failed(handleError(continuationBlockOrError, opInst)))
2700 return llvm::make_error<PreviouslyReportedError>();
2701
2702 builder.SetInsertPoint(continuationBlockOrError.get()->getTerminator());
2703
2704 // This is freeing the private variables as mapped inside of the task: these
2705 // will be per-task private copies possibly after task duplication. This is
2706 // handled transparently by how these are passed to the structure passed
2707 // into the outlined function. When the task is duplicated, that structure
2708 // is duplicated too.
2709 if (failed(cleanupPrivateVars(builder, moduleTranslation,
2710 taskloopOp.getLoc(), privateVarsInfo.llvmVars,
2711 privateVarsInfo.privatizers)))
2712 return llvm::make_error<PreviouslyReportedError>();
2713 // Similarly, the task context structure freed inside the task is the
2714 // per-task copy after task duplication.
2715 taskStructMgr.freeStructPtr();
2716
2717 return llvm::Error::success();
2718 };
2719
2720 // Taskloop divides into an appropriate number of tasks by repeatedly
2721 // duplicating the original task. Each time this is done, the task context
2722 // structure must be duplicated too.
2723 auto taskDupCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP,
2724 llvm::Value *destPtr, llvm::Value *srcPtr)
2726 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2727 builder.restoreIP(codegenIP);
2728
2729 llvm::Type *ptrTy =
2730 builder.getPtrTy(srcPtr->getType()->getPointerAddressSpace());
2731 llvm::Value *src =
2732 builder.CreateLoad(ptrTy, srcPtr, "omp.taskloop.context.src");
2733
2734 TaskContextStructManager &srcStructMgr = taskStructMgr;
2735 TaskContextStructManager destStructMgr(builder, moduleTranslation,
2736 privateVarsInfo.privatizers);
2737 destStructMgr.generateTaskContextStruct();
2738 llvm::Value *dest = destStructMgr.getStructPtr();
2739 dest->setName("omp.taskloop.context.dest");
2740 builder.CreateStore(dest, destPtr);
2741
2743 srcStructMgr.createGEPsToPrivateVars(src);
2745 destStructMgr.createGEPsToPrivateVars(dest);
2746
2747 // Inline init regions.
2748 for (auto [privDecl, mold, blockArg, llvmPrivateVarAlloc] :
2749 llvm::zip_equal(privateVarsInfo.privatizers, srcGEPs,
2750 privateVarsInfo.blockArgs, destGEPs)) {
2751 // To be handled inside task body.
2752 if (!privDecl.readsFromMold())
2753 continue;
2754 assert(llvmPrivateVarAlloc &&
2755 "reads from mold so shouldn't have been skipped");
2756
2757 llvm::Expected<llvm::Value *> privateVarOrErr =
2758 initPrivateVar(builder, moduleTranslation, privDecl, mold, blockArg,
2759 llvmPrivateVarAlloc, builder.GetInsertBlock());
2760 if (!privateVarOrErr)
2761 return privateVarOrErr.takeError();
2762
2764
2765 // TODO: this is a bit of a hack for Fortran character boxes.
2766 // Character boxes are passed by value into the init region and then the
2767 // initialized character box is yielded by value. Here we need to store
2768 // the yielded value into the private allocation, and load the private
2769 // allocation to match the type expected by region block arguments.
2770 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
2771 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2772 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
2773 // Load it so we have the value pointed to by the GEP
2774 llvmPrivateVarAlloc = builder.CreateLoad(
2775 privateVarOrErr.get()->getType(), llvmPrivateVarAlloc);
2776 }
2777 assert(llvmPrivateVarAlloc->getType() ==
2778 moduleTranslation.convertType(blockArg.getType()));
2779
2780 // Mapping blockArg -> llvmPrivateVarAlloc is done inside the body
2781 // callback so that OpenMPIRBuilder doesn't try to pass each GEP address
2782 // through a stack allocated structure.
2783 }
2784
2785 if (failed(copyFirstPrivateVars(
2786 &opInst, builder, moduleTranslation, srcGEPs, destGEPs,
2787 privateVarsInfo.privatizers, taskloopOp.getPrivateNeedsBarrier())))
2788 return llvm::make_error<PreviouslyReportedError>();
2789
2790 return builder.saveIP();
2791 };
2792
2793 auto loopOp = cast<omp::LoopNestOp>(taskloopOp.getWrappedLoop());
2794
2795 auto loopInfo = [&]() -> llvm::Expected<llvm::CanonicalLoopInfo *> {
2796 llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
2797 return loopInfo;
2798 };
2799
2800 Operation::operand_range lowerBounds = loopOp.getLoopLowerBounds();
2801 Operation::operand_range upperBounds = loopOp.getLoopUpperBounds();
2802 Operation::operand_range steps = loopOp.getLoopSteps();
2803 llvm::Type *boundType =
2804 moduleTranslation.lookupValue(lowerBounds[0])->getType();
2805 llvm::Value *lbVal = nullptr;
2806 llvm::Value *ubVal = builder.getIntN(boundType->getIntegerBitWidth(), 1);
2807 llvm::Value *stepVal = nullptr;
2808 if (loopOp.getCollapseNumLoops() > 1) {
2809 // In cases where Collapse is used with Taskloop, the upper bound of the
2810 // iteration space needs to be recalculated to cater for the collapsed loop.
2811 // The Collapsed Loop UpperBound is the product of all collapsed
2812 // loop's tripcount.
2813 // The LowerBound for collapsed loops is always 1. When the loops are
2814 // collapsed, it will reset the bounds and introduce processing to ensure
2815 // the index's are presented as expected. As this happens after creating
2816 // Taskloop, these bounds need predicting. Example:
2817 // !$omp taskloop collapse(2)
2818 // do i = 1, 10
2819 // do j = 1, 5
2820 // ..
2821 // end do
2822 // end do
2823 // This loop above has a total of 50 iterations, so the lb will be 1, and
2824 // the ub will be 50. collapseLoops in OMPIRBuilder then handles ensuring
2825 // that i and j are properly presented when used in the loop.
2826 for (uint64_t i = 0; i < loopOp.getCollapseNumLoops(); i++) {
2827 llvm::Value *loopLb = moduleTranslation.lookupValue(lowerBounds[i]);
2828 llvm::Value *loopUb = moduleTranslation.lookupValue(upperBounds[i]);
2829 llvm::Value *loopStep = moduleTranslation.lookupValue(steps[i]);
2830 // In some cases, such as where the ub is less than the lb so the loop
2831 // steps down, the calculation for the loopTripCount is swapped. To ensure
2832 // the correct value is found, calculate both UB - LB and LB - UB then
2833 // select which value to use depending on how the loop has been
2834 // configured.
2835 llvm::Value *loopLbMinusOne = builder.CreateSub(
2836 loopLb, builder.getIntN(boundType->getIntegerBitWidth(), 1));
2837 llvm::Value *loopUbMinusOne = builder.CreateSub(
2838 loopUb, builder.getIntN(boundType->getIntegerBitWidth(), 1));
2839 llvm::Value *boundsCmp = builder.CreateICmpSLT(loopLb, loopUb);
2840 llvm::Value *ubMinusLb = builder.CreateSub(loopUb, loopLbMinusOne);
2841 llvm::Value *lbMinusUb = builder.CreateSub(loopLb, loopUbMinusOne);
2842 llvm::Value *loopTripCount =
2843 builder.CreateSelect(boundsCmp, ubMinusLb, lbMinusUb);
2844 loopTripCount = builder.CreateBinaryIntrinsic(
2845 llvm::Intrinsic::abs, loopTripCount, builder.getFalse());
2846 // For loops that have a step value not equal to 1, we need to adjust the
2847 // trip count to ensure the correct number of iterations for the loop is
2848 // captured.
2849 llvm::Value *loopTripCountDivStep =
2850 builder.CreateSDiv(loopTripCount, loopStep);
2851 loopTripCountDivStep = builder.CreateBinaryIntrinsic(
2852 llvm::Intrinsic::abs, loopTripCountDivStep, builder.getFalse());
2853 llvm::Value *loopTripCountRem =
2854 builder.CreateSRem(loopTripCount, loopStep);
2855 loopTripCountRem = builder.CreateBinaryIntrinsic(
2856 llvm::Intrinsic::abs, loopTripCountRem, builder.getFalse());
2857 llvm::Value *needsRoundUp = builder.CreateICmpNE(
2858 loopTripCountRem,
2859 builder.getIntN(loopTripCountRem->getType()->getIntegerBitWidth(),
2860 0));
2861 loopTripCount =
2862 builder.CreateAdd(loopTripCountDivStep,
2863 builder.CreateZExtOrTrunc(
2864 needsRoundUp, loopTripCountDivStep->getType()));
2865 ubVal = builder.CreateMul(ubVal, loopTripCount);
2866 }
2867 lbVal = builder.getIntN(boundType->getIntegerBitWidth(), 1);
2868 stepVal = builder.getIntN(boundType->getIntegerBitWidth(), 1);
2869 } else {
2870 lbVal = moduleTranslation.lookupValue(lowerBounds[0]);
2871 ubVal = moduleTranslation.lookupValue(upperBounds[0]);
2872 stepVal = moduleTranslation.lookupValue(steps[0]);
2873 }
2874 assert(lbVal != nullptr && "Expected value for lbVal");
2875 assert(ubVal != nullptr && "Expected value for ubVal");
2876 assert(stepVal != nullptr && "Expected value for stepVal");
2877
2878 llvm::Value *ifCond = nullptr;
2879 llvm::Value *grainsize = nullptr;
2880 int sched = 0; // default
2881 mlir::Value grainsizeVal = taskloopOp.getGrainsize();
2882 mlir::Value numTasksVal = taskloopOp.getNumTasks();
2883 if (Value ifVar = taskloopOp.getIfExpr())
2884 ifCond = moduleTranslation.lookupValue(ifVar);
2885 if (grainsizeVal) {
2886 grainsize = moduleTranslation.lookupValue(grainsizeVal);
2887 sched = 1; // grainsize
2888 } else if (numTasksVal) {
2889 grainsize = moduleTranslation.lookupValue(numTasksVal);
2890 sched = 2; // num_tasks
2891 }
2892
2893 llvm::OpenMPIRBuilder::TaskDupCallbackTy taskDupOrNull = nullptr;
2894 if (taskStructMgr.getStructPtr())
2895 taskDupOrNull = taskDupCB;
2896
2897 llvm::OpenMPIRBuilder &ompBuilder = *moduleTranslation.getOpenMPBuilder();
2898 SmallVector<llvm::BranchInst *> cancelTerminators;
2899 // The directive to match here is OMPD_taskgroup because it is the
2900 // taskgroup which is canceled. This is handled here because it is the
2901 // task's cleanup block which should be branched to. It doesn't depend upon
2902 // nogroup because even in that case the taskloop might still be inside an
2903 // explicit taskgroup.
2904 pushCancelFinalizationCB(cancelTerminators, builder, ompBuilder, taskloopOp,
2905 llvm::omp::Directive::OMPD_taskgroup);
2906
2907 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2908 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2909 moduleTranslation.getOpenMPBuilder()->createTaskloop(
2910 ompLoc, allocaIP, bodyCB, loopInfo, lbVal, ubVal, stepVal,
2911 taskloopOp.getUntied(), ifCond, grainsize, taskloopOp.getNogroup(),
2912 sched, moduleTranslation.lookupValue(taskloopOp.getFinal()),
2913 taskloopOp.getMergeable(),
2914 moduleTranslation.lookupValue(taskloopOp.getPriority()),
2915 loopOp.getCollapseNumLoops(), taskDupOrNull,
2916 taskStructMgr.getStructPtr());
2917
2918 if (failed(handleError(afterIP, opInst)))
2919 return failure();
2920
2921 popCancelFinalizationCB(cancelTerminators, ompBuilder, afterIP.get());
2922
2923 builder.restoreIP(*afterIP);
2924 return success();
2925}
2926
2927/// Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder.
2928static LogicalResult
2929convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder,
2930 LLVM::ModuleTranslation &moduleTranslation) {
2931 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2932 if (failed(checkImplementationStatus(*tgOp)))
2933 return failure();
2934
2935 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
2936 builder.restoreIP(codegenIP);
2937 return convertOmpOpRegions(tgOp.getRegion(), "omp.taskgroup.region",
2938 builder, moduleTranslation)
2939 .takeError();
2940 };
2941
2942 InsertPointTy allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
2943 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2944 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2945 moduleTranslation.getOpenMPBuilder()->createTaskgroup(ompLoc, allocaIP,
2946 bodyCB);
2947
2948 if (failed(handleError(afterIP, *tgOp)))
2949 return failure();
2950
2951 builder.restoreIP(*afterIP);
2952 return success();
2953}
2954
2955static LogicalResult
2956convertOmpTaskwaitOp(omp::TaskwaitOp twOp, llvm::IRBuilderBase &builder,
2957 LLVM::ModuleTranslation &moduleTranslation) {
2958 if (failed(checkImplementationStatus(*twOp)))
2959 return failure();
2960
2961 moduleTranslation.getOpenMPBuilder()->createTaskwait(builder.saveIP());
2962 return success();
2963}
2964
2965/// Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
2966static LogicalResult
2967convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
2968 LLVM::ModuleTranslation &moduleTranslation) {
2969 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
2970 auto wsloopOp = cast<omp::WsloopOp>(opInst);
2971 if (failed(checkImplementationStatus(opInst)))
2972 return failure();
2973
2974 auto loopOp = cast<omp::LoopNestOp>(wsloopOp.getWrappedLoop());
2975 llvm::ArrayRef<bool> isByRef = getIsByRef(wsloopOp.getReductionByref());
2976 assert(isByRef.size() == wsloopOp.getNumReductionVars());
2977
2978 // Static is the default.
2979 auto schedule =
2980 wsloopOp.getScheduleKind().value_or(omp::ClauseScheduleKind::Static);
2981
2982 // Find the loop configuration.
2983 llvm::Value *step = moduleTranslation.lookupValue(loopOp.getLoopSteps()[0]);
2984 llvm::Type *ivType = step->getType();
2985 llvm::Value *chunk = nullptr;
2986 if (wsloopOp.getScheduleChunk()) {
2987 llvm::Value *chunkVar =
2988 moduleTranslation.lookupValue(wsloopOp.getScheduleChunk());
2989 chunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
2990 }
2991
2992 omp::DistributeOp distributeOp = nullptr;
2993 llvm::Value *distScheduleChunk = nullptr;
2994 bool hasDistSchedule = false;
2995 if (llvm::isa_and_present<omp::DistributeOp>(opInst.getParentOp())) {
2996 distributeOp = cast<omp::DistributeOp>(opInst.getParentOp());
2997 hasDistSchedule = distributeOp.getDistScheduleStatic();
2998 if (distributeOp.getDistScheduleChunkSize()) {
2999 llvm::Value *chunkVar = moduleTranslation.lookupValue(
3000 distributeOp.getDistScheduleChunkSize());
3001 distScheduleChunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
3002 }
3003 }
3004
3005 PrivateVarsInfo privateVarsInfo(wsloopOp);
3006
3008 collectReductionDecls(wsloopOp, reductionDecls);
3009 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3010 findAllocaInsertPoint(builder, moduleTranslation);
3011
3012 SmallVector<llvm::Value *> privateReductionVariables(
3013 wsloopOp.getNumReductionVars());
3014
3016 builder, moduleTranslation, privateVarsInfo, allocaIP);
3017 if (handleError(afterAllocas, opInst).failed())
3018 return failure();
3019
3020 DenseMap<Value, llvm::Value *> reductionVariableMap;
3021
3022 MutableArrayRef<BlockArgument> reductionArgs =
3023 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
3024
3025 SmallVector<DeferredStore> deferredStores;
3026
3027 if (failed(allocReductionVars(wsloopOp, reductionArgs, builder,
3028 moduleTranslation, allocaIP, reductionDecls,
3029 privateReductionVariables, reductionVariableMap,
3030 deferredStores, isByRef)))
3031 return failure();
3032
3033 if (handleError(initPrivateVars(builder, moduleTranslation, privateVarsInfo),
3034 opInst)
3035 .failed())
3036 return failure();
3037
3038 if (failed(copyFirstPrivateVars(
3039 wsloopOp, builder, moduleTranslation, privateVarsInfo.mlirVars,
3040 privateVarsInfo.llvmVars, privateVarsInfo.privatizers,
3041 wsloopOp.getPrivateNeedsBarrier())))
3042 return failure();
3043
3044 assert(afterAllocas.get()->getSinglePredecessor());
3045 if (failed(initReductionVars(wsloopOp, reductionArgs, builder,
3046 moduleTranslation,
3047 afterAllocas.get()->getSinglePredecessor(),
3048 reductionDecls, privateReductionVariables,
3049 reductionVariableMap, isByRef, deferredStores)))
3050 return failure();
3051
3052 // TODO: Handle doacross loops when the ordered clause has a parameter.
3053 bool isOrdered = wsloopOp.getOrdered().has_value();
3054 std::optional<omp::ScheduleModifier> scheduleMod = wsloopOp.getScheduleMod();
3055 bool isSimd = wsloopOp.getScheduleSimd();
3056 bool loopNeedsBarrier = !wsloopOp.getNowait();
3057
3058 // The only legal way for the direct parent to be omp.distribute is that this
3059 // represents 'distribute parallel do'. Otherwise, this is a regular
3060 // worksharing loop.
3061 llvm::omp::WorksharingLoopType workshareLoopType =
3062 llvm::isa_and_present<omp::DistributeOp>(opInst.getParentOp())
3063 ? llvm::omp::WorksharingLoopType::DistributeForStaticLoop
3064 : llvm::omp::WorksharingLoopType::ForStaticLoop;
3065
3066 SmallVector<llvm::BranchInst *> cancelTerminators;
3067 pushCancelFinalizationCB(cancelTerminators, builder, *ompBuilder, wsloopOp,
3068 llvm::omp::Directive::OMPD_for);
3069
3070 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3071
3072 // Initialize linear variables and linear step
3073 LinearClauseProcessor linearClauseProcessor;
3074
3075 if (!wsloopOp.getLinearVars().empty()) {
3076 auto linearVarTypes = wsloopOp.getLinearVarTypes().value();
3077 for (mlir::Attribute linearVarType : linearVarTypes)
3078 linearClauseProcessor.registerType(moduleTranslation, linearVarType);
3079
3080 for (auto [idx, linearVar] : llvm::enumerate(wsloopOp.getLinearVars()))
3081 linearClauseProcessor.createLinearVar(
3082 builder, moduleTranslation, moduleTranslation.lookupValue(linearVar),
3083 idx);
3084 for (mlir::Value linearStep : wsloopOp.getLinearStepVars())
3085 linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
3086 }
3087
3089 wsloopOp.getRegion(), "omp.wsloop.region", builder, moduleTranslation);
3090
3091 if (failed(handleError(regionBlock, opInst)))
3092 return failure();
3093
3094 llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
3095
3096 // Emit Initialization and Update IR for linear variables
3097 if (!wsloopOp.getLinearVars().empty()) {
3098 linearClauseProcessor.initLinearVar(builder, moduleTranslation,
3099 loopInfo->getPreheader());
3100 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
3101 moduleTranslation.getOpenMPBuilder()->createBarrier(
3102 builder.saveIP(), llvm::omp::OMPD_barrier);
3103 if (failed(handleError(afterBarrierIP, *loopOp)))
3104 return failure();
3105 builder.restoreIP(*afterBarrierIP);
3106 linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
3107 loopInfo->getIndVar());
3108 linearClauseProcessor.splitLinearFiniBB(builder, loopInfo->getExit());
3109 }
3110
3111 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
3112
3113 // Check if we can generate no-loop kernel
3114 bool noLoopMode = false;
3115 omp::TargetOp targetOp = wsloopOp->getParentOfType<mlir::omp::TargetOp>();
3116 if (targetOp) {
3117 Operation *targetCapturedOp = targetOp.getInnermostCapturedOmpOp();
3118 // We need this check because, without it, noLoopMode would be set to true
3119 // for every omp.wsloop nested inside a no-loop SPMD target region, even if
3120 // that loop is not the top-level SPMD one.
3121 if (loopOp == targetCapturedOp) {
3122 omp::TargetRegionFlags kernelFlags =
3123 targetOp.getKernelExecFlags(targetCapturedOp);
3124 if (omp::bitEnumContainsAll(kernelFlags,
3125 omp::TargetRegionFlags::spmd |
3126 omp::TargetRegionFlags::no_loop) &&
3127 !omp::bitEnumContainsAny(kernelFlags,
3128 omp::TargetRegionFlags::generic))
3129 noLoopMode = true;
3130 }
3131 }
3132
3133 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
3134 ompBuilder->applyWorkshareLoop(
3135 ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
3136 convertToScheduleKind(schedule), chunk, isSimd,
3137 scheduleMod == omp::ScheduleModifier::monotonic,
3138 scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
3139 workshareLoopType, noLoopMode, hasDistSchedule, distScheduleChunk);
3140
3141 if (failed(handleError(wsloopIP, opInst)))
3142 return failure();
3143
3144 // Emit finalization and in-place rewrites for linear vars.
3145 if (!wsloopOp.getLinearVars().empty()) {
3146 llvm::OpenMPIRBuilder::InsertPointTy oldIP = builder.saveIP();
3147 assert(loopInfo->getLastIter() &&
3148 "`lastiter` in CanonicalLoopInfo is nullptr");
3149 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
3150 linearClauseProcessor.finalizeLinearVar(builder, moduleTranslation,
3151 loopInfo->getLastIter());
3152 if (failed(handleError(afterBarrierIP, *loopOp)))
3153 return failure();
3154 for (size_t index = 0; index < wsloopOp.getLinearVars().size(); index++)
3155 linearClauseProcessor.rewriteInPlace(builder, "omp.loop_nest.region",
3156 index);
3157 builder.restoreIP(oldIP);
3158 }
3159
3160 // Set the correct branch target for task cancellation
3161 popCancelFinalizationCB(cancelTerminators, *ompBuilder, wsloopIP.get());
3162
3163 // Process the reductions if required.
3164 if (failed(createReductionsAndCleanup(
3165 wsloopOp, builder, moduleTranslation, allocaIP, reductionDecls,
3166 privateReductionVariables, isByRef, wsloopOp.getNowait(),
3167 /*isTeamsReduction=*/false)))
3168 return failure();
3169
3170 return cleanupPrivateVars(builder, moduleTranslation, wsloopOp.getLoc(),
3171 privateVarsInfo.llvmVars,
3172 privateVarsInfo.privatizers);
3173}
3174
3175/// Converts the OpenMP parallel operation to LLVM IR.
3176static LogicalResult
3177convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
3178 LLVM::ModuleTranslation &moduleTranslation) {
3179 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
3180 ArrayRef<bool> isByRef = getIsByRef(opInst.getReductionByref());
3181 assert(isByRef.size() == opInst.getNumReductionVars());
3182 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3183 bool isCancellable = constructIsCancellable(opInst);
3184
3185 if (failed(checkImplementationStatus(*opInst)))
3186 return failure();
3187
3188 PrivateVarsInfo privateVarsInfo(opInst);
3189
3190 // Collect reduction declarations
3192 collectReductionDecls(opInst, reductionDecls);
3193 SmallVector<llvm::Value *> privateReductionVariables(
3194 opInst.getNumReductionVars());
3195 SmallVector<DeferredStore> deferredStores;
3196
3197 auto bodyGenCB = [&](InsertPointTy allocaIP,
3198 InsertPointTy codeGenIP) -> llvm::Error {
3200 builder, moduleTranslation, privateVarsInfo, allocaIP);
3201 if (handleError(afterAllocas, *opInst).failed())
3202 return llvm::make_error<PreviouslyReportedError>();
3203
3204 // Allocate reduction vars
3205 DenseMap<Value, llvm::Value *> reductionVariableMap;
3206
3207 MutableArrayRef<BlockArgument> reductionArgs =
3208 cast<omp::BlockArgOpenMPOpInterface>(*opInst).getReductionBlockArgs();
3209
3210 allocaIP =
3211 InsertPointTy(allocaIP.getBlock(),
3212 allocaIP.getBlock()->getTerminator()->getIterator());
3213
3214 if (failed(allocReductionVars(
3215 opInst, reductionArgs, builder, moduleTranslation, allocaIP,
3216 reductionDecls, privateReductionVariables, reductionVariableMap,
3217 deferredStores, isByRef)))
3218 return llvm::make_error<PreviouslyReportedError>();
3219
3220 assert(afterAllocas.get()->getSinglePredecessor());
3221 builder.restoreIP(codeGenIP);
3222
3223 if (handleError(
3224 initPrivateVars(builder, moduleTranslation, privateVarsInfo),
3225 *opInst)
3226 .failed())
3227 return llvm::make_error<PreviouslyReportedError>();
3228
3229 if (failed(copyFirstPrivateVars(
3230 opInst, builder, moduleTranslation, privateVarsInfo.mlirVars,
3231 privateVarsInfo.llvmVars, privateVarsInfo.privatizers,
3232 opInst.getPrivateNeedsBarrier())))
3233 return llvm::make_error<PreviouslyReportedError>();
3234
3235 if (failed(
3236 initReductionVars(opInst, reductionArgs, builder, moduleTranslation,
3237 afterAllocas.get()->getSinglePredecessor(),
3238 reductionDecls, privateReductionVariables,
3239 reductionVariableMap, isByRef, deferredStores)))
3240 return llvm::make_error<PreviouslyReportedError>();
3241
3242 // Save the alloca insertion point on ModuleTranslation stack for use in
3243 // nested regions.
3245 moduleTranslation, allocaIP);
3246
3247 // ParallelOp has only one region associated with it.
3249 opInst.getRegion(), "omp.par.region", builder, moduleTranslation);
3250 if (!regionBlock)
3251 return regionBlock.takeError();
3252
3253 // Process the reductions if required.
3254 if (opInst.getNumReductionVars() > 0) {
3255 // Collect reduction info
3257 SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
3259 owningReductionGenRefDataPtrGens;
3261 collectReductionInfo(opInst, builder, moduleTranslation, reductionDecls,
3262 owningReductionGens, owningAtomicReductionGens,
3263 owningReductionGenRefDataPtrGens,
3264 privateReductionVariables, reductionInfos, isByRef);
3265
3266 // Move to region cont block
3267 builder.SetInsertPoint((*regionBlock)->getTerminator());
3268
3269 // Generate reductions from info
3270 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
3271 builder.SetInsertPoint(tempTerminator);
3272
3273 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
3274 ompBuilder->createReductions(
3275 builder.saveIP(), allocaIP, reductionInfos, isByRef,
3276 /*IsNoWait=*/false, /*IsTeamsReduction=*/false);
3277 if (!contInsertPoint)
3278 return contInsertPoint.takeError();
3279
3280 if (!contInsertPoint->getBlock())
3281 return llvm::make_error<PreviouslyReportedError>();
3282
3283 tempTerminator->eraseFromParent();
3284 builder.restoreIP(*contInsertPoint);
3285 }
3286
3287 return llvm::Error::success();
3288 };
3289
3290 auto privCB = [](InsertPointTy allocaIP, InsertPointTy codeGenIP,
3291 llvm::Value &, llvm::Value &val, llvm::Value *&replVal) {
3292 // tell OpenMPIRBuilder not to do anything. We handled Privatisation in
3293 // bodyGenCB.
3294 replVal = &val;
3295 return codeGenIP;
3296 };
3297
3298 // TODO: Perform finalization actions for variables. This has to be
3299 // called for variables which have destructors/finalizers.
3300 auto finiCB = [&](InsertPointTy codeGenIP) -> llvm::Error {
3301 InsertPointTy oldIP = builder.saveIP();
3302 builder.restoreIP(codeGenIP);
3303
3304 // if the reduction has a cleanup region, inline it here to finalize the
3305 // reduction variables
3306 SmallVector<Region *> reductionCleanupRegions;
3307 llvm::transform(reductionDecls, std::back_inserter(reductionCleanupRegions),
3308 [](omp::DeclareReductionOp reductionDecl) {
3309 return &reductionDecl.getCleanupRegion();
3310 });
3311 if (failed(inlineOmpRegionCleanup(
3312 reductionCleanupRegions, privateReductionVariables,
3313 moduleTranslation, builder, "omp.reduction.cleanup")))
3314 return llvm::createStringError(
3315 "failed to inline `cleanup` region of `omp.declare_reduction`");
3316
3317 if (failed(cleanupPrivateVars(builder, moduleTranslation, opInst.getLoc(),
3318 privateVarsInfo.llvmVars,
3319 privateVarsInfo.privatizers)))
3320 return llvm::make_error<PreviouslyReportedError>();
3321
3322 // If we could be performing cancellation, add the cancellation barrier on
3323 // the way out of the outlined region.
3324 if (isCancellable) {
3325 auto IPOrErr = ompBuilder->createBarrier(
3326 llvm::OpenMPIRBuilder::LocationDescription(builder),
3327 llvm::omp::Directive::OMPD_unknown,
3328 /* ForceSimpleCall */ false,
3329 /* CheckCancelFlag */ false);
3330 if (!IPOrErr)
3331 return IPOrErr.takeError();
3332 }
3333
3334 builder.restoreIP(oldIP);
3335 return llvm::Error::success();
3336 };
3337
3338 llvm::Value *ifCond = nullptr;
3339 if (auto ifVar = opInst.getIfExpr())
3340 ifCond = moduleTranslation.lookupValue(ifVar);
3341 llvm::Value *numThreads = nullptr;
3342 if (!opInst.getNumThreadsVars().empty())
3343 numThreads = moduleTranslation.lookupValue(opInst.getNumThreads(0));
3344 auto pbKind = llvm::omp::OMP_PROC_BIND_default;
3345 if (auto bind = opInst.getProcBindKind())
3346 pbKind = getProcBindKind(*bind);
3347
3348 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3349 findAllocaInsertPoint(builder, moduleTranslation);
3350 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3351
3352 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3353 ompBuilder->createParallel(ompLoc, allocaIP, bodyGenCB, privCB, finiCB,
3354 ifCond, numThreads, pbKind, isCancellable);
3355
3356 if (failed(handleError(afterIP, *opInst)))
3357 return failure();
3358
3359 builder.restoreIP(*afterIP);
3360 return success();
3361}
3362
3363/// Convert Order attribute to llvm::omp::OrderKind.
3364static llvm::omp::OrderKind
3365convertOrderKind(std::optional<omp::ClauseOrderKind> o) {
3366 if (!o)
3367 return llvm::omp::OrderKind::OMP_ORDER_unknown;
3368 switch (*o) {
3369 case omp::ClauseOrderKind::Concurrent:
3370 return llvm::omp::OrderKind::OMP_ORDER_concurrent;
3371 }
3372 llvm_unreachable("Unknown ClauseOrderKind kind");
3373}
3374
3375/// Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder.
3376static LogicalResult
3377convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
3378 LLVM::ModuleTranslation &moduleTranslation) {
3379 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3380 auto simdOp = cast<omp::SimdOp>(opInst);
3381
3382 if (failed(checkImplementationStatus(opInst)))
3383 return failure();
3384
3385 PrivateVarsInfo privateVarsInfo(simdOp);
3386
3387 MutableArrayRef<BlockArgument> reductionArgs =
3388 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
3389 DenseMap<Value, llvm::Value *> reductionVariableMap;
3390 SmallVector<llvm::Value *> privateReductionVariables(
3391 simdOp.getNumReductionVars());
3392 SmallVector<DeferredStore> deferredStores;
3394 collectReductionDecls(simdOp, reductionDecls);
3395 llvm::ArrayRef<bool> isByRef = getIsByRef(simdOp.getReductionByref());
3396 assert(isByRef.size() == simdOp.getNumReductionVars());
3397
3398 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3399 findAllocaInsertPoint(builder, moduleTranslation);
3400
3402 builder, moduleTranslation, privateVarsInfo, allocaIP);
3403 if (handleError(afterAllocas, opInst).failed())
3404 return failure();
3405
3406 // Initialize linear variables and linear step
3407 LinearClauseProcessor linearClauseProcessor;
3408
3409 if (!simdOp.getLinearVars().empty()) {
3410 auto linearVarTypes = simdOp.getLinearVarTypes().value();
3411 for (mlir::Attribute linearVarType : linearVarTypes)
3412 linearClauseProcessor.registerType(moduleTranslation, linearVarType);
3413 for (auto [idx, linearVar] : llvm::enumerate(simdOp.getLinearVars())) {
3414 bool isImplicit = false;
3415 for (auto [mlirPrivVar, llvmPrivateVar] : llvm::zip_equal(
3416 privateVarsInfo.mlirVars, privateVarsInfo.llvmVars)) {
3417 // If the linear variable is implicit, reuse the already
3418 // existing llvm::Value
3419 if (linearVar == mlirPrivVar) {
3420 isImplicit = true;
3421 linearClauseProcessor.createLinearVar(builder, moduleTranslation,
3422 llvmPrivateVar, idx);
3423 break;
3424 }
3425 }
3426
3427 if (!isImplicit)
3428 linearClauseProcessor.createLinearVar(
3429 builder, moduleTranslation,
3430 moduleTranslation.lookupValue(linearVar), idx);
3431 }
3432 for (mlir::Value linearStep : simdOp.getLinearStepVars())
3433 linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
3434 }
3435
3436 if (failed(allocReductionVars(simdOp, reductionArgs, builder,
3437 moduleTranslation, allocaIP, reductionDecls,
3438 privateReductionVariables, reductionVariableMap,
3439 deferredStores, isByRef)))
3440 return failure();
3441
3442 if (handleError(initPrivateVars(builder, moduleTranslation, privateVarsInfo),
3443 opInst)
3444 .failed())
3445 return failure();
3446
3447 // No call to copyFirstPrivateVars because FIRSTPRIVATE is not allowed for
3448 // SIMD.
3449
3450 assert(afterAllocas.get()->getSinglePredecessor());
3451 if (failed(initReductionVars(simdOp, reductionArgs, builder,
3452 moduleTranslation,
3453 afterAllocas.get()->getSinglePredecessor(),
3454 reductionDecls, privateReductionVariables,
3455 reductionVariableMap, isByRef, deferredStores)))
3456 return failure();
3457
3458 llvm::ConstantInt *simdlen = nullptr;
3459 if (std::optional<uint64_t> simdlenVar = simdOp.getSimdlen())
3460 simdlen = builder.getInt64(simdlenVar.value());
3461
3462 llvm::ConstantInt *safelen = nullptr;
3463 if (std::optional<uint64_t> safelenVar = simdOp.getSafelen())
3464 safelen = builder.getInt64(safelenVar.value());
3465
3466 llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars;
3467 llvm::omp::OrderKind order = convertOrderKind(simdOp.getOrder());
3468
3469 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
3470 std::optional<ArrayAttr> alignmentValues = simdOp.getAlignments();
3471 mlir::OperandRange operands = simdOp.getAlignedVars();
3472 for (size_t i = 0; i < operands.size(); ++i) {
3473 llvm::Value *alignment = nullptr;
3474 llvm::Value *llvmVal = moduleTranslation.lookupValue(operands[i]);
3475 llvm::Type *ty = llvmVal->getType();
3476
3477 auto intAttr = cast<IntegerAttr>((*alignmentValues)[i]);
3478 alignment = builder.getInt64(intAttr.getInt());
3479 assert(ty->isPointerTy() && "Invalid type for aligned variable");
3480 assert(alignment && "Invalid alignment value");
3481
3482 // Check if the alignment value is not a power of 2. If so, skip emitting
3483 // alignment.
3484 if (!intAttr.getValue().isPowerOf2())
3485 continue;
3486
3487 auto curInsert = builder.saveIP();
3488 builder.SetInsertPoint(sourceBlock);
3489 llvmVal = builder.CreateLoad(ty, llvmVal);
3490 builder.restoreIP(curInsert);
3491 alignedVars[llvmVal] = alignment;
3492 }
3493
3495 simdOp.getRegion(), "omp.simd.region", builder, moduleTranslation);
3496
3497 if (failed(handleError(regionBlock, opInst)))
3498 return failure();
3499
3500 llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
3501 // Emit Initialization for linear variables
3502 if (simdOp.getLinearVars().size()) {
3503 linearClauseProcessor.initLinearVar(builder, moduleTranslation,
3504 loopInfo->getPreheader());
3505
3506 linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
3507 loopInfo->getIndVar());
3508 }
3509 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
3510
3511 ompBuilder->applySimd(loopInfo, alignedVars,
3512 simdOp.getIfExpr()
3513 ? moduleTranslation.lookupValue(simdOp.getIfExpr())
3514 : nullptr,
3515 order, simdlen, safelen);
3516
3517 linearClauseProcessor.emitStoresForLinearVar(builder);
3518 for (size_t index = 0; index < simdOp.getLinearVars().size(); index++)
3519 linearClauseProcessor.rewriteInPlace(builder, "omp.loop_nest.region",
3520 index);
3521
3522 // We now need to reduce the per-simd-lane reduction variable into the
3523 // original variable. This works a bit differently to other reductions (e.g.
3524 // wsloop) because we don't need to call into the OpenMP runtime to handle
3525 // threads: everything happened in this one thread.
3526 for (auto [i, tuple] : llvm::enumerate(
3527 llvm::zip(reductionDecls, isByRef, simdOp.getReductionVars(),
3528 privateReductionVariables))) {
3529 auto [decl, byRef, reductionVar, privateReductionVar] = tuple;
3530
3531 OwningReductionGen gen = makeReductionGen(decl, builder, moduleTranslation);
3532 llvm::Value *originalVariable = moduleTranslation.lookupValue(reductionVar);
3533 llvm::Type *reductionType = moduleTranslation.convertType(decl.getType());
3534
3535 // We have one less load for by-ref case because that load is now inside of
3536 // the reduction region.
3537 llvm::Value *redValue = originalVariable;
3538 if (!byRef)
3539 redValue =
3540 builder.CreateLoad(reductionType, redValue, "red.value." + Twine(i));
3541 llvm::Value *privateRedValue = builder.CreateLoad(
3542 reductionType, privateReductionVar, "red.private.value." + Twine(i));
3543 llvm::Value *reduced;
3544
3545 auto res = gen(builder.saveIP(), redValue, privateRedValue, reduced);
3546 if (failed(handleError(res, opInst)))
3547 return failure();
3548 builder.restoreIP(res.get());
3549
3550 // For by-ref case, the store is inside of the reduction region.
3551 if (!byRef)
3552 builder.CreateStore(reduced, originalVariable);
3553 }
3554
3555 // After the construct, deallocate private reduction variables.
3556 SmallVector<Region *> reductionRegions;
3557 llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
3558 [](omp::DeclareReductionOp reductionDecl) {
3559 return &reductionDecl.getCleanupRegion();
3560 });
3561 if (failed(inlineOmpRegionCleanup(reductionRegions, privateReductionVariables,
3562 moduleTranslation, builder,
3563 "omp.reduction.cleanup")))
3564 return failure();
3565
3566 return cleanupPrivateVars(builder, moduleTranslation, simdOp.getLoc(),
3567 privateVarsInfo.llvmVars,
3568 privateVarsInfo.privatizers);
3569}
3570
3571/// Converts an OpenMP loop nest into LLVM IR using OpenMPIRBuilder.
3572static LogicalResult
3573convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder,
3574 LLVM::ModuleTranslation &moduleTranslation) {
3575 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3576 auto loopOp = cast<omp::LoopNestOp>(opInst);
3577
3578 if (failed(checkImplementationStatus(opInst)))
3579 return failure();
3580
3581 // Set up the source location value for OpenMP runtime.
3582 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3583
3584 // Generator of the canonical loop body.
3587 auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip,
3588 llvm::Value *iv) -> llvm::Error {
3589 // Make sure further conversions know about the induction variable.
3590 moduleTranslation.mapValue(
3591 loopOp.getRegion().front().getArgument(loopInfos.size()), iv);
3592
3593 // Capture the body insertion point for use in nested loops. BodyIP of the
3594 // CanonicalLoopInfo always points to the beginning of the entry block of
3595 // the body.
3596 bodyInsertPoints.push_back(ip);
3597
3598 if (loopInfos.size() != loopOp.getNumLoops() - 1)
3599 return llvm::Error::success();
3600
3601 // Convert the body of the loop.
3602 builder.restoreIP(ip);
3604 loopOp.getRegion(), "omp.loop_nest.region", builder, moduleTranslation);
3605 if (!regionBlock)
3606 return regionBlock.takeError();
3607
3608 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
3609 return llvm::Error::success();
3610 };
3611
3612 // Delegate actual loop construction to the OpenMP IRBuilder.
3613 // TODO: this currently assumes omp.loop_nest is semantically similar to SCF
3614 // loop, i.e. it has a positive step, uses signed integer semantics.
3615 // Reconsider this code when the nested loop operation clearly supports more
3616 // cases.
3617 for (unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) {
3618 llvm::Value *lowerBound =
3619 moduleTranslation.lookupValue(loopOp.getLoopLowerBounds()[i]);
3620 llvm::Value *upperBound =
3621 moduleTranslation.lookupValue(loopOp.getLoopUpperBounds()[i]);
3622 llvm::Value *step = moduleTranslation.lookupValue(loopOp.getLoopSteps()[i]);
3623
3624 // Make sure loop trip count are emitted in the preheader of the outermost
3625 // loop at the latest so that they are all available for the new collapsed
3626 // loop will be created below.
3627 llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc;
3628 llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP;
3629 if (i != 0) {
3630 loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back(),
3631 ompLoc.DL);
3632 computeIP = loopInfos.front()->getPreheaderIP();
3633 }
3634
3636 ompBuilder->createCanonicalLoop(
3637 loc, bodyGen, lowerBound, upperBound, step,
3638 /*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP);
3639
3640 if (failed(handleError(loopResult, *loopOp)))
3641 return failure();
3642
3643 loopInfos.push_back(*loopResult);
3644 }
3645
3646 llvm::OpenMPIRBuilder::InsertPointTy afterIP =
3647 loopInfos.front()->getAfterIP();
3648
3649 // Do tiling.
3650 if (const auto &tiles = loopOp.getTileSizes()) {
3651 llvm::Type *ivType = loopInfos.front()->getIndVarType();
3653
3654 for (auto tile : tiles.value()) {
3655 llvm::Value *tileVal = llvm::ConstantInt::get(ivType, tile);
3656 tileSizes.push_back(tileVal);
3657 }
3658
3659 std::vector<llvm::CanonicalLoopInfo *> newLoops =
3660 ompBuilder->tileLoops(ompLoc.DL, loopInfos, tileSizes);
3661
3662 // Update afterIP to get the correct insertion point after
3663 // tiling.
3664 llvm::BasicBlock *afterBB = newLoops.front()->getAfter();
3665 llvm::BasicBlock *afterAfterBB = afterBB->getSingleSuccessor();
3666 afterIP = {afterAfterBB, afterAfterBB->begin()};
3667
3668 // Update the loop infos.
3669 loopInfos.clear();
3670 for (const auto &newLoop : newLoops)
3671 loopInfos.push_back(newLoop);
3672 } // Tiling done.
3673
3674 // Do collapse.
3675 const auto &numCollapse = loopOp.getCollapseNumLoops();
3677 loopInfos.begin(), loopInfos.begin() + (numCollapse));
3678
3679 auto newTopLoopInfo =
3680 ompBuilder->collapseLoops(ompLoc.DL, collapseLoopInfos, {});
3681
3682 assert(newTopLoopInfo && "New top loop information is missing");
3683 moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
3684 [&](OpenMPLoopInfoStackFrame &frame) {
3685 frame.loopInfo = newTopLoopInfo;
3686 return WalkResult::interrupt();
3687 });
3688
3689 // Continue building IR after the loop. Note that the LoopInfo returned by
3690 // `collapseLoops` points inside the outermost loop and is intended for
3691 // potential further loop transformations. Use the insertion point stored
3692 // before collapsing loops instead.
3693 builder.restoreIP(afterIP);
3694 return success();
3695}
3696
3697/// Convert an omp.canonical_loop to LLVM-IR
3698static LogicalResult
3699convertOmpCanonicalLoopOp(omp::CanonicalLoopOp op, llvm::IRBuilderBase &builder,
3700 LLVM::ModuleTranslation &moduleTranslation) {
3701 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3702
3703 llvm::OpenMPIRBuilder::LocationDescription loopLoc(builder);
3704 Value loopIV = op.getInductionVar();
3705 Value loopTC = op.getTripCount();
3706
3707 llvm::Value *llvmTC = moduleTranslation.lookupValue(loopTC);
3708
3710 ompBuilder->createCanonicalLoop(
3711 loopLoc,
3712 [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *llvmIV) {
3713 // Register the mapping of MLIR induction variable to LLVM-IR
3714 // induction variable
3715 moduleTranslation.mapValue(loopIV, llvmIV);
3716
3717 builder.restoreIP(ip);
3719 convertOmpOpRegions(op.getRegion(), "omp.loop.region", builder,
3720 moduleTranslation);
3721
3722 return bodyGenStatus.takeError();
3723 },
3724 llvmTC, "omp.loop");
3725 if (!llvmOrError)
3726 return op.emitError(llvm::toString(llvmOrError.takeError()));
3727
3728 llvm::CanonicalLoopInfo *llvmCLI = *llvmOrError;
3729 llvm::IRBuilderBase::InsertPoint afterIP = llvmCLI->getAfterIP();
3730 builder.restoreIP(afterIP);
3731
3732 // Register the mapping of MLIR loop to LLVM-IR OpenMPIRBuilder loop
3733 if (Value cli = op.getCli())
3734 moduleTranslation.mapOmpLoop(cli, llvmCLI);
3735
3736 return success();
3737}
3738
3739/// Apply a `#pragma omp unroll` / "!$omp unroll" transformation using the
3740/// OpenMPIRBuilder.
3741static LogicalResult
3742applyUnrollHeuristic(omp::UnrollHeuristicOp op, llvm::IRBuilderBase &builder,
3743 LLVM::ModuleTranslation &moduleTranslation) {
3744 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3745
3746 Value applyee = op.getApplyee();
3747 assert(applyee && "Loop to apply unrolling on required");
3748
3749 llvm::CanonicalLoopInfo *consBuilderCLI =
3750 moduleTranslation.lookupOMPLoop(applyee);
3751 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
3752 ompBuilder->unrollLoopHeuristic(loc.DL, consBuilderCLI);
3753
3754 moduleTranslation.invalidateOmpLoop(applyee);
3755 return success();
3756}
3757
3758/// Apply a `#pragma omp tile` / `!$omp tile` transformation using the
3759/// OpenMPIRBuilder.
3760static LogicalResult applyTile(omp::TileOp op, llvm::IRBuilderBase &builder,
3761 LLVM::ModuleTranslation &moduleTranslation) {
3762 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3763 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
3764
3766 SmallVector<llvm::Value *> translatedSizes;
3767
3768 for (Value size : op.getSizes()) {
3769 llvm::Value *translatedSize = moduleTranslation.lookupValue(size);
3770 assert(translatedSize &&
3771 "sizes clause arguments must already be translated");
3772 translatedSizes.push_back(translatedSize);
3773 }
3774
3775 for (Value applyee : op.getApplyees()) {
3776 llvm::CanonicalLoopInfo *consBuilderCLI =
3777 moduleTranslation.lookupOMPLoop(applyee);
3778 assert(applyee && "Canonical loop must already been translated");
3779 translatedLoops.push_back(consBuilderCLI);
3780 }
3781
3782 auto generatedLoops =
3783 ompBuilder->tileLoops(loc.DL, translatedLoops, translatedSizes);
3784 if (!op.getGeneratees().empty()) {
3785 for (auto [mlirLoop, genLoop] :
3786 zip_equal(op.getGeneratees(), generatedLoops))
3787 moduleTranslation.mapOmpLoop(mlirLoop, genLoop);
3788 }
3789
3790 // CLIs can only be consumed once
3791 for (Value applyee : op.getApplyees())
3792 moduleTranslation.invalidateOmpLoop(applyee);
3793
3794 return success();
3795}
3796
3797/// Apply a `#pragma omp fuse` / `!$omp fuse` transformation using the
3798/// OpenMPIRBuilder.
3799static LogicalResult applyFuse(omp::FuseOp op, llvm::IRBuilderBase &builder,
3800 LLVM::ModuleTranslation &moduleTranslation) {
3801 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3802 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
3803
3804 // Select what CLIs are going to be fused
3805 SmallVector<llvm::CanonicalLoopInfo *> beforeFuse, toFuse, afterFuse;
3806 for (size_t i = 0; i < op.getApplyees().size(); i++) {
3807 Value applyee = op.getApplyees()[i];
3808 llvm::CanonicalLoopInfo *consBuilderCLI =
3809 moduleTranslation.lookupOMPLoop(applyee);
3810 assert(applyee && "Canonical loop must already been translated");
3811 if (op.getFirst().has_value() && i < op.getFirst().value() - 1)
3812 beforeFuse.push_back(consBuilderCLI);
3813 else if (op.getCount().has_value() &&
3814 i >= op.getFirst().value() + op.getCount().value() - 1)
3815 afterFuse.push_back(consBuilderCLI);
3816 else
3817 toFuse.push_back(consBuilderCLI);
3818 }
3819 assert(
3820 (op.getGeneratees().empty() ||
3821 beforeFuse.size() + afterFuse.size() + 1 == op.getGeneratees().size()) &&
3822 "Wrong number of generatees");
3823
3824 // do the fuse
3825 auto generatedLoop = ompBuilder->fuseLoops(loc.DL, toFuse);
3826 if (!op.getGeneratees().empty()) {
3827 size_t i = 0;
3828 for (; i < beforeFuse.size(); i++)
3829 moduleTranslation.mapOmpLoop(op.getGeneratees()[i], beforeFuse[i]);
3830 moduleTranslation.mapOmpLoop(op.getGeneratees()[i++], generatedLoop);
3831 for (; i < afterFuse.size(); i++)
3832 moduleTranslation.mapOmpLoop(op.getGeneratees()[i], afterFuse[i]);
3833 }
3834
3835 // CLIs can only be consumed once
3836 for (Value applyee : op.getApplyees())
3837 moduleTranslation.invalidateOmpLoop(applyee);
3838
3839 return success();
3840}
3841
3842/// Convert an Atomic Ordering attribute to llvm::AtomicOrdering.
3843static llvm::AtomicOrdering
3844convertAtomicOrdering(std::optional<omp::ClauseMemoryOrderKind> ao) {
3845 if (!ao)
3846 return llvm::AtomicOrdering::Monotonic; // Default Memory Ordering
3847
3848 switch (*ao) {
3849 case omp::ClauseMemoryOrderKind::Seq_cst:
3850 return llvm::AtomicOrdering::SequentiallyConsistent;
3851 case omp::ClauseMemoryOrderKind::Acq_rel:
3852 return llvm::AtomicOrdering::AcquireRelease;
3853 case omp::ClauseMemoryOrderKind::Acquire:
3854 return llvm::AtomicOrdering::Acquire;
3855 case omp::ClauseMemoryOrderKind::Release:
3856 return llvm::AtomicOrdering::Release;
3857 case omp::ClauseMemoryOrderKind::Relaxed:
3858 return llvm::AtomicOrdering::Monotonic;
3859 }
3860 llvm_unreachable("Unknown ClauseMemoryOrderKind kind");
3861}
3862
3863/// Convert omp.atomic.read operation to LLVM IR.
3864static LogicalResult
3865convertOmpAtomicRead(Operation &opInst, llvm::IRBuilderBase &builder,
3866 LLVM::ModuleTranslation &moduleTranslation) {
3867 auto readOp = cast<omp::AtomicReadOp>(opInst);
3868 if (failed(checkImplementationStatus(opInst)))
3869 return failure();
3870
3871 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3872 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3873 findAllocaInsertPoint(builder, moduleTranslation);
3874
3875 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3876
3877 llvm::AtomicOrdering AO = convertAtomicOrdering(readOp.getMemoryOrder());
3878 llvm::Value *x = moduleTranslation.lookupValue(readOp.getX());
3879 llvm::Value *v = moduleTranslation.lookupValue(readOp.getV());
3880
3881 llvm::Type *elementType =
3882 moduleTranslation.convertType(readOp.getElementType());
3883
3884 llvm::OpenMPIRBuilder::AtomicOpValue V = {v, elementType, false, false};
3885 llvm::OpenMPIRBuilder::AtomicOpValue X = {x, elementType, false, false};
3886 builder.restoreIP(ompBuilder->createAtomicRead(ompLoc, X, V, AO, allocaIP));
3887 return success();
3888}
3889
3890/// Converts an omp.atomic.write operation to LLVM IR.
3891static LogicalResult
3892convertOmpAtomicWrite(Operation &opInst, llvm::IRBuilderBase &builder,
3893 LLVM::ModuleTranslation &moduleTranslation) {
3894 auto writeOp = cast<omp::AtomicWriteOp>(opInst);
3895 if (failed(checkImplementationStatus(opInst)))
3896 return failure();
3897
3898 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3899 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3900 findAllocaInsertPoint(builder, moduleTranslation);
3901
3902 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3903 llvm::AtomicOrdering ao = convertAtomicOrdering(writeOp.getMemoryOrder());
3904 llvm::Value *expr = moduleTranslation.lookupValue(writeOp.getExpr());
3905 llvm::Value *dest = moduleTranslation.lookupValue(writeOp.getX());
3906 llvm::Type *ty = moduleTranslation.convertType(writeOp.getExpr().getType());
3907 llvm::OpenMPIRBuilder::AtomicOpValue x = {dest, ty, /*isSigned=*/false,
3908 /*isVolatile=*/false};
3909 builder.restoreIP(
3910 ompBuilder->createAtomicWrite(ompLoc, x, expr, ao, allocaIP));
3911 return success();
3912}
3913
3914/// Converts an LLVM dialect binary operation to the corresponding enum value
3915/// for `atomicrmw` supported binary operation.
3916static llvm::AtomicRMWInst::BinOp convertBinOpToAtomic(Operation &op) {
3918 .Case([&](LLVM::AddOp) { return llvm::AtomicRMWInst::BinOp::Add; })
3919 .Case([&](LLVM::SubOp) { return llvm::AtomicRMWInst::BinOp::Sub; })
3920 .Case([&](LLVM::AndOp) { return llvm::AtomicRMWInst::BinOp::And; })
3921 .Case([&](LLVM::OrOp) { return llvm::AtomicRMWInst::BinOp::Or; })
3922 .Case([&](LLVM::XOrOp) { return llvm::AtomicRMWInst::BinOp::Xor; })
3923 .Case([&](LLVM::UMaxOp) { return llvm::AtomicRMWInst::BinOp::UMax; })
3924 .Case([&](LLVM::UMinOp) { return llvm::AtomicRMWInst::BinOp::UMin; })
3925 .Case([&](LLVM::FAddOp) { return llvm::AtomicRMWInst::BinOp::FAdd; })
3926 .Case([&](LLVM::FSubOp) { return llvm::AtomicRMWInst::BinOp::FSub; })
3927 .Default(llvm::AtomicRMWInst::BinOp::BAD_BINOP);
3928}
3929
3930static void extractAtomicControlFlags(omp::AtomicUpdateOp atomicUpdateOp,
3931 bool &isIgnoreDenormalMode,
3932 bool &isFineGrainedMemory,
3933 bool &isRemoteMemory) {
3934 isIgnoreDenormalMode = false;
3935 isFineGrainedMemory = false;
3936 isRemoteMemory = false;
3937 if (atomicUpdateOp &&
3938 atomicUpdateOp->hasAttr(atomicUpdateOp.getAtomicControlAttrName())) {
3939 mlir::omp::AtomicControlAttr atomicControlAttr =
3940 atomicUpdateOp.getAtomicControlAttr();
3941 isIgnoreDenormalMode = atomicControlAttr.getIgnoreDenormalMode();
3942 isFineGrainedMemory = atomicControlAttr.getFineGrainedMemory();
3943 isRemoteMemory = atomicControlAttr.getRemoteMemory();
3944 }
3945}
3946
3947/// Converts an OpenMP atomic update operation using OpenMPIRBuilder.
3948static LogicalResult
3949convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst,
3950 llvm::IRBuilderBase &builder,
3951 LLVM::ModuleTranslation &moduleTranslation) {
3952 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3953 if (failed(checkImplementationStatus(*opInst)))
3954 return failure();
3955
3956 // Convert values and types.
3957 auto &innerOpList = opInst.getRegion().front().getOperations();
3958 bool isXBinopExpr{false};
3959 llvm::AtomicRMWInst::BinOp binop;
3960 mlir::Value mlirExpr;
3961 llvm::Value *llvmExpr = nullptr;
3962 llvm::Value *llvmX = nullptr;
3963 llvm::Type *llvmXElementType = nullptr;
3964 if (innerOpList.size() == 2) {
3965 // The two operations here are the update and the terminator.
3966 // Since we can identify the update operation, there is a possibility
3967 // that we can generate the atomicrmw instruction.
3968 mlir::Operation &innerOp = *opInst.getRegion().front().begin();
3969 if (!llvm::is_contained(innerOp.getOperands(),
3970 opInst.getRegion().getArgument(0))) {
3971 return opInst.emitError("no atomic update operation with region argument"
3972 " as operand found inside atomic.update region");
3973 }
3974 binop = convertBinOpToAtomic(innerOp);
3975 isXBinopExpr = innerOp.getOperand(0) == opInst.getRegion().getArgument(0);
3976 mlirExpr = (isXBinopExpr ? innerOp.getOperand(1) : innerOp.getOperand(0));
3977 llvmExpr = moduleTranslation.lookupValue(mlirExpr);
3978 } else {
3979 // Since the update region includes more than one operation
3980 // we will resort to generating a cmpxchg loop.
3981 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
3982 }
3983 llvmX = moduleTranslation.lookupValue(opInst.getX());
3984 llvmXElementType = moduleTranslation.convertType(
3985 opInst.getRegion().getArgument(0).getType());
3986 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
3987 /*isSigned=*/false,
3988 /*isVolatile=*/false};
3989
3990 llvm::AtomicOrdering atomicOrdering =
3991 convertAtomicOrdering(opInst.getMemoryOrder());
3992
3993 // Generate update code.
3994 auto updateFn =
3995 [&opInst, &moduleTranslation](
3996 llvm::Value *atomicx,
3997 llvm::IRBuilder<> &builder) -> llvm::Expected<llvm::Value *> {
3998 Block &bb = *opInst.getRegion().begin();
3999 moduleTranslation.mapValue(*opInst.getRegion().args_begin(), atomicx);
4000 moduleTranslation.mapBlock(&bb, builder.GetInsertBlock());
4001 if (failed(moduleTranslation.convertBlock(bb, true, builder)))
4002 return llvm::make_error<PreviouslyReportedError>();
4003
4004 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.getTerminator());
4005 assert(yieldop && yieldop.getResults().size() == 1 &&
4006 "terminator must be omp.yield op and it must have exactly one "
4007 "argument");
4008 return moduleTranslation.lookupValue(yieldop.getResults()[0]);
4009 };
4010
4011 bool isIgnoreDenormalMode;
4012 bool isFineGrainedMemory;
4013 bool isRemoteMemory;
4014 extractAtomicControlFlags(opInst, isIgnoreDenormalMode, isFineGrainedMemory,
4015 isRemoteMemory);
4016 // Handle ambiguous alloca, if any.
4017 auto allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
4018 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4019 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4020 ompBuilder->createAtomicUpdate(ompLoc, allocaIP, llvmAtomicX, llvmExpr,
4021 atomicOrdering, binop, updateFn,
4022 isXBinopExpr, isIgnoreDenormalMode,
4023 isFineGrainedMemory, isRemoteMemory);
4024
4025 if (failed(handleError(afterIP, *opInst)))
4026 return failure();
4027
4028 builder.restoreIP(*afterIP);
4029 return success();
4030}
4031
4032static LogicalResult
4033convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp,
4034 llvm::IRBuilderBase &builder,
4035 LLVM::ModuleTranslation &moduleTranslation) {
4036 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4037 if (failed(checkImplementationStatus(*atomicCaptureOp)))
4038 return failure();
4039
4040 mlir::Value mlirExpr;
4041 bool isXBinopExpr = false, isPostfixUpdate = false;
4042 llvm::AtomicRMWInst::BinOp binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
4043
4044 omp::AtomicUpdateOp atomicUpdateOp = atomicCaptureOp.getAtomicUpdateOp();
4045 omp::AtomicWriteOp atomicWriteOp = atomicCaptureOp.getAtomicWriteOp();
4046
4047 assert((atomicUpdateOp || atomicWriteOp) &&
4048 "internal op must be an atomic.update or atomic.write op");
4049
4050 if (atomicWriteOp) {
4051 isPostfixUpdate = true;
4052 mlirExpr = atomicWriteOp.getExpr();
4053 } else {
4054 isPostfixUpdate = atomicCaptureOp.getSecondOp() ==
4055 atomicCaptureOp.getAtomicUpdateOp().getOperation();
4056 auto &innerOpList = atomicUpdateOp.getRegion().front().getOperations();
4057 // Find the binary update operation that uses the region argument
4058 // and get the expression to update
4059 if (innerOpList.size() == 2) {
4060 mlir::Operation &innerOp = *atomicUpdateOp.getRegion().front().begin();
4061 if (!llvm::is_contained(innerOp.getOperands(),
4062 atomicUpdateOp.getRegion().getArgument(0))) {
4063 return atomicUpdateOp.emitError(
4064 "no atomic update operation with region argument"
4065 " as operand found inside atomic.update region");
4066 }
4067 binop = convertBinOpToAtomic(innerOp);
4068 isXBinopExpr =
4069 innerOp.getOperand(0) == atomicUpdateOp.getRegion().getArgument(0);
4070 mlirExpr = (isXBinopExpr ? innerOp.getOperand(1) : innerOp.getOperand(0));
4071 } else {
4072 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
4073 }
4074 }
4075
4076 llvm::Value *llvmExpr = moduleTranslation.lookupValue(mlirExpr);
4077 llvm::Value *llvmX =
4078 moduleTranslation.lookupValue(atomicCaptureOp.getAtomicReadOp().getX());
4079 llvm::Value *llvmV =
4080 moduleTranslation.lookupValue(atomicCaptureOp.getAtomicReadOp().getV());
4081 llvm::Type *llvmXElementType = moduleTranslation.convertType(
4082 atomicCaptureOp.getAtomicReadOp().getElementType());
4083 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
4084 /*isSigned=*/false,
4085 /*isVolatile=*/false};
4086 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicV = {llvmV, llvmXElementType,
4087 /*isSigned=*/false,
4088 /*isVolatile=*/false};
4089
4090 llvm::AtomicOrdering atomicOrdering =
4091 convertAtomicOrdering(atomicCaptureOp.getMemoryOrder());
4092
4093 auto updateFn =
4094 [&](llvm::Value *atomicx,
4095 llvm::IRBuilder<> &builder) -> llvm::Expected<llvm::Value *> {
4096 if (atomicWriteOp)
4097 return moduleTranslation.lookupValue(atomicWriteOp.getExpr());
4098 Block &bb = *atomicUpdateOp.getRegion().begin();
4099 moduleTranslation.mapValue(*atomicUpdateOp.getRegion().args_begin(),
4100 atomicx);
4101 moduleTranslation.mapBlock(&bb, builder.GetInsertBlock());
4102 if (failed(moduleTranslation.convertBlock(bb, true, builder)))
4103 return llvm::make_error<PreviouslyReportedError>();
4104
4105 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.getTerminator());
4106 assert(yieldop && yieldop.getResults().size() == 1 &&
4107 "terminator must be omp.yield op and it must have exactly one "
4108 "argument");
4109 return moduleTranslation.lookupValue(yieldop.getResults()[0]);
4110 };
4111
4112 bool isIgnoreDenormalMode;
4113 bool isFineGrainedMemory;
4114 bool isRemoteMemory;
4115 extractAtomicControlFlags(atomicUpdateOp, isIgnoreDenormalMode,
4116 isFineGrainedMemory, isRemoteMemory);
4117 // Handle ambiguous alloca, if any.
4118 auto allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
4119 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4120 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4121 ompBuilder->createAtomicCapture(
4122 ompLoc, allocaIP, llvmAtomicX, llvmAtomicV, llvmExpr, atomicOrdering,
4123 binop, updateFn, atomicUpdateOp, isPostfixUpdate, isXBinopExpr,
4124 isIgnoreDenormalMode, isFineGrainedMemory, isRemoteMemory);
4125
4126 if (failed(handleError(afterIP, *atomicCaptureOp)))
4127 return failure();
4128
4129 builder.restoreIP(*afterIP);
4130 return success();
4131}
4132
4133static llvm::omp::Directive convertCancellationConstructType(
4134 omp::ClauseCancellationConstructType directive) {
4135 switch (directive) {
4136 case omp::ClauseCancellationConstructType::Loop:
4137 return llvm::omp::Directive::OMPD_for;
4138 case omp::ClauseCancellationConstructType::Parallel:
4139 return llvm::omp::Directive::OMPD_parallel;
4140 case omp::ClauseCancellationConstructType::Sections:
4141 return llvm::omp::Directive::OMPD_sections;
4142 case omp::ClauseCancellationConstructType::Taskgroup:
4143 return llvm::omp::Directive::OMPD_taskgroup;
4144 }
4145 llvm_unreachable("Unhandled cancellation construct type");
4146}
4147
4148static LogicalResult
4149convertOmpCancel(omp::CancelOp op, llvm::IRBuilderBase &builder,
4150 LLVM::ModuleTranslation &moduleTranslation) {
4151 if (failed(checkImplementationStatus(*op.getOperation())))
4152 return failure();
4153
4154 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4155 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4156
4157 llvm::Value *ifCond = nullptr;
4158 if (Value ifVar = op.getIfExpr())
4159 ifCond = moduleTranslation.lookupValue(ifVar);
4160
4161 llvm::omp::Directive cancelledDirective =
4162 convertCancellationConstructType(op.getCancelDirective());
4163
4164 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4165 ompBuilder->createCancel(ompLoc, ifCond, cancelledDirective);
4166
4167 if (failed(handleError(afterIP, *op.getOperation())))
4168 return failure();
4169
4170 builder.restoreIP(afterIP.get());
4171
4172 return success();
4173}
4174
4175static LogicalResult
4176convertOmpCancellationPoint(omp::CancellationPointOp op,
4177 llvm::IRBuilderBase &builder,
4178 LLVM::ModuleTranslation &moduleTranslation) {
4179 if (failed(checkImplementationStatus(*op.getOperation())))
4180 return failure();
4181
4182 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4183 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4184
4185 llvm::omp::Directive cancelledDirective =
4186 convertCancellationConstructType(op.getCancelDirective());
4187
4188 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4189 ompBuilder->createCancellationPoint(ompLoc, cancelledDirective);
4190
4191 if (failed(handleError(afterIP, *op.getOperation())))
4192 return failure();
4193
4194 builder.restoreIP(afterIP.get());
4195
4196 return success();
4197}
4198
4199/// Converts an OpenMP Threadprivate operation into LLVM IR using
4200/// OpenMPIRBuilder.
4201static LogicalResult
4202convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder,
4203 LLVM::ModuleTranslation &moduleTranslation) {
4204 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4205 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4206 auto threadprivateOp = cast<omp::ThreadprivateOp>(opInst);
4207
4208 if (failed(checkImplementationStatus(opInst)))
4209 return failure();
4210
4211 Value symAddr = threadprivateOp.getSymAddr();
4212 auto *symOp = symAddr.getDefiningOp();
4213
4214 if (auto asCast = dyn_cast<LLVM::AddrSpaceCastOp>(symOp))
4215 symOp = asCast.getOperand().getDefiningOp();
4216
4217 if (!isa<LLVM::AddressOfOp>(symOp))
4218 return opInst.emitError("Addressing symbol not found");
4219 LLVM::AddressOfOp addressOfOp = dyn_cast<LLVM::AddressOfOp>(symOp);
4220
4221 LLVM::GlobalOp global =
4222 addressOfOp.getGlobal(moduleTranslation.symbolTable());
4223 llvm::GlobalValue *globalValue = moduleTranslation.lookupGlobal(global);
4224 llvm::Type *type = globalValue->getValueType();
4225 llvm::TypeSize typeSize =
4226 builder.GetInsertBlock()->getModule()->getDataLayout().getTypeStoreSize(
4227 type);
4228 llvm::ConstantInt *size = builder.getInt64(typeSize.getFixedValue());
4229 llvm::Value *callInst = ompBuilder->createCachedThreadPrivate(
4230 ompLoc, globalValue, size, global.getSymName() + ".cache");
4231 moduleTranslation.mapValue(opInst.getResult(0), callInst);
4232
4233 return success();
4234}
4235
4236static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind
4237convertToDeviceClauseKind(mlir::omp::DeclareTargetDeviceType deviceClause) {
4238 switch (deviceClause) {
4239 case mlir::omp::DeclareTargetDeviceType::host:
4240 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseHost;
4241 break;
4242 case mlir::omp::DeclareTargetDeviceType::nohost:
4243 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseNoHost;
4244 break;
4245 case mlir::omp::DeclareTargetDeviceType::any:
4246 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseAny;
4247 break;
4248 }
4249 llvm_unreachable("unhandled device clause");
4250}
4251
4252static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind
4254 mlir::omp::DeclareTargetCaptureClause captureClause) {
4255 switch (captureClause) {
4256 case mlir::omp::DeclareTargetCaptureClause::to:
4257 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo;
4258 case mlir::omp::DeclareTargetCaptureClause::link:
4259 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink;
4260 case mlir::omp::DeclareTargetCaptureClause::enter:
4261 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter;
4262 case mlir::omp::DeclareTargetCaptureClause::none:
4263 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryNone;
4264 }
4265 llvm_unreachable("unhandled capture clause");
4266}
4267
4269 Operation *op = value.getDefiningOp();
4270 if (auto addrCast = dyn_cast_if_present<LLVM::AddrSpaceCastOp>(op))
4271 op = addrCast->getOperand(0).getDefiningOp();
4272 if (auto addressOfOp = dyn_cast_if_present<LLVM::AddressOfOp>(op)) {
4273 auto modOp = addressOfOp->getParentOfType<mlir::ModuleOp>();
4274 return modOp.lookupSymbol(addressOfOp.getGlobalName());
4275 }
4276 return nullptr;
4277}
4278
4279static llvm::SmallString<64>
4280getDeclareTargetRefPtrSuffix(LLVM::GlobalOp globalOp,
4281 llvm::OpenMPIRBuilder &ompBuilder) {
4282 llvm::SmallString<64> suffix;
4283 llvm::raw_svector_ostream os(suffix);
4284 if (globalOp.getVisibility() == mlir::SymbolTable::Visibility::Private) {
4285 auto loc = globalOp->getLoc()->findInstanceOf<FileLineColLoc>();
4286 auto fileInfoCallBack = [&loc]() {
4287 return std::pair<std::string, uint64_t>(
4288 llvm::StringRef(loc.getFilename()), loc.getLine());
4289 };
4290
4291 auto vfs = llvm::vfs::getRealFileSystem();
4292 os << llvm::format(
4293 "_%x",
4294 ompBuilder.getTargetEntryUniqueInfo(fileInfoCallBack, *vfs).FileID);
4295 }
4296 os << "_decl_tgt_ref_ptr";
4297
4298 return suffix;
4299}
4300
4301static bool isDeclareTargetLink(Value value) {
4302 if (auto declareTargetGlobal =
4303 dyn_cast_if_present<omp::DeclareTargetInterface>(
4304 getGlobalOpFromValue(value)))
4305 if (declareTargetGlobal.getDeclareTargetCaptureClause() ==
4306 omp::DeclareTargetCaptureClause::link)
4307 return true;
4308 return false;
4309}
4310
4311static bool isDeclareTargetTo(Value value) {
4312 if (auto declareTargetGlobal =
4313 dyn_cast_if_present<omp::DeclareTargetInterface>(
4314 getGlobalOpFromValue(value)))
4315 if (declareTargetGlobal.getDeclareTargetCaptureClause() ==
4316 omp::DeclareTargetCaptureClause::to ||
4317 declareTargetGlobal.getDeclareTargetCaptureClause() ==
4318 omp::DeclareTargetCaptureClause::enter)
4319 return true;
4320 return false;
4321}
4322
4323// Returns the reference pointer generated by the lowering of the declare
4324// target operation in cases where the link clause is used or the to clause is
4325// used in USM mode.
4326static llvm::Value *
4328 LLVM::ModuleTranslation &moduleTranslation) {
4329 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4330 if (auto gOp =
4331 dyn_cast_or_null<LLVM::GlobalOp>(getGlobalOpFromValue(value))) {
4332 if (auto declareTargetGlobal =
4333 dyn_cast<omp::DeclareTargetInterface>(gOp.getOperation())) {
4334 // In this case, we must utilise the reference pointer generated by
4335 // the declare target operation, similar to Clang
4336 if ((declareTargetGlobal.getDeclareTargetCaptureClause() ==
4337 omp::DeclareTargetCaptureClause::link) ||
4338 (declareTargetGlobal.getDeclareTargetCaptureClause() ==
4339 omp::DeclareTargetCaptureClause::to &&
4340 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
4341 llvm::SmallString<64> suffix =
4342 getDeclareTargetRefPtrSuffix(gOp, *ompBuilder);
4343
4344 if (gOp.getSymName().contains(suffix))
4345 return moduleTranslation.getLLVMModule()->getNamedValue(
4346 gOp.getSymName());
4347
4348 return moduleTranslation.getLLVMModule()->getNamedValue(
4349 (gOp.getSymName().str() + suffix.str()).str());
4350 }
4351 }
4352 }
4353 return nullptr;
4354}
4355
4356namespace {
4357// Append customMappers information to existing MapInfosTy
4358struct MapInfosTy : llvm::OpenMPIRBuilder::MapInfosTy {
4359 SmallVector<Operation *, 4> Mappers;
4360
4361 /// Append arrays in \a CurInfo.
4362 void append(MapInfosTy &curInfo) {
4363 Mappers.append(curInfo.Mappers.begin(), curInfo.Mappers.end());
4364 llvm::OpenMPIRBuilder::MapInfosTy::append(curInfo);
4365 }
4366};
4367// A small helper structure to contain data gathered
4368// for map lowering and coalese it into one area and
4369// avoiding extra computations such as searches in the
4370// llvm module for lowered mapped variables or checking
4371// if something is declare target (and retrieving the
4372// value) more than neccessary.
4373struct MapInfoData : MapInfosTy {
4374 llvm::SmallVector<bool, 4> IsDeclareTarget;
4375 llvm::SmallVector<bool, 4> IsAMember;
4376 // Identify if mapping was added by mapClause or use_device clauses.
4377 llvm::SmallVector<bool, 4> IsAMapping;
4378 llvm::SmallVector<mlir::Operation *, 4> MapClause;
4379 llvm::SmallVector<llvm::Value *, 4> OriginalValue;
4380 // Stripped off array/pointer to get the underlying
4381 // element type
4382 llvm::SmallVector<llvm::Type *, 4> BaseType;
4383
4384 /// Append arrays in \a CurInfo.
4385 void append(MapInfoData &CurInfo) {
4386 IsDeclareTarget.append(CurInfo.IsDeclareTarget.begin(),
4387 CurInfo.IsDeclareTarget.end());
4388 MapClause.append(CurInfo.MapClause.begin(), CurInfo.MapClause.end());
4389 OriginalValue.append(CurInfo.OriginalValue.begin(),
4390 CurInfo.OriginalValue.end());
4391 BaseType.append(CurInfo.BaseType.begin(), CurInfo.BaseType.end());
4392 MapInfosTy::append(CurInfo);
4393 }
4394};
4395
4396enum class TargetDirectiveEnumTy : uint32_t {
4397 None = 0,
4398 Target = 1,
4399 TargetData = 2,
4400 TargetEnterData = 3,
4401 TargetExitData = 4,
4402 TargetUpdate = 5
4403};
4404
4405static TargetDirectiveEnumTy getTargetDirectiveEnumTyFromOp(Operation *op) {
4406 return llvm::TypeSwitch<Operation *, TargetDirectiveEnumTy>(op)
4407 .Case([](omp::TargetDataOp) { return TargetDirectiveEnumTy::TargetData; })
4408 .Case([](omp::TargetEnterDataOp) {
4409 return TargetDirectiveEnumTy::TargetEnterData;
4410 })
4411 .Case([&](omp::TargetExitDataOp) {
4412 return TargetDirectiveEnumTy::TargetExitData;
4413 })
4414 .Case([&](omp::TargetUpdateOp) {
4415 return TargetDirectiveEnumTy::TargetUpdate;
4416 })
4417 .Case([&](omp::TargetOp) { return TargetDirectiveEnumTy::Target; })
4418 .Default([&](Operation *op) { return TargetDirectiveEnumTy::None; });
4419}
4420
4421} // namespace
4422
4423static uint64_t getArrayElementSizeInBits(LLVM::LLVMArrayType arrTy,
4424 DataLayout &dl) {
4425 if (auto nestedArrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(
4426 arrTy.getElementType()))
4427 return getArrayElementSizeInBits(nestedArrTy, dl);
4428 return dl.getTypeSizeInBits(arrTy.getElementType());
4429}
4430
4431// This function calculates the size to be offloaded for a specified type, given
4432// its associated map clause (which can contain bounds information which affects
4433// the total size), this size is calculated based on the underlying element type
4434// e.g. given a 1-D array of ints, we will calculate the size from the integer
4435// type * number of elements in the array. This size can be used in other
4436// calculations but is ultimately used as an argument to the OpenMP runtimes
4437// kernel argument structure which is generated through the combinedInfo data
4438// structures.
4439// This function is somewhat equivalent to Clang's getExprTypeSize inside of
4440// CGOpenMPRuntime.cpp.
4441static llvm::Value *getSizeInBytes(DataLayout &dl, const mlir::Type &type,
4442 Operation *clauseOp,
4443 llvm::Value *basePointer,
4444 llvm::Type *baseType,
4445 llvm::IRBuilderBase &builder,
4446 LLVM::ModuleTranslation &moduleTranslation) {
4447 if (auto memberClause =
4448 mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(clauseOp)) {
4449 // This calculates the size to transfer based on bounds and the underlying
4450 // element type, provided bounds have been specified (Fortran
4451 // pointers/allocatables/target and arrays that have sections specified fall
4452 // into this as well)
4453 if (!memberClause.getBounds().empty()) {
4454 llvm::Value *elementCount = builder.getInt64(1);
4455 for (auto bounds : memberClause.getBounds()) {
4456 if (auto boundOp = mlir::dyn_cast_if_present<mlir::omp::MapBoundsOp>(
4457 bounds.getDefiningOp())) {
4458 // The below calculation for the size to be mapped calculated from the
4459 // map.info's bounds is: (elemCount * [UB - LB] + 1), later we
4460 // multiply by the underlying element types byte size to get the full
4461 // size to be offloaded based on the bounds
4462 elementCount = builder.CreateMul(
4463 elementCount,
4464 builder.CreateAdd(
4465 builder.CreateSub(
4466 moduleTranslation.lookupValue(boundOp.getUpperBound()),
4467 moduleTranslation.lookupValue(boundOp.getLowerBound())),
4468 builder.getInt64(1)));
4469 }
4470 }
4471
4472 // utilising getTypeSizeInBits instead of getTypeSize as getTypeSize gives
4473 // the size in inconsistent byte or bit format.
4474 uint64_t underlyingTypeSzInBits = dl.getTypeSizeInBits(type);
4475 if (auto arrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(type))
4476 underlyingTypeSzInBits = getArrayElementSizeInBits(arrTy, dl);
4477
4478 // The size in bytes x number of elements, the sizeInBytes stored is
4479 // the underyling types size, e.g. if ptr<i32>, it'll be the i32's
4480 // size, so we do some on the fly runtime math to get the size in
4481 // bytes from the extent (ub - lb) * sizeInBytes. NOTE: This may need
4482 // some adjustment for members with more complex types.
4483 return builder.CreateMul(elementCount,
4484 builder.getInt64(underlyingTypeSzInBits / 8));
4485 }
4486 }
4487
4488 return builder.getInt64(dl.getTypeSizeInBits(type) / 8);
4489}
4490
4491// Convert the MLIR map flag set to the runtime map flag set for embedding
4492// in LLVM-IR. This is important as the two bit-flag lists do not correspond
4493// 1-to-1 as there's flags the runtime doesn't care about and vice versa.
4494// Certain flags are discarded here such as RefPtee and co.
4495static llvm::omp::OpenMPOffloadMappingFlags
4496convertClauseMapFlags(omp::ClauseMapFlags mlirFlags) {
4497 auto mapTypeToBool = [&mlirFlags](omp::ClauseMapFlags flag) {
4498 return (mlirFlags & flag) == flag;
4499 };
4500 const bool hasExplicitMap =
4501 (mlirFlags & ~omp::ClauseMapFlags::is_device_ptr) !=
4502 omp::ClauseMapFlags::none;
4503
4504 llvm::omp::OpenMPOffloadMappingFlags mapType =
4505 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
4506
4507 if (mapTypeToBool(omp::ClauseMapFlags::to))
4508 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
4509
4510 if (mapTypeToBool(omp::ClauseMapFlags::from))
4511 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
4512
4513 if (mapTypeToBool(omp::ClauseMapFlags::always))
4514 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
4515
4516 if (mapTypeToBool(omp::ClauseMapFlags::del))
4517 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
4518
4519 if (mapTypeToBool(omp::ClauseMapFlags::return_param))
4520 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
4521
4522 if (mapTypeToBool(omp::ClauseMapFlags::priv))
4523 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRIVATE;
4524
4525 if (mapTypeToBool(omp::ClauseMapFlags::literal))
4526 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
4527
4528 if (mapTypeToBool(omp::ClauseMapFlags::implicit))
4529 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
4530
4531 if (mapTypeToBool(omp::ClauseMapFlags::close))
4532 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
4533
4534 if (mapTypeToBool(omp::ClauseMapFlags::present))
4535 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT;
4536
4537 if (mapTypeToBool(omp::ClauseMapFlags::ompx_hold))
4538 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD;
4539
4540 if (mapTypeToBool(omp::ClauseMapFlags::attach))
4541 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH;
4542
4543 if (mapTypeToBool(omp::ClauseMapFlags::is_device_ptr)) {
4544 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
4545 if (!hasExplicitMap)
4546 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
4547 }
4548
4549 return mapType;
4550}
4551
4553 MapInfoData &mapData, SmallVectorImpl<Value> &mapVars,
4554 LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl,
4555 llvm::IRBuilderBase &builder, ArrayRef<Value> useDevPtrOperands = {},
4556 ArrayRef<Value> useDevAddrOperands = {},
4557 ArrayRef<Value> hasDevAddrOperands = {}) {
4558 auto checkIsAMember = [](const auto &mapVars, auto mapOp) {
4559 // Check if this is a member mapping and correctly assign that it is, if
4560 // it is a member of a larger object.
4561 // TODO: Need better handling of members, and distinguishing of members
4562 // that are implicitly allocated on device vs explicitly passed in as
4563 // arguments.
4564 // TODO: May require some further additions to support nested record
4565 // types, i.e. member maps that can have member maps.
4566 for (Value mapValue : mapVars) {
4567 auto map = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
4568 for (auto member : map.getMembers())
4569 if (member == mapOp)
4570 return true;
4571 }
4572 return false;
4573 };
4574
4575 // Process MapOperands
4576 for (Value mapValue : mapVars) {
4577 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
4578 Value offloadPtr =
4579 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
4580 mapData.OriginalValue.push_back(moduleTranslation.lookupValue(offloadPtr));
4581 mapData.Pointers.push_back(mapData.OriginalValue.back());
4582
4583 if (llvm::Value *refPtr =
4584 getRefPtrIfDeclareTarget(offloadPtr, moduleTranslation)) {
4585 mapData.IsDeclareTarget.push_back(true);
4586 mapData.BasePointers.push_back(refPtr);
4587 } else if (isDeclareTargetTo(offloadPtr)) {
4588 mapData.IsDeclareTarget.push_back(true);
4589 mapData.BasePointers.push_back(mapData.OriginalValue.back());
4590 } else { // regular mapped variable
4591 mapData.IsDeclareTarget.push_back(false);
4592 mapData.BasePointers.push_back(mapData.OriginalValue.back());
4593 }
4594
4595 mapData.BaseType.push_back(
4596 moduleTranslation.convertType(mapOp.getVarType()));
4597 mapData.Sizes.push_back(
4598 getSizeInBytes(dl, mapOp.getVarType(), mapOp, mapData.Pointers.back(),
4599 mapData.BaseType.back(), builder, moduleTranslation));
4600 mapData.MapClause.push_back(mapOp.getOperation());
4601 mapData.Types.push_back(convertClauseMapFlags(mapOp.getMapType()));
4602 mapData.Names.push_back(LLVM::createMappingInformation(
4603 mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
4604 mapData.DevicePointers.push_back(llvm::OpenMPIRBuilder::DeviceInfoTy::None);
4605 if (mapOp.getMapperId())
4606 mapData.Mappers.push_back(
4608 mapOp, mapOp.getMapperIdAttr()));
4609 else
4610 mapData.Mappers.push_back(nullptr);
4611 mapData.IsAMapping.push_back(true);
4612 mapData.IsAMember.push_back(checkIsAMember(mapVars, mapOp));
4613 }
4614
4615 auto findMapInfo = [&mapData](llvm::Value *val,
4616 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
4617 unsigned index = 0;
4618 bool found = false;
4619 for (llvm::Value *basePtr : mapData.OriginalValue) {
4620 if (basePtr == val && mapData.IsAMapping[index]) {
4621 found = true;
4622 mapData.Types[index] |=
4623 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
4624 mapData.DevicePointers[index] = devInfoTy;
4625 }
4626 index++;
4627 }
4628 return found;
4629 };
4630
4631 // Process useDevPtr(Addr)Operands
4632 auto addDevInfos = [&](const llvm::ArrayRef<Value> &useDevOperands,
4633 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
4634 for (Value mapValue : useDevOperands) {
4635 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
4636 Value offloadPtr =
4637 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
4638 llvm::Value *origValue = moduleTranslation.lookupValue(offloadPtr);
4639
4640 // Check if map info is already present for this entry.
4641 if (!findMapInfo(origValue, devInfoTy)) {
4642 mapData.OriginalValue.push_back(origValue);
4643 mapData.Pointers.push_back(mapData.OriginalValue.back());
4644 mapData.IsDeclareTarget.push_back(false);
4645 mapData.BasePointers.push_back(mapData.OriginalValue.back());
4646 mapData.BaseType.push_back(
4647 moduleTranslation.convertType(mapOp.getVarType()));
4648 mapData.Sizes.push_back(builder.getInt64(0));
4649 mapData.MapClause.push_back(mapOp.getOperation());
4650 mapData.Types.push_back(
4651 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM);
4652 mapData.Names.push_back(LLVM::createMappingInformation(
4653 mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
4654 mapData.DevicePointers.push_back(devInfoTy);
4655 mapData.Mappers.push_back(nullptr);
4656 mapData.IsAMapping.push_back(false);
4657 mapData.IsAMember.push_back(checkIsAMember(useDevOperands, mapOp));
4658 }
4659 }
4660 };
4661
4662 addDevInfos(useDevAddrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
4663 addDevInfos(useDevPtrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer);
4664
4665 for (Value mapValue : hasDevAddrOperands) {
4666 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
4667 Value offloadPtr =
4668 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
4669 llvm::Value *origValue = moduleTranslation.lookupValue(offloadPtr);
4670 auto mapType = convertClauseMapFlags(mapOp.getMapType());
4671 auto mapTypeAlways = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
4672 bool isDevicePtr =
4673 (mapOp.getMapType() & omp::ClauseMapFlags::is_device_ptr) !=
4674 omp::ClauseMapFlags::none;
4675
4676 mapData.OriginalValue.push_back(origValue);
4677 mapData.BasePointers.push_back(origValue);
4678 mapData.Pointers.push_back(origValue);
4679 mapData.IsDeclareTarget.push_back(false);
4680 mapData.BaseType.push_back(
4681 moduleTranslation.convertType(mapOp.getVarType()));
4682 mapData.Sizes.push_back(
4683 builder.getInt64(dl.getTypeSize(mapOp.getVarType())));
4684 mapData.MapClause.push_back(mapOp.getOperation());
4685 if (llvm::to_underlying(mapType & mapTypeAlways)) {
4686 // Descriptors are mapped with the ALWAYS flag, since they can get
4687 // rematerialized, so the address of the decriptor for a given object
4688 // may change from one place to another.
4689 mapData.Types.push_back(mapType);
4690 // Technically it's possible for a non-descriptor mapping to have
4691 // both has-device-addr and ALWAYS, so lookup the mapper in case it
4692 // exists.
4693 if (mapOp.getMapperId()) {
4694 mapData.Mappers.push_back(
4696 mapOp, mapOp.getMapperIdAttr()));
4697 } else {
4698 mapData.Mappers.push_back(nullptr);
4699 }
4700 } else {
4701 // For is_device_ptr we need the map type to propagate so the runtime
4702 // can materialize the device-side copy of the pointer container.
4703 mapData.Types.push_back(
4704 isDevicePtr ? mapType
4705 : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL);
4706 mapData.Mappers.push_back(nullptr);
4707 }
4708 mapData.Names.push_back(LLVM::createMappingInformation(
4709 mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
4710 mapData.DevicePointers.push_back(
4711 isDevicePtr ? llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer
4712 : llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
4713 mapData.IsAMapping.push_back(false);
4714 mapData.IsAMember.push_back(checkIsAMember(hasDevAddrOperands, mapOp));
4715 }
4716}
4717
4718static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp) {
4719 auto *res = llvm::find(mapData.MapClause, memberOp);
4720 assert(res != mapData.MapClause.end() &&
4721 "MapInfoOp for member not found in MapData, cannot return index");
4722 return std::distance(mapData.MapClause.begin(), res);
4723}
4724
4726 omp::MapInfoOp mapInfo) {
4727 ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
4728 llvm::SmallVector<size_t> occludedChildren;
4729 llvm::sort(
4730 indices.begin(), indices.end(), [&](const size_t a, const size_t b) {
4731 // Bail early if we are asked to look at the same index. If we do not
4732 // bail early, we can end up mistakenly adding indices to
4733 // occludedChildren. This can occur with some types of libc++ hardening.
4734 if (a == b)
4735 return false;
4736
4737 auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]);
4738 auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]);
4739
4740 for (auto it : llvm::zip(memberIndicesA, memberIndicesB)) {
4741 int64_t aIndex = mlir::cast<IntegerAttr>(std::get<0>(it)).getInt();
4742 int64_t bIndex = mlir::cast<IntegerAttr>(std::get<1>(it)).getInt();
4743
4744 if (aIndex == bIndex)
4745 continue;
4746
4747 if (aIndex < bIndex)
4748 return true;
4749
4750 if (aIndex > bIndex)
4751 return false;
4752 }
4753
4754 // Iterated up until the end of the smallest member and
4755 // they were found to be equal up to that point, so select
4756 // the member with the lowest index count, so the "parent"
4757 bool memberAParent = memberIndicesA.size() < memberIndicesB.size();
4758 if (memberAParent)
4759 occludedChildren.push_back(b);
4760 else
4761 occludedChildren.push_back(a);
4762 return memberAParent;
4763 });
4764
4765 // We remove children from the index list that are overshadowed by
4766 // a parent, this prevents us retrieving these as the first or last
4767 // element when the parent is the correct element in these cases.
4768 for (auto v : occludedChildren)
4769 indices.erase(std::remove(indices.begin(), indices.end(), v),
4770 indices.end());
4771}
4772
4773static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo,
4774 bool first) {
4775 ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
4776 // Only 1 member has been mapped, we can return it.
4777 if (indexAttr.size() == 1)
4778 return cast<omp::MapInfoOp>(mapInfo.getMembers()[0].getDefiningOp());
4779 llvm::SmallVector<size_t> indices(indexAttr.size());
4780 std::iota(indices.begin(), indices.end(), 0);
4781 sortMapIndices(indices, mapInfo);
4782 return llvm::cast<omp::MapInfoOp>(
4783 mapInfo.getMembers()[first ? indices.front() : indices.back()]
4784 .getDefiningOp());
4785}
4786
4787/// This function calculates the array/pointer offset for map data provided
4788/// with bounds operations, e.g. when provided something like the following:
4789///
4790/// Fortran
4791/// map(tofrom: array(2:5, 3:2))
4792///
4793/// We must calculate the initial pointer offset to pass across, this function
4794/// performs this using bounds.
4795///
4796/// TODO/WARNING: This only supports Fortran's column major indexing currently
4797/// as is noted in the note below and comments in the function, we must extend
4798/// this function when we add a C++ frontend.
4799/// NOTE: which while specified in row-major order it currently needs to be
4800/// flipped for Fortran's column order array allocation and access (as
4801/// opposed to C++'s row-major, hence the backwards processing where order is
4802/// important). This is likely important to keep in mind for the future when
4803/// we incorporate a C++ frontend, both frontends will need to agree on the
4804/// ordering of generated bounds operations (one may have to flip them) to
4805/// make the below lowering frontend agnostic. The offload size
4806/// calcualtion may also have to be adjusted for C++.
4807static std::vector<llvm::Value *>
4809 llvm::IRBuilderBase &builder, bool isArrayTy,
4810 OperandRange bounds) {
4811 std::vector<llvm::Value *> idx;
4812 // There's no bounds to calculate an offset from, we can safely
4813 // ignore and return no indices.
4814 if (bounds.empty())
4815 return idx;
4816
4817 // If we have an array type, then we have its type so can treat it as a
4818 // normal GEP instruction where the bounds operations are simply indexes
4819 // into the array. We currently do reverse order of the bounds, which
4820 // I believe leans more towards Fortran's column-major in memory.
4821 if (isArrayTy) {
4822 idx.push_back(builder.getInt64(0));
4823 for (int i = bounds.size() - 1; i >= 0; --i) {
4824 if (auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
4825 bounds[i].getDefiningOp())) {
4826 idx.push_back(moduleTranslation.lookupValue(boundOp.getLowerBound()));
4827 }
4828 }
4829 } else {
4830 // If we do not have an array type, but we have bounds, then we're dealing
4831 // with a pointer that's being treated like an array and we have the
4832 // underlying type e.g. an i32, or f64 etc, e.g. a fortran descriptor base
4833 // address (pointer pointing to the actual data) so we must caclulate the
4834 // offset using a single index which the following loop attempts to
4835 // compute using the standard column-major algorithm e.g for a 3D array:
4836 //
4837 // ((((c_idx * b_len) + b_idx) * a_len) + a_idx)
4838 //
4839 // It is of note that it's doing column-major rather than row-major at the
4840 // moment, but having a way for the frontend to indicate which major format
4841 // to use or standardizing/canonicalizing the order of the bounds to compute
4842 // the offset may be useful in the future when there's other frontends with
4843 // different formats.
4844 std::vector<llvm::Value *> dimensionIndexSizeOffset;
4845 for (int i = bounds.size() - 1; i >= 0; --i) {
4846 if (auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
4847 bounds[i].getDefiningOp())) {
4848 if (i == ((int)bounds.size() - 1))
4849 idx.emplace_back(
4850 moduleTranslation.lookupValue(boundOp.getLowerBound()));
4851 else
4852 idx.back() = builder.CreateAdd(
4853 builder.CreateMul(idx.back(), moduleTranslation.lookupValue(
4854 boundOp.getExtent())),
4855 moduleTranslation.lookupValue(boundOp.getLowerBound()));
4856 }
4857 }
4858 }
4859
4860 return idx;
4861}
4862
4864 llvm::transform(values, std::back_inserter(ints), [](Attribute value) {
4865 return cast<IntegerAttr>(value).getInt();
4866 });
4867}
4868
4869// Gathers members that are overlapping in the parent, excluding members that
4870// themselves overlap, keeping the top-most (closest to parents level) map.
4871static void
4873 omp::MapInfoOp parentOp) {
4874 // No members mapped, no overlaps.
4875 if (parentOp.getMembers().empty())
4876 return;
4877
4878 // Single member, we can insert and return early.
4879 if (parentOp.getMembers().size() == 1) {
4880 overlapMapDataIdxs.push_back(0);
4881 return;
4882 }
4883
4884 // 1) collect list of top-level overlapping members from MemberOp
4886 ArrayAttr indexAttr = parentOp.getMembersIndexAttr();
4887 for (auto [memIndex, indicesAttr] : llvm::enumerate(indexAttr))
4888 memberByIndex.push_back(
4889 std::make_pair(memIndex, cast<ArrayAttr>(indicesAttr)));
4890
4891 // Sort the smallest first (higher up the parent -> member chain), so that
4892 // when we remove members, we remove as much as we can in the initial
4893 // iterations, shortening the number of passes required.
4894 llvm::sort(memberByIndex.begin(), memberByIndex.end(),
4895 [&](auto a, auto b) { return a.second.size() < b.second.size(); });
4896
4897 // Remove elements from the vector if there is a parent element that
4898 // supersedes it. i.e. if member [0] is mapped, we can remove members [0,1],
4899 // [0,2].. etc.
4901 for (auto v : memberByIndex) {
4902 llvm::SmallVector<int64_t> vArr(v.second.size());
4903 getAsIntegers(v.second, vArr);
4904 skipList.push_back(
4905 *std::find_if(memberByIndex.begin(), memberByIndex.end(), [&](auto x) {
4906 if (v == x)
4907 return false;
4908 llvm::SmallVector<int64_t> xArr(x.second.size());
4909 getAsIntegers(x.second, xArr);
4910 return std::equal(vArr.begin(), vArr.end(), xArr.begin()) &&
4911 xArr.size() >= vArr.size();
4912 }));
4913 }
4914
4915 // Collect the indices, as we need the base pointer etc. from the MapData
4916 // structure which is primarily accessible via index at the moment.
4917 for (auto v : memberByIndex)
4918 if (find(skipList.begin(), skipList.end(), v) == skipList.end())
4919 overlapMapDataIdxs.push_back(v.first);
4920}
4921
4922// The intent is to verify if the mapped data being passed is a
4923// pointer -> pointee that requires special handling in certain cases,
4924// e.g. applying the OMP_MAP_PTR_AND_OBJ map type.
4925//
4926// There may be a better way to verify this, but unfortunately with
4927// opaque pointers we lose the ability to easily check if something is
4928// a pointer whilst maintaining access to the underlying type.
4929static bool checkIfPointerMap(omp::MapInfoOp mapOp) {
4930 // If we have a varPtrPtr field assigned then the underlying type is a pointer
4931 if (mapOp.getVarPtrPtr())
4932 return true;
4933
4934 // If the map data is declare target with a link clause, then it's represented
4935 // as a pointer when we lower it to LLVM-IR even if at the MLIR level it has
4936 // no relation to pointers.
4937 if (isDeclareTargetLink(mapOp.getVarPtr()))
4938 return true;
4939
4940 return false;
4941}
4942
4943// This creates two insertions into the MapInfosTy data structure for the
4944// "parent" of a set of members, (usually a container e.g.
4945// class/structure/derived type) when subsequent members have also been
4946// explicitly mapped on the same map clause. Certain types, such as Fortran
4947// descriptors are mapped like this as well, however, the members are
4948// implicit as far as a user is concerned, but we must explicitly map them
4949// internally.
4950//
4951// This function also returns the memberOfFlag for this particular parent,
4952// which is utilised in subsequent member mappings (by modifying there map type
4953// with it) to indicate that a member is part of this parent and should be
4954// treated by the runtime as such. Important to achieve the correct mapping.
4955//
4956// This function borrows a lot from Clang's emitCombinedEntry function
4957// inside of CGOpenMPRuntime.cpp
4958static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
4959 LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder,
4960 llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo,
4961 MapInfoData &mapData, uint64_t mapDataIndex,
4962 TargetDirectiveEnumTy targetDirective) {
4963 assert(!ompBuilder.Config.isTargetDevice() &&
4964 "function only supported for host device codegen");
4965
4966 auto parentClause =
4967 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
4968
4969 auto *parentMapper = mapData.Mappers[mapDataIndex];
4970
4971 // Map the first segment of the parent. If a user-defined mapper is attached,
4972 // include the parent's to/from-style bits (and common modifiers) in this
4973 // base entry so the mapper receives correct copy semantics via its 'type'
4974 // parameter. Also keep TARGET_PARAM when required for kernel arguments.
4975 llvm::omp::OpenMPOffloadMappingFlags baseFlag =
4976 (targetDirective == TargetDirectiveEnumTy::Target &&
4977 !mapData.IsDeclareTarget[mapDataIndex])
4978 ? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM
4979 : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
4980
4981 if (parentMapper) {
4982 using mapFlags = llvm::omp::OpenMPOffloadMappingFlags;
4983 // Preserve relevant map-type bits from the parent clause. These include
4984 // the copy direction (TO/FROM), as well as commonly used modifiers that
4985 // should be visible to the mapper for correct behaviour.
4986 mapFlags parentFlags = mapData.Types[mapDataIndex];
4987 mapFlags preserve = mapFlags::OMP_MAP_TO | mapFlags::OMP_MAP_FROM |
4988 mapFlags::OMP_MAP_ALWAYS | mapFlags::OMP_MAP_CLOSE |
4989 mapFlags::OMP_MAP_PRESENT | mapFlags::OMP_MAP_OMPX_HOLD;
4990 baseFlag |= (parentFlags & preserve);
4991 }
4992
4993 combinedInfo.Types.emplace_back(baseFlag);
4994 combinedInfo.DevicePointers.emplace_back(
4995 mapData.DevicePointers[mapDataIndex]);
4996 // Only attach the mapper to the base entry when we are mapping the whole
4997 // parent. Combined/segment entries must not carry a mapper; otherwise the
4998 // mapper can be invoked with a partial size, which is undefined behaviour.
4999 combinedInfo.Mappers.emplace_back(
5000 parentMapper && !parentClause.getPartialMap() ? parentMapper : nullptr);
5001 combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
5002 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
5003 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
5004
5005 // Calculate size of the parent object being mapped based on the
5006 // addresses at runtime, highAddr - lowAddr = size. This of course
5007 // doesn't factor in allocated data like pointers, hence the further
5008 // processing of members specified by users, or in the case of
5009 // Fortran pointers and allocatables, the mapping of the pointed to
5010 // data by the descriptor (which itself, is a structure containing
5011 // runtime information on the dynamically allocated data).
5012 llvm::Value *lowAddr, *highAddr;
5013 if (!parentClause.getPartialMap()) {
5014 lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
5015 builder.getPtrTy());
5016 highAddr = builder.CreatePointerCast(
5017 builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
5018 mapData.Pointers[mapDataIndex], 1),
5019 builder.getPtrTy());
5020 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
5021 } else {
5022 auto mapOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
5023 int firstMemberIdx = getMapDataMemberIdx(
5024 mapData, getFirstOrLastMappedMemberPtr(mapOp, true));
5025 lowAddr = builder.CreatePointerCast(mapData.Pointers[firstMemberIdx],
5026 builder.getPtrTy());
5027 int lastMemberIdx = getMapDataMemberIdx(
5028 mapData, getFirstOrLastMappedMemberPtr(mapOp, false));
5029 highAddr = builder.CreatePointerCast(
5030 builder.CreateGEP(mapData.BaseType[lastMemberIdx],
5031 mapData.Pointers[lastMemberIdx], builder.getInt64(1)),
5032 builder.getPtrTy());
5033 combinedInfo.Pointers.emplace_back(mapData.Pointers[firstMemberIdx]);
5034 }
5035
5036 llvm::Value *size = builder.CreateIntCast(
5037 builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
5038 builder.getInt64Ty(),
5039 /*isSigned=*/false);
5040 combinedInfo.Sizes.push_back(size);
5041
5042 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag =
5043 ompBuilder.getMemberOfFlag(combinedInfo.BasePointers.size() - 1);
5044
5045 // This creates the initial MEMBER_OF mapping that consists of
5046 // the parent/top level container (same as above effectively, except
5047 // with a fixed initial compile time size and separate maptype which
5048 // indicates the true mape type (tofrom etc.). This parent mapping is
5049 // only relevant if the structure in its totality is being mapped,
5050 // otherwise the above suffices.
5051 if (!parentClause.getPartialMap()) {
5052 // TODO: This will need to be expanded to include the whole host of logic
5053 // for the map flags that Clang currently supports (e.g. it should do some
5054 // further case specific flag modifications). For the moment, it handles
5055 // what we support as expected.
5056 llvm::omp::OpenMPOffloadMappingFlags mapFlag = mapData.Types[mapDataIndex];
5057 bool hasMapClose = (llvm::omp::OpenMPOffloadMappingFlags(mapFlag) &
5058 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE) ==
5059 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
5060 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
5061
5062 if (targetDirective == TargetDirectiveEnumTy::TargetUpdate || hasMapClose) {
5063 combinedInfo.Types.emplace_back(mapFlag);
5064 combinedInfo.DevicePointers.emplace_back(
5065 mapData.DevicePointers[mapDataIndex]);
5066 combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
5067 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
5068 combinedInfo.BasePointers.emplace_back(
5069 mapData.BasePointers[mapDataIndex]);
5070 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
5071 combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIndex]);
5072 combinedInfo.Mappers.emplace_back(nullptr);
5073 } else {
5074 llvm::SmallVector<size_t> overlapIdxs;
5075 // Find all of the members that "overlap", i.e. occlude other members that
5076 // were mapped alongside the parent, e.g. member [0], occludes [0,1] and
5077 // [0,2], but not [1,0].
5078 getOverlappedMembers(overlapIdxs, parentClause);
5079 // We need to make sure the overlapped members are sorted in order of
5080 // lowest address to highest address.
5081 sortMapIndices(overlapIdxs, parentClause);
5082
5083 lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
5084 builder.getPtrTy());
5085 highAddr = builder.CreatePointerCast(
5086 builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
5087 mapData.Pointers[mapDataIndex], 1),
5088 builder.getPtrTy());
5089
5090 // TODO: We may want to skip arrays/array sections in this as Clang does.
5091 // It appears to be an optimisation rather than a necessity though,
5092 // but this requires further investigation. However, we would have to make
5093 // sure to not exclude maps with bounds that ARE pointers, as these are
5094 // processed as separate components, i.e. pointer + data.
5095 for (auto v : overlapIdxs) {
5096 auto mapDataOverlapIdx = getMapDataMemberIdx(
5097 mapData,
5098 cast<omp::MapInfoOp>(parentClause.getMembers()[v].getDefiningOp()));
5099 combinedInfo.Types.emplace_back(mapFlag);
5100 combinedInfo.DevicePointers.emplace_back(
5101 mapData.DevicePointers[mapDataOverlapIdx]);
5102 combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
5103 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
5104 combinedInfo.BasePointers.emplace_back(
5105 mapData.BasePointers[mapDataIndex]);
5106 combinedInfo.Mappers.emplace_back(nullptr);
5107 combinedInfo.Pointers.emplace_back(lowAddr);
5108 combinedInfo.Sizes.emplace_back(builder.CreateIntCast(
5109 builder.CreatePtrDiff(builder.getInt8Ty(),
5110 mapData.OriginalValue[mapDataOverlapIdx],
5111 lowAddr),
5112 builder.getInt64Ty(), /*isSigned=*/true));
5113 lowAddr = builder.CreateConstGEP1_32(
5114 checkIfPointerMap(llvm::cast<omp::MapInfoOp>(
5115 mapData.MapClause[mapDataOverlapIdx]))
5116 ? builder.getPtrTy()
5117 : mapData.BaseType[mapDataOverlapIdx],
5118 mapData.BasePointers[mapDataOverlapIdx], 1);
5119 }
5120
5121 combinedInfo.Types.emplace_back(mapFlag);
5122 combinedInfo.DevicePointers.emplace_back(
5123 mapData.DevicePointers[mapDataIndex]);
5124 combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
5125 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
5126 combinedInfo.BasePointers.emplace_back(
5127 mapData.BasePointers[mapDataIndex]);
5128 combinedInfo.Mappers.emplace_back(nullptr);
5129 combinedInfo.Pointers.emplace_back(lowAddr);
5130 combinedInfo.Sizes.emplace_back(builder.CreateIntCast(
5131 builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
5132 builder.getInt64Ty(), true));
5133 }
5134 }
5135 return memberOfFlag;
5136}
5137
5138// This function is intended to add explicit mappings of members
5140 LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder,
5141 llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo,
5142 MapInfoData &mapData, uint64_t mapDataIndex,
5143 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag,
5144 TargetDirectiveEnumTy targetDirective) {
5145 assert(!ompBuilder.Config.isTargetDevice() &&
5146 "function only supported for host device codegen");
5147
5148 auto parentClause =
5149 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
5150
5151 for (auto mappedMembers : parentClause.getMembers()) {
5152 auto memberClause =
5153 llvm::cast<omp::MapInfoOp>(mappedMembers.getDefiningOp());
5154 int memberDataIdx = getMapDataMemberIdx(mapData, memberClause);
5155
5156 assert(memberDataIdx >= 0 && "could not find mapped member of structure");
5157
5158 // If we're currently mapping a pointer to a block of data, we must
5159 // initially map the pointer, and then attatch/bind the data with a
5160 // subsequent map to the pointer. This segment of code generates the
5161 // pointer mapping, which can in certain cases be optimised out as Clang
5162 // currently does in its lowering. However, for the moment we do not do so,
5163 // in part as we currently have substantially less information on the data
5164 // being mapped at this stage.
5165 if (checkIfPointerMap(memberClause)) {
5166 auto mapFlag = convertClauseMapFlags(memberClause.getMapType());
5167 mapFlag &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
5168 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
5169 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
5170 combinedInfo.Types.emplace_back(mapFlag);
5171 combinedInfo.DevicePointers.emplace_back(
5172 llvm::OpenMPIRBuilder::DeviceInfoTy::None);
5173 combinedInfo.Mappers.emplace_back(nullptr);
5174 combinedInfo.Names.emplace_back(
5175 LLVM::createMappingInformation(memberClause.getLoc(), ompBuilder));
5176 combinedInfo.BasePointers.emplace_back(
5177 mapData.BasePointers[mapDataIndex]);
5178 combinedInfo.Pointers.emplace_back(mapData.BasePointers[memberDataIdx]);
5179 combinedInfo.Sizes.emplace_back(builder.getInt64(
5180 moduleTranslation.getLLVMModule()->getDataLayout().getPointerSize()));
5181 }
5182
5183 // Same MemberOfFlag to indicate its link with parent and other members
5184 // of.
5185 auto mapFlag = convertClauseMapFlags(memberClause.getMapType());
5186 mapFlag &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
5187 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
5188 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
5189 bool isDeclTargetTo = isDeclareTargetTo(parentClause.getVarPtr()
5190 ? parentClause.getVarPtr()
5191 : parentClause.getVarPtrPtr());
5192 if (checkIfPointerMap(memberClause) &&
5193 (!isDeclTargetTo ||
5194 (targetDirective != TargetDirectiveEnumTy::TargetUpdate &&
5195 targetDirective != TargetDirectiveEnumTy::TargetData))) {
5196 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
5197 }
5198
5199 combinedInfo.Types.emplace_back(mapFlag);
5200 combinedInfo.DevicePointers.emplace_back(
5201 mapData.DevicePointers[memberDataIdx]);
5202 combinedInfo.Mappers.emplace_back(mapData.Mappers[memberDataIdx]);
5203 combinedInfo.Names.emplace_back(
5204 LLVM::createMappingInformation(memberClause.getLoc(), ompBuilder));
5205 uint64_t basePointerIndex =
5206 checkIfPointerMap(memberClause) ? memberDataIdx : mapDataIndex;
5207 combinedInfo.BasePointers.emplace_back(
5208 mapData.BasePointers[basePointerIndex]);
5209 combinedInfo.Pointers.emplace_back(mapData.Pointers[memberDataIdx]);
5210
5211 llvm::Value *size = mapData.Sizes[memberDataIdx];
5212 if (checkIfPointerMap(memberClause)) {
5213 size = builder.CreateSelect(
5214 builder.CreateIsNull(mapData.Pointers[memberDataIdx]),
5215 builder.getInt64(0), size);
5216 }
5217
5218 combinedInfo.Sizes.emplace_back(size);
5219 }
5220}
5221
5222static void processIndividualMap(MapInfoData &mapData, size_t mapDataIdx,
5223 MapInfosTy &combinedInfo,
5224 TargetDirectiveEnumTy targetDirective,
5225 int mapDataParentIdx = -1) {
5226 // Declare Target Mappings are excluded from being marked as
5227 // OMP_MAP_TARGET_PARAM as they are not passed as parameters, they're
5228 // marked with OMP_MAP_PTR_AND_OBJ instead.
5229 auto mapFlag = mapData.Types[mapDataIdx];
5230 auto mapInfoOp = llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIdx]);
5231
5232 bool isPtrTy = checkIfPointerMap(mapInfoOp);
5233 if (isPtrTy)
5234 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
5235
5236 if (targetDirective == TargetDirectiveEnumTy::Target &&
5237 !mapData.IsDeclareTarget[mapDataIdx])
5238 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
5239
5240 if (mapInfoOp.getMapCaptureType() == omp::VariableCaptureKind::ByCopy &&
5241 !isPtrTy)
5242 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
5243
5244 // if we're provided a mapDataParentIdx, then the data being mapped is
5245 // part of a larger object (in a parent <-> member mapping) and in this
5246 // case our BasePointer should be the parent.
5247 if (mapDataParentIdx >= 0)
5248 combinedInfo.BasePointers.emplace_back(
5249 mapData.BasePointers[mapDataParentIdx]);
5250 else
5251 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIdx]);
5252
5253 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIdx]);
5254 combinedInfo.DevicePointers.emplace_back(mapData.DevicePointers[mapDataIdx]);
5255 combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIdx]);
5256 combinedInfo.Names.emplace_back(mapData.Names[mapDataIdx]);
5257 combinedInfo.Types.emplace_back(mapFlag);
5258 combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIdx]);
5259}
5260
5262 llvm::IRBuilderBase &builder,
5263 llvm::OpenMPIRBuilder &ompBuilder,
5264 DataLayout &dl, MapInfosTy &combinedInfo,
5265 MapInfoData &mapData, uint64_t mapDataIndex,
5266 TargetDirectiveEnumTy targetDirective) {
5267 assert(!ompBuilder.Config.isTargetDevice() &&
5268 "function only supported for host device codegen");
5269
5270 auto parentClause =
5271 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
5272
5273 // If we have a partial map (no parent referenced in the map clauses of the
5274 // directive, only members) and only a single member, we do not need to bind
5275 // the map of the member to the parent, we can pass the member separately.
5276 if (parentClause.getMembers().size() == 1 && parentClause.getPartialMap()) {
5277 auto memberClause = llvm::cast<omp::MapInfoOp>(
5278 parentClause.getMembers()[0].getDefiningOp());
5279 int memberDataIdx = getMapDataMemberIdx(mapData, memberClause);
5280 // Note: Clang treats arrays with explicit bounds that fall into this
5281 // category as a parent with map case, however, it seems this isn't a
5282 // requirement, and processing them as an individual map is fine. So,
5283 // we will handle them as individual maps for the moment, as it's
5284 // difficult for us to check this as we always require bounds to be
5285 // specified currently and it's also marginally more optimal (single
5286 // map rather than two). The difference may come from the fact that
5287 // Clang maps array without bounds as pointers (which we do not
5288 // currently do), whereas we treat them as arrays in all cases
5289 // currently.
5290 processIndividualMap(mapData, memberDataIdx, combinedInfo, targetDirective,
5291 mapDataIndex);
5292 return;
5293 }
5294
5295 llvm::omp::OpenMPOffloadMappingFlags memberOfParentFlag =
5296 mapParentWithMembers(moduleTranslation, builder, ompBuilder, dl,
5297 combinedInfo, mapData, mapDataIndex,
5298 targetDirective);
5299 processMapMembersWithParent(moduleTranslation, builder, ompBuilder, dl,
5300 combinedInfo, mapData, mapDataIndex,
5301 memberOfParentFlag, targetDirective);
5302}
5303
5304// This is a variation on Clang's GenerateOpenMPCapturedVars, which
5305// generates different operation (e.g. load/store) combinations for
5306// arguments to the kernel, based on map capture kinds which are then
5307// utilised in the combinedInfo in place of the original Map value.
5308static void
5309createAlteredByCaptureMap(MapInfoData &mapData,
5310 LLVM::ModuleTranslation &moduleTranslation,
5311 llvm::IRBuilderBase &builder) {
5312 assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
5313 "function only supported for host device codegen");
5314 for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
5315 // if it's declare target, skip it, it's handled separately.
5316 if (!mapData.IsDeclareTarget[i]) {
5317 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
5318 omp::VariableCaptureKind captureKind = mapOp.getMapCaptureType();
5319 bool isPtrTy = checkIfPointerMap(mapOp);
5320
5321 // Currently handles array sectioning lowerbound case, but more
5322 // logic may be required in the future. Clang invokes EmitLValue,
5323 // which has specialised logic for special Clang types such as user
5324 // defines, so it is possible we will have to extend this for
5325 // structures or other complex types. As the general idea is that this
5326 // function mimics some of the logic from Clang that we require for
5327 // kernel argument passing from host -> device.
5328 switch (captureKind) {
5329 case omp::VariableCaptureKind::ByRef: {
5330 llvm::Value *newV = mapData.Pointers[i];
5331 std::vector<llvm::Value *> offsetIdx = calculateBoundsOffset(
5332 moduleTranslation, builder, mapData.BaseType[i]->isArrayTy(),
5333 mapOp.getBounds());
5334 if (isPtrTy)
5335 newV = builder.CreateLoad(builder.getPtrTy(), newV);
5336
5337 if (!offsetIdx.empty())
5338 newV = builder.CreateInBoundsGEP(mapData.BaseType[i], newV, offsetIdx,
5339 "array_offset");
5340 mapData.Pointers[i] = newV;
5341 } break;
5342 case omp::VariableCaptureKind::ByCopy: {
5343 llvm::Type *type = mapData.BaseType[i];
5344 llvm::Value *newV;
5345 if (mapData.Pointers[i]->getType()->isPointerTy())
5346 newV = builder.CreateLoad(type, mapData.Pointers[i]);
5347 else
5348 newV = mapData.Pointers[i];
5349
5350 if (!isPtrTy) {
5351 auto curInsert = builder.saveIP();
5352 llvm::DebugLoc DbgLoc = builder.getCurrentDebugLocation();
5353 builder.restoreIP(findAllocaInsertPoint(builder, moduleTranslation));
5354 auto *memTempAlloc =
5355 builder.CreateAlloca(builder.getPtrTy(), nullptr, ".casted");
5356 builder.SetCurrentDebugLocation(DbgLoc);
5357 builder.restoreIP(curInsert);
5358
5359 builder.CreateStore(newV, memTempAlloc);
5360 newV = builder.CreateLoad(builder.getPtrTy(), memTempAlloc);
5361 }
5362
5363 mapData.Pointers[i] = newV;
5364 mapData.BasePointers[i] = newV;
5365 } break;
5366 case omp::VariableCaptureKind::This:
5367 case omp::VariableCaptureKind::VLAType:
5368 mapData.MapClause[i]->emitOpError("Unhandled capture kind");
5369 break;
5370 }
5371 }
5372 }
5373}
5374
5375// Generate all map related information and fill the combinedInfo.
5376static void genMapInfos(llvm::IRBuilderBase &builder,
5377 LLVM::ModuleTranslation &moduleTranslation,
5378 DataLayout &dl, MapInfosTy &combinedInfo,
5379 MapInfoData &mapData,
5380 TargetDirectiveEnumTy targetDirective) {
5381 assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
5382 "function only supported for host device codegen");
5383 // We wish to modify some of the methods in which arguments are
5384 // passed based on their capture type by the target region, this can
5385 // involve generating new loads and stores, which changes the
5386 // MLIR value to LLVM value mapping, however, we only wish to do this
5387 // locally for the current function/target and also avoid altering
5388 // ModuleTranslation, so we remap the base pointer or pointer stored
5389 // in the map infos corresponding MapInfoData, which is later accessed
5390 // by genMapInfos and createTarget to help generate the kernel and
5391 // kernel arg structure. It primarily becomes relevant in cases like
5392 // bycopy, or byref range'd arrays. In the default case, we simply
5393 // pass thee pointer byref as both basePointer and pointer.
5394 createAlteredByCaptureMap(mapData, moduleTranslation, builder);
5395
5396 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
5397
5398 // We operate under the assumption that all vectors that are
5399 // required in MapInfoData are of equal lengths (either filled with
5400 // default constructed data or appropiate information) so we can
5401 // utilise the size from any component of MapInfoData, if we can't
5402 // something is missing from the initial MapInfoData construction.
5403 for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
5404 // NOTE/TODO: We currently do not support arbitrary depth record
5405 // type mapping.
5406 if (mapData.IsAMember[i])
5407 continue;
5408
5409 auto mapInfoOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[i]);
5410 if (!mapInfoOp.getMembers().empty()) {
5411 processMapWithMembersOf(moduleTranslation, builder, *ompBuilder, dl,
5412 combinedInfo, mapData, i, targetDirective);
5413 continue;
5414 }
5415
5416 processIndividualMap(mapData, i, combinedInfo, targetDirective);
5417 }
5418}
5419
5420static llvm::Expected<llvm::Function *>
5421emitUserDefinedMapper(Operation *declMapperOp, llvm::IRBuilderBase &builder,
5422 LLVM::ModuleTranslation &moduleTranslation,
5423 llvm::StringRef mapperFuncName,
5424 TargetDirectiveEnumTy targetDirective);
5425
5426static llvm::Expected<llvm::Function *>
5427getOrCreateUserDefinedMapperFunc(Operation *op, llvm::IRBuilderBase &builder,
5428 LLVM::ModuleTranslation &moduleTranslation,
5429 TargetDirectiveEnumTy targetDirective) {
5430 assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
5431 "function only supported for host device codegen");
5432 auto declMapperOp = cast<omp::DeclareMapperOp>(op);
5433 std::string mapperFuncName =
5434 moduleTranslation.getOpenMPBuilder()->createPlatformSpecificName(
5435 {"omp_mapper", declMapperOp.getSymName()});
5436
5437 if (auto *lookupFunc = moduleTranslation.lookupFunction(mapperFuncName))
5438 return lookupFunc;
5439
5440 // Recursive types can cause re-entrant mapper emission. The mapper function
5441 // is created by OpenMPIRBuilder before the callbacks run, so it may already
5442 // exist in the LLVM module even though it is not yet registered in the
5443 // ModuleTranslation mapping table. Reuse and register it to break the
5444 // recursion.
5445 if (llvm::Function *existingFunc =
5446 moduleTranslation.getLLVMModule()->getFunction(mapperFuncName)) {
5447 moduleTranslation.mapFunction(mapperFuncName, existingFunc);
5448 return existingFunc;
5449 }
5450
5451 return emitUserDefinedMapper(declMapperOp, builder, moduleTranslation,
5452 mapperFuncName, targetDirective);
5453}
5454
5455static llvm::Expected<llvm::Function *>
5456emitUserDefinedMapper(Operation *op, llvm::IRBuilderBase &builder,
5457 LLVM::ModuleTranslation &moduleTranslation,
5458 llvm::StringRef mapperFuncName,
5459 TargetDirectiveEnumTy targetDirective) {
5460 assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
5461 "function only supported for host device codegen");
5462 auto declMapperOp = cast<omp::DeclareMapperOp>(op);
5463 auto declMapperInfoOp = declMapperOp.getDeclareMapperInfo();
5464 DataLayout dl = DataLayout(declMapperOp->getParentOfType<ModuleOp>());
5465 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
5466 llvm::Type *varType = moduleTranslation.convertType(declMapperOp.getType());
5467 SmallVector<Value> mapVars = declMapperInfoOp.getMapVars();
5468
5469 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
5470
5471 // Fill up the arrays with all the mapped variables.
5472 MapInfosTy combinedInfo;
5473 auto genMapInfoCB =
5474 [&](InsertPointTy codeGenIP, llvm::Value *ptrPHI,
5475 llvm::Value *unused2) -> llvm::OpenMPIRBuilder::MapInfosOrErrorTy {
5476 builder.restoreIP(codeGenIP);
5477 moduleTranslation.mapValue(declMapperOp.getSymVal(), ptrPHI);
5478 moduleTranslation.mapBlock(&declMapperOp.getRegion().front(),
5479 builder.GetInsertBlock());
5480 if (failed(moduleTranslation.convertBlock(declMapperOp.getRegion().front(),
5481 /*ignoreArguments=*/true,
5482 builder)))
5483 return llvm::make_error<PreviouslyReportedError>();
5484 MapInfoData mapData;
5485 collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, dl,
5486 builder);
5487 genMapInfos(builder, moduleTranslation, dl, combinedInfo, mapData,
5488 targetDirective);
5489
5490 // Drop the mapping that is no longer necessary so that the same region
5491 // can be processed multiple times.
5492 moduleTranslation.forgetMapping(declMapperOp.getRegion());
5493 return combinedInfo;
5494 };
5495
5496 auto customMapperCB = [&](unsigned i) -> llvm::Expected<llvm::Function *> {
5497 if (!combinedInfo.Mappers[i])
5498 return nullptr;
5499 return getOrCreateUserDefinedMapperFunc(combinedInfo.Mappers[i], builder,
5500 moduleTranslation, targetDirective);
5501 };
5502
5503 llvm::Expected<llvm::Function *> newFn = ompBuilder->emitUserDefinedMapper(
5504 genMapInfoCB, varType, mapperFuncName, customMapperCB);
5505 if (!newFn)
5506 return newFn.takeError();
5507 if ([[maybe_unused]] llvm::Function *mappedFunc =
5508 moduleTranslation.lookupFunction(mapperFuncName)) {
5509 assert(mappedFunc == *newFn &&
5510 "mapper function mapping disagrees with emitted function");
5511 } else {
5512 moduleTranslation.mapFunction(mapperFuncName, *newFn);
5513 }
5514 return *newFn;
5515}
5516
5517static LogicalResult
5518convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
5519 LLVM::ModuleTranslation &moduleTranslation) {
5520 llvm::Value *ifCond = nullptr;
5521 llvm::Value *deviceID = builder.getInt64(llvm::omp::OMP_DEVICEID_UNDEF);
5522 SmallVector<Value> mapVars;
5523 SmallVector<Value> useDevicePtrVars;
5524 SmallVector<Value> useDeviceAddrVars;
5525 llvm::omp::RuntimeFunction RTLFn;
5526 DataLayout DL = DataLayout(op->getParentOfType<ModuleOp>());
5527 TargetDirectiveEnumTy targetDirective = getTargetDirectiveEnumTyFromOp(op);
5528
5529 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
5530 llvm::OpenMPIRBuilder::TargetDataInfo info(
5531 /*RequiresDevicePointerInfo=*/true,
5532 /*SeparateBeginEndCalls=*/true);
5533 assert(!ompBuilder->Config.isTargetDevice() &&
5534 "target data/enter/exit/update are host ops");
5535 bool isOffloadEntry = !ompBuilder->Config.TargetTriples.empty();
5536
5537 auto getDeviceID = [&](mlir::Value dev) -> llvm::Value * {
5538 llvm::Value *v = moduleTranslation.lookupValue(dev);
5539 return builder.CreateIntCast(v, builder.getInt64Ty(), /*isSigned=*/true);
5540 };
5541
5542 LogicalResult result =
5544 .Case([&](omp::TargetDataOp dataOp) {
5545 if (failed(checkImplementationStatus(*dataOp)))
5546 return failure();
5547
5548 if (auto ifVar = dataOp.getIfExpr())
5549 ifCond = moduleTranslation.lookupValue(ifVar);
5550
5551 if (mlir::Value devId = dataOp.getDevice())
5552 deviceID = getDeviceID(devId);
5553
5554 mapVars = dataOp.getMapVars();
5555 useDevicePtrVars = dataOp.getUseDevicePtrVars();
5556 useDeviceAddrVars = dataOp.getUseDeviceAddrVars();
5557 return success();
5558 })
5559 .Case([&](omp::TargetEnterDataOp enterDataOp) -> LogicalResult {
5560 if (failed(checkImplementationStatus(*enterDataOp)))
5561 return failure();
5562
5563 if (auto ifVar = enterDataOp.getIfExpr())
5564 ifCond = moduleTranslation.lookupValue(ifVar);
5565
5566 if (mlir::Value devId = enterDataOp.getDevice())
5567 deviceID = getDeviceID(devId);
5568
5569 RTLFn =
5570 enterDataOp.getNowait()
5571 ? llvm::omp::OMPRTL___tgt_target_data_begin_nowait_mapper
5572 : llvm::omp::OMPRTL___tgt_target_data_begin_mapper;
5573 mapVars = enterDataOp.getMapVars();
5574 info.HasNoWait = enterDataOp.getNowait();
5575 return success();
5576 })
5577 .Case([&](omp::TargetExitDataOp exitDataOp) -> LogicalResult {
5578 if (failed(checkImplementationStatus(*exitDataOp)))
5579 return failure();
5580
5581 if (auto ifVar = exitDataOp.getIfExpr())
5582 ifCond = moduleTranslation.lookupValue(ifVar);
5583
5584 if (mlir::Value devId = exitDataOp.getDevice())
5585 deviceID = getDeviceID(devId);
5586
5587 RTLFn = exitDataOp.getNowait()
5588 ? llvm::omp::OMPRTL___tgt_target_data_end_nowait_mapper
5589 : llvm::omp::OMPRTL___tgt_target_data_end_mapper;
5590 mapVars = exitDataOp.getMapVars();
5591 info.HasNoWait = exitDataOp.getNowait();
5592 return success();
5593 })
5594 .Case([&](omp::TargetUpdateOp updateDataOp) -> LogicalResult {
5595 if (failed(checkImplementationStatus(*updateDataOp)))
5596 return failure();
5597
5598 if (auto ifVar = updateDataOp.getIfExpr())
5599 ifCond = moduleTranslation.lookupValue(ifVar);
5600
5601 if (mlir::Value devId = updateDataOp.getDevice())
5602 deviceID = getDeviceID(devId);
5603
5604 RTLFn =
5605 updateDataOp.getNowait()
5606 ? llvm::omp::OMPRTL___tgt_target_data_update_nowait_mapper
5607 : llvm::omp::OMPRTL___tgt_target_data_update_mapper;
5608 mapVars = updateDataOp.getMapVars();
5609 info.HasNoWait = updateDataOp.getNowait();
5610 return success();
5611 })
5612 .DefaultUnreachable("unexpected operation");
5613
5614 if (failed(result))
5615 return failure();
5616 // Pretend we have IF(false) if we're not doing offload.
5617 if (!isOffloadEntry)
5618 ifCond = builder.getFalse();
5619
5620 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
5621 MapInfoData mapData;
5622 collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, DL,
5623 builder, useDevicePtrVars, useDeviceAddrVars);
5624
5625 // Fill up the arrays with all the mapped variables.
5626 MapInfosTy combinedInfo;
5627 auto genMapInfoCB = [&](InsertPointTy codeGenIP) -> MapInfosTy & {
5628 builder.restoreIP(codeGenIP);
5629 genMapInfos(builder, moduleTranslation, DL, combinedInfo, mapData,
5630 targetDirective);
5631 return combinedInfo;
5632 };
5633
5634 // Define a lambda to apply mappings between use_device_addr and
5635 // use_device_ptr base pointers, and their associated block arguments.
5636 auto mapUseDevice =
5637 [&moduleTranslation](
5638 llvm::OpenMPIRBuilder::DeviceInfoTy type,
5640 llvm::SmallVectorImpl<Value> &useDeviceVars, MapInfoData &mapInfoData,
5641 llvm::function_ref<llvm::Value *(llvm::Value *)> mapper = nullptr) {
5642 for (auto [arg, useDevVar] :
5643 llvm::zip_equal(blockArgs, useDeviceVars)) {
5644
5645 auto getMapBasePtr = [](omp::MapInfoOp mapInfoOp) {
5646 return mapInfoOp.getVarPtrPtr() ? mapInfoOp.getVarPtrPtr()
5647 : mapInfoOp.getVarPtr();
5648 };
5649
5650 auto useDevMap = cast<omp::MapInfoOp>(useDevVar.getDefiningOp());
5651 for (auto [mapClause, devicePointer, basePointer] : llvm::zip_equal(
5652 mapInfoData.MapClause, mapInfoData.DevicePointers,
5653 mapInfoData.BasePointers)) {
5654 auto mapOp = cast<omp::MapInfoOp>(mapClause);
5655 if (getMapBasePtr(mapOp) != getMapBasePtr(useDevMap) ||
5656 devicePointer != type)
5657 continue;
5658
5659 if (llvm::Value *devPtrInfoMap =
5660 mapper ? mapper(basePointer) : basePointer) {
5661 moduleTranslation.mapValue(arg, devPtrInfoMap);
5662 break;
5663 }
5664 }
5665 }
5666 };
5667
5668 using BodyGenTy = llvm::OpenMPIRBuilder::BodyGenTy;
5669 auto bodyGenCB = [&](InsertPointTy codeGenIP, BodyGenTy bodyGenType)
5670 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
5671 // We must always restoreIP regardless of doing anything the caller
5672 // does not restore it, leading to incorrect (no) branch generation.
5673 builder.restoreIP(codeGenIP);
5674 assert(isa<omp::TargetDataOp>(op) &&
5675 "BodyGen requested for non TargetDataOp");
5676 auto blockArgIface = cast<omp::BlockArgOpenMPOpInterface>(op);
5677 Region &region = cast<omp::TargetDataOp>(op).getRegion();
5678 switch (bodyGenType) {
5679 case BodyGenTy::Priv:
5680 // Check if any device ptr/addr info is available
5681 if (!info.DevicePtrInfoMap.empty()) {
5682 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
5683 blockArgIface.getUseDeviceAddrBlockArgs(),
5684 useDeviceAddrVars, mapData,
5685 [&](llvm::Value *basePointer) -> llvm::Value * {
5686 if (!info.DevicePtrInfoMap[basePointer].second)
5687 return nullptr;
5688 return builder.CreateLoad(
5689 builder.getPtrTy(),
5690 info.DevicePtrInfoMap[basePointer].second);
5691 });
5692 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
5693 blockArgIface.getUseDevicePtrBlockArgs(), useDevicePtrVars,
5694 mapData, [&](llvm::Value *basePointer) {
5695 return info.DevicePtrInfoMap[basePointer].second;
5696 });
5697
5698 if (failed(inlineConvertOmpRegions(region, "omp.data.region", builder,
5699 moduleTranslation)))
5700 return llvm::make_error<PreviouslyReportedError>();
5701 }
5702 break;
5703 case BodyGenTy::DupNoPriv:
5704 if (info.DevicePtrInfoMap.empty()) {
5705 // For host device we still need to do the mapping for codegen,
5706 // otherwise it may try to lookup a missing value.
5707 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
5708 blockArgIface.getUseDeviceAddrBlockArgs(),
5709 useDeviceAddrVars, mapData);
5710 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
5711 blockArgIface.getUseDevicePtrBlockArgs(), useDevicePtrVars,
5712 mapData);
5713 }
5714 break;
5715 case BodyGenTy::NoPriv:
5716 // If device info is available then region has already been generated
5717 if (info.DevicePtrInfoMap.empty()) {
5718 if (failed(inlineConvertOmpRegions(region, "omp.data.region", builder,
5719 moduleTranslation)))
5720 return llvm::make_error<PreviouslyReportedError>();
5721 }
5722 break;
5723 }
5724 return builder.saveIP();
5725 };
5726
5727 auto customMapperCB =
5728 [&](unsigned int i) -> llvm::Expected<llvm::Function *> {
5729 if (!combinedInfo.Mappers[i])
5730 return nullptr;
5731 info.HasMapper = true;
5732 return getOrCreateUserDefinedMapperFunc(combinedInfo.Mappers[i], builder,
5733 moduleTranslation, targetDirective);
5734 };
5735
5736 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
5737 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
5738 findAllocaInsertPoint(builder, moduleTranslation);
5739 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = [&]() {
5740 if (isa<omp::TargetDataOp>(op))
5741 return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(),
5742 deviceID, ifCond, info, genMapInfoCB,
5743 customMapperCB,
5744 /*MapperFunc=*/nullptr, bodyGenCB,
5745 /*DeviceAddrCB=*/nullptr);
5746 return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(),
5747 deviceID, ifCond, info, genMapInfoCB,
5748 customMapperCB, &RTLFn);
5749 }();
5750
5751 if (failed(handleError(afterIP, *op)))
5752 return failure();
5753
5754 builder.restoreIP(*afterIP);
5755 return success();
5756}
5757
5758static LogicalResult
5759convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder,
5760 LLVM::ModuleTranslation &moduleTranslation) {
5761 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
5762 auto distributeOp = cast<omp::DistributeOp>(opInst);
5763 if (failed(checkImplementationStatus(opInst)))
5764 return failure();
5765
5766 /// Process teams op reduction in distribute if the reduction is contained in
5767 /// the distribute op.
5768 omp::TeamsOp teamsOp = opInst.getParentOfType<omp::TeamsOp>();
5769 bool doDistributeReduction =
5770 teamsOp ? teamsReductionContainedInDistribute(teamsOp) : false;
5771
5772 DenseMap<Value, llvm::Value *> reductionVariableMap;
5773 unsigned numReductionVars = teamsOp ? teamsOp.getNumReductionVars() : 0;
5775 SmallVector<llvm::Value *> privateReductionVariables(numReductionVars);
5776 llvm::ArrayRef<bool> isByRef;
5777
5778 if (doDistributeReduction) {
5779 isByRef = getIsByRef(teamsOp.getReductionByref());
5780 assert(isByRef.size() == teamsOp.getNumReductionVars());
5781
5782 collectReductionDecls(teamsOp, reductionDecls);
5783 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
5784 findAllocaInsertPoint(builder, moduleTranslation);
5785
5786 MutableArrayRef<BlockArgument> reductionArgs =
5787 llvm::cast<omp::BlockArgOpenMPOpInterface>(*teamsOp)
5788 .getReductionBlockArgs();
5789
5791 teamsOp, reductionArgs, builder, moduleTranslation, allocaIP,
5792 reductionDecls, privateReductionVariables, reductionVariableMap,
5793 isByRef)))
5794 return failure();
5795 }
5796
5797 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
5798 auto bodyGenCB = [&](InsertPointTy allocaIP,
5799 InsertPointTy codeGenIP) -> llvm::Error {
5800 // Save the alloca insertion point on ModuleTranslation stack for use in
5801 // nested regions.
5803 moduleTranslation, allocaIP);
5804
5805 // DistributeOp has only one region associated with it.
5806 builder.restoreIP(codeGenIP);
5807 PrivateVarsInfo privVarsInfo(distributeOp);
5808
5810 allocatePrivateVars(builder, moduleTranslation, privVarsInfo, allocaIP);
5811 if (handleError(afterAllocas, opInst).failed())
5812 return llvm::make_error<PreviouslyReportedError>();
5813
5814 if (handleError(initPrivateVars(builder, moduleTranslation, privVarsInfo),
5815 opInst)
5816 .failed())
5817 return llvm::make_error<PreviouslyReportedError>();
5818
5819 if (failed(copyFirstPrivateVars(
5820 distributeOp, builder, moduleTranslation, privVarsInfo.mlirVars,
5821 privVarsInfo.llvmVars, privVarsInfo.privatizers,
5822 distributeOp.getPrivateNeedsBarrier())))
5823 return llvm::make_error<PreviouslyReportedError>();
5824
5825 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
5826 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
5828 convertOmpOpRegions(distributeOp.getRegion(), "omp.distribute.region",
5829 builder, moduleTranslation);
5830 if (!regionBlock)
5831 return regionBlock.takeError();
5832 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
5833
5834 // Skip applying a workshare loop below when translating 'distribute
5835 // parallel do' (it's been already handled by this point while translating
5836 // the nested omp.wsloop).
5837 if (!isa_and_present<omp::WsloopOp>(distributeOp.getNestedWrapper())) {
5838 // TODO: Add support for clauses which are valid for DISTRIBUTE
5839 // constructs. Static schedule is the default.
5840 bool hasDistSchedule = distributeOp.getDistScheduleStatic();
5841 auto schedule = hasDistSchedule ? omp::ClauseScheduleKind::Distribute
5842 : omp::ClauseScheduleKind::Static;
5843 // dist_schedule clauses are ordered - otherise this should be false
5844 bool isOrdered = hasDistSchedule;
5845 std::optional<omp::ScheduleModifier> scheduleMod;
5846 bool isSimd = false;
5847 llvm::omp::WorksharingLoopType workshareLoopType =
5848 llvm::omp::WorksharingLoopType::DistributeStaticLoop;
5849 bool loopNeedsBarrier = false;
5850 llvm::Value *chunk = moduleTranslation.lookupValue(
5851 distributeOp.getDistScheduleChunkSize());
5852 llvm::CanonicalLoopInfo *loopInfo =
5853 findCurrentLoopInfo(moduleTranslation);
5854 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
5855 ompBuilder->applyWorkshareLoop(
5856 ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
5857 convertToScheduleKind(schedule), chunk, isSimd,
5858 scheduleMod == omp::ScheduleModifier::monotonic,
5859 scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
5860 workshareLoopType, false, hasDistSchedule, chunk);
5861
5862 if (!wsloopIP)
5863 return wsloopIP.takeError();
5864 }
5865 if (failed(cleanupPrivateVars(builder, moduleTranslation,
5866 distributeOp.getLoc(), privVarsInfo.llvmVars,
5867 privVarsInfo.privatizers)))
5868 return llvm::make_error<PreviouslyReportedError>();
5869
5870 return llvm::Error::success();
5871 };
5872
5873 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
5874 findAllocaInsertPoint(builder, moduleTranslation);
5875 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
5876 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
5877 ompBuilder->createDistribute(ompLoc, allocaIP, bodyGenCB);
5878
5879 if (failed(handleError(afterIP, opInst)))
5880 return failure();
5881
5882 builder.restoreIP(*afterIP);
5883
5884 if (doDistributeReduction) {
5885 // Process the reductions if required.
5887 teamsOp, builder, moduleTranslation, allocaIP, reductionDecls,
5888 privateReductionVariables, isByRef,
5889 /*isNoWait*/ false, /*isTeamsReduction*/ true);
5890 }
5891 return success();
5892}
5893
5894/// Lowers the FlagsAttr which is applied to the module on the device
5895/// pass when offloading, this attribute contains OpenMP RTL globals that can
5896/// be passed as flags to the frontend, otherwise they are set to default
5897static LogicalResult
5898convertFlagsAttr(Operation *op, mlir::omp::FlagsAttr attribute,
5899 LLVM::ModuleTranslation &moduleTranslation) {
5900 if (!cast<mlir::ModuleOp>(op))
5901 return failure();
5902
5903 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
5904
5905 ompBuilder->M.addModuleFlag(llvm::Module::Max, "openmp-device",
5906 attribute.getOpenmpDeviceVersion());
5907
5908 if (attribute.getNoGpuLib())
5909 return success();
5910
5911 ompBuilder->createGlobalFlag(
5912 attribute.getDebugKind() /*LangOpts().OpenMPTargetDebug*/,
5913 "__omp_rtl_debug_kind");
5914 ompBuilder->createGlobalFlag(
5915 attribute
5916 .getAssumeTeamsOversubscription() /*LangOpts().OpenMPTeamSubscription*/
5917 ,
5918 "__omp_rtl_assume_teams_oversubscription");
5919 ompBuilder->createGlobalFlag(
5920 attribute
5921 .getAssumeThreadsOversubscription() /*LangOpts().OpenMPThreadSubscription*/
5922 ,
5923 "__omp_rtl_assume_threads_oversubscription");
5924 ompBuilder->createGlobalFlag(
5925 attribute.getAssumeNoThreadState() /*LangOpts().OpenMPNoThreadState*/,
5926 "__omp_rtl_assume_no_thread_state");
5927 ompBuilder->createGlobalFlag(
5928 attribute
5929 .getAssumeNoNestedParallelism() /*LangOpts().OpenMPNoNestedParallelism*/
5930 ,
5931 "__omp_rtl_assume_no_nested_parallelism");
5932 return success();
5933}
5934
5935static void getTargetEntryUniqueInfo(llvm::TargetRegionEntryInfo &targetInfo,
5936 omp::TargetOp targetOp,
5937 llvm::StringRef parentName = "") {
5938 auto fileLoc = targetOp.getLoc()->findInstanceOf<FileLineColLoc>();
5939
5940 assert(fileLoc && "No file found from location");
5941 StringRef fileName = fileLoc.getFilename().getValue();
5942
5943 llvm::sys::fs::UniqueID id;
5944 uint64_t line = fileLoc.getLine();
5945 if (auto ec = llvm::sys::fs::getUniqueID(fileName, id)) {
5946 size_t fileHash = llvm::hash_value(fileName.str());
5947 size_t deviceId = 0xdeadf17e;
5948 targetInfo =
5949 llvm::TargetRegionEntryInfo(parentName, deviceId, fileHash, line);
5950 } else {
5951 targetInfo = llvm::TargetRegionEntryInfo(parentName, id.getDevice(),
5952 id.getFile(), line);
5953 }
5954}
5955
5956static void
5957handleDeclareTargetMapVar(MapInfoData &mapData,
5958 LLVM::ModuleTranslation &moduleTranslation,
5959 llvm::IRBuilderBase &builder, llvm::Function *func) {
5960 assert(moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
5961 "function only supported for target device codegen");
5962 llvm::IRBuilderBase::InsertPointGuard guard(builder);
5963 for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
5964 // In the case of declare target mapped variables, the basePointer is
5965 // the reference pointer generated by the convertDeclareTargetAttr
5966 // method. Whereas the kernelValue is the original variable, so for
5967 // the device we must replace all uses of this original global variable
5968 // (stored in kernelValue) with the reference pointer (stored in
5969 // basePointer for declare target mapped variables), as for device the
5970 // data is mapped into this reference pointer and should be loaded
5971 // from it, the original variable is discarded. On host both exist and
5972 // metadata is generated (elsewhere in the convertDeclareTargetAttr)
5973 // function to link the two variables in the runtime and then both the
5974 // reference pointer and the pointer are assigned in the kernel argument
5975 // structure for the host.
5976 if (mapData.IsDeclareTarget[i]) {
5977 // If the original map value is a constant, then we have to make sure all
5978 // of it's uses within the current kernel/function that we are going to
5979 // rewrite are converted to instructions, as we will be altering the old
5980 // use (OriginalValue) from a constant to an instruction, which will be
5981 // illegal and ICE the compiler if the user is a constant expression of
5982 // some kind e.g. a constant GEP.
5983 if (auto *constant = dyn_cast<llvm::Constant>(mapData.OriginalValue[i]))
5984 convertUsersOfConstantsToInstructions(constant, func, false);
5985
5986 // The users iterator will get invalidated if we modify an element,
5987 // so we populate this vector of uses to alter each user on an
5988 // individual basis to emit its own load (rather than one load for
5989 // all).
5991 for (llvm::User *user : mapData.OriginalValue[i]->users())
5992 userVec.push_back(user);
5993
5994 for (llvm::User *user : userVec) {
5995 if (auto *insn = dyn_cast<llvm::Instruction>(user)) {
5996 if (insn->getFunction() == func) {
5997 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
5998 llvm::Value *substitute = mapData.BasePointers[i];
5999 if (isDeclareTargetLink(mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr()
6000 : mapOp.getVarPtr())) {
6001 builder.SetCurrentDebugLocation(insn->getDebugLoc());
6002 substitute = builder.CreateLoad(
6003 mapData.BasePointers[i]->getType(), mapData.BasePointers[i]);
6004 cast<llvm::LoadInst>(substitute)->moveBefore(insn->getIterator());
6005 }
6006 user->replaceUsesOfWith(mapData.OriginalValue[i], substitute);
6007 }
6008 }
6009 }
6010 }
6011 }
6012}
6013
6014// The createDeviceArgumentAccessor function generates
6015// instructions for retrieving (acessing) kernel
6016// arguments inside of the device kernel for use by
6017// the kernel. This enables different semantics such as
6018// the creation of temporary copies of data allowing
6019// semantics like read-only/no host write back kernel
6020// arguments.
6021//
6022// This currently implements a very light version of Clang's
6023// EmitParmDecl's handling of direct argument handling as well
6024// as a portion of the argument access generation based on
6025// capture types found at the end of emitOutlinedFunctionPrologue
6026// in Clang. The indirect path handling of EmitParmDecl's may be
6027// required for future work, but a direct 1-to-1 copy doesn't seem
6028// possible as the logic is rather scattered throughout Clang's
6029// lowering and perhaps we wish to deviate slightly.
6030//
6031// \param mapData - A container containing vectors of information
6032// corresponding to the input argument, which should have a
6033// corresponding entry in the MapInfoData containers
6034// OrigialValue's.
6035// \param arg - This is the generated kernel function argument that
6036// corresponds to the passed in input argument. We generated different
6037// accesses of this Argument, based on capture type and other Input
6038// related information.
6039// \param input - This is the host side value that will be passed to
6040// the kernel i.e. the kernel input, we rewrite all uses of this within
6041// the kernel (as we generate the kernel body based on the target's region
6042// which maintians references to the original input) to the retVal argument
6043// apon exit of this function inside of the OMPIRBuilder. This interlinks
6044// the kernel argument to future uses of it in the function providing
6045// appropriate "glue" instructions inbetween.
6046// \param retVal - This is the value that all uses of input inside of the
6047// kernel will be re-written to, the goal of this function is to generate
6048// an appropriate location for the kernel argument to be accessed from,
6049// e.g. ByRef will result in a temporary allocation location and then
6050// a store of the kernel argument into this allocated memory which
6051// will then be loaded from, ByCopy will use the allocated memory
6052// directly.
6053static llvm::IRBuilderBase::InsertPoint
6054createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg,
6055 llvm::Value *input, llvm::Value *&retVal,
6056 llvm::IRBuilderBase &builder,
6057 llvm::OpenMPIRBuilder &ompBuilder,
6058 LLVM::ModuleTranslation &moduleTranslation,
6059 llvm::IRBuilderBase::InsertPoint allocaIP,
6060 llvm::IRBuilderBase::InsertPoint codeGenIP) {
6061 assert(ompBuilder.Config.isTargetDevice() &&
6062 "function only supported for target device codegen");
6063 builder.restoreIP(allocaIP);
6064
6065 omp::VariableCaptureKind capture = omp::VariableCaptureKind::ByRef;
6066 LLVM::TypeToLLVMIRTranslator typeToLLVMIRTranslator(
6067 ompBuilder.M.getContext());
6068 unsigned alignmentValue = 0;
6069 // Find the associated MapInfoData entry for the current input
6070 for (size_t i = 0; i < mapData.MapClause.size(); ++i)
6071 if (mapData.OriginalValue[i] == input) {
6072 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
6073 capture = mapOp.getMapCaptureType();
6074 // Get information of alignment of mapped object
6075 alignmentValue = typeToLLVMIRTranslator.getPreferredAlignment(
6076 mapOp.getVarType(), ompBuilder.M.getDataLayout());
6077 break;
6078 }
6079
6080 unsigned int allocaAS = ompBuilder.M.getDataLayout().getAllocaAddrSpace();
6081 unsigned int defaultAS =
6082 ompBuilder.M.getDataLayout().getProgramAddressSpace();
6083
6084 // Create the alloca for the argument the current point.
6085 llvm::Value *v = builder.CreateAlloca(arg.getType(), allocaAS);
6086
6087 if (allocaAS != defaultAS && arg.getType()->isPointerTy())
6088 v = builder.CreateAddrSpaceCast(v, builder.getPtrTy(defaultAS));
6089
6090 builder.CreateStore(&arg, v);
6091
6092 builder.restoreIP(codeGenIP);
6093
6094 switch (capture) {
6095 case omp::VariableCaptureKind::ByCopy: {
6096 retVal = v;
6097 break;
6098 }
6099 case omp::VariableCaptureKind::ByRef: {
6100 llvm::LoadInst *loadInst = builder.CreateAlignedLoad(
6101 v->getType(), v,
6102 ompBuilder.M.getDataLayout().getPrefTypeAlign(v->getType()));
6103 // CreateAlignedLoad function creates similar LLVM IR:
6104 // %res = load ptr, ptr %input, align 8
6105 // This LLVM IR does not contain information about alignment
6106 // of the loaded value. We need to add !align metadata to unblock
6107 // optimizer. The existence of the !align metadata on the instruction
6108 // tells the optimizer that the value loaded is known to be aligned to
6109 // a boundary specified by the integer value in the metadata node.
6110 // Example:
6111 // %res = load ptr, ptr %input, align 8, !align !align_md_node
6112 // ^ ^
6113 // | |
6114 // alignment of %input address |
6115 // |
6116 // alignment of %res object
6117 if (v->getType()->isPointerTy() && alignmentValue) {
6118 llvm::MDBuilder MDB(builder.getContext());
6119 loadInst->setMetadata(
6120 llvm::LLVMContext::MD_align,
6121 llvm::MDNode::get(builder.getContext(),
6122 MDB.createConstant(llvm::ConstantInt::get(
6123 llvm::Type::getInt64Ty(builder.getContext()),
6124 alignmentValue))));
6125 }
6126 retVal = loadInst;
6127
6128 break;
6129 }
6130 case omp::VariableCaptureKind::This:
6131 case omp::VariableCaptureKind::VLAType:
6132 // TODO: Consider returning error to use standard reporting for
6133 // unimplemented features.
6134 assert(false && "Currently unsupported capture kind");
6135 break;
6136 }
6137
6138 return builder.saveIP();
6139}
6140
6141/// Follow uses of `host_eval`-defined block arguments of the given `omp.target`
6142/// operation and populate output variables with their corresponding host value
6143/// (i.e. operand evaluated outside of the target region), based on their uses
6144/// inside of the target region.
6145///
6146/// Loop bounds and steps are only optionally populated, if output vectors are
6147/// provided.
6148static void
6149extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
6150 Value &numTeamsLower, Value &numTeamsUpper,
6151 Value &threadLimit,
6152 llvm::SmallVectorImpl<Value> *lowerBounds = nullptr,
6153 llvm::SmallVectorImpl<Value> *upperBounds = nullptr,
6154 llvm::SmallVectorImpl<Value> *steps = nullptr) {
6155 auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
6156 for (auto item : llvm::zip_equal(targetOp.getHostEvalVars(),
6157 blockArgIface.getHostEvalBlockArgs())) {
6158 Value hostEvalVar = std::get<0>(item), blockArg = std::get<1>(item);
6159
6160 for (Operation *user : blockArg.getUsers()) {
6162 .Case([&](omp::TeamsOp teamsOp) {
6163 if (teamsOp.getNumTeamsLower() == blockArg)
6164 numTeamsLower = hostEvalVar;
6165 else if (llvm::is_contained(teamsOp.getNumTeamsUpperVars(),
6166 blockArg))
6167 numTeamsUpper = hostEvalVar;
6168 else if (!teamsOp.getThreadLimitVars().empty() &&
6169 teamsOp.getThreadLimit(0) == blockArg)
6170 threadLimit = hostEvalVar;
6171 else
6172 llvm_unreachable("unsupported host_eval use");
6173 })
6174 .Case([&](omp::ParallelOp parallelOp) {
6175 if (!parallelOp.getNumThreadsVars().empty() &&
6176 parallelOp.getNumThreads(0) == blockArg)
6177 numThreads = hostEvalVar;
6178 else
6179 llvm_unreachable("unsupported host_eval use");
6180 })
6181 .Case([&](omp::LoopNestOp loopOp) {
6182 auto processBounds =
6183 [&](OperandRange opBounds,
6184 llvm::SmallVectorImpl<Value> *outBounds) -> bool {
6185 bool found = false;
6186 for (auto [i, lb] : llvm::enumerate(opBounds)) {
6187 if (lb == blockArg) {
6188 found = true;
6189 if (outBounds)
6190 (*outBounds)[i] = hostEvalVar;
6191 }
6192 }
6193 return found;
6194 };
6195 bool found =
6196 processBounds(loopOp.getLoopLowerBounds(), lowerBounds);
6197 found = processBounds(loopOp.getLoopUpperBounds(), upperBounds) ||
6198 found;
6199 found = processBounds(loopOp.getLoopSteps(), steps) || found;
6200 (void)found;
6201 assert(found && "unsupported host_eval use");
6202 })
6203 .DefaultUnreachable("unsupported host_eval use");
6204 }
6205 }
6206}
6207
6208/// If \p op is of the given type parameter, return it casted to that type.
6209/// Otherwise, if its immediate parent operation (or some other higher-level
6210/// parent, if \p immediateParent is false) is of that type, return that parent
6211/// casted to the given type.
6212///
6213/// If \p op is \c null or neither it or its parent(s) are of the specified
6214/// type, return a \c null operation.
6215template <typename OpTy>
6216static OpTy castOrGetParentOfType(Operation *op, bool immediateParent = false) {
6217 if (!op)
6218 return OpTy();
6219
6220 if (OpTy casted = dyn_cast<OpTy>(op))
6221 return casted;
6222
6223 if (immediateParent)
6224 return dyn_cast_if_present<OpTy>(op->getParentOp());
6225
6226 return op->getParentOfType<OpTy>();
6227}
6228
6229/// If the given \p value is defined by an \c llvm.mlir.constant operation and
6230/// it is of an integer type, return its value.
6231static std::optional<int64_t> extractConstInteger(Value value) {
6232 if (!value)
6233 return std::nullopt;
6234
6235 if (auto constOp = value.getDefiningOp<LLVM::ConstantOp>())
6236 if (auto constAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
6237 return constAttr.getInt();
6238
6239 return std::nullopt;
6240}
6241
6242static uint64_t getTypeByteSize(mlir::Type type, const DataLayout &dl) {
6243 uint64_t sizeInBits = dl.getTypeSizeInBits(type);
6244 uint64_t sizeInBytes = sizeInBits / 8;
6245 return sizeInBytes;
6246}
6247
6248template <typename OpTy>
6249static uint64_t getReductionDataSize(OpTy &op) {
6250 if (op.getNumReductionVars() > 0) {
6252 collectReductionDecls(op, reductions);
6253
6255 members.reserve(reductions.size());
6256 for (omp::DeclareReductionOp &red : reductions)
6257 members.push_back(red.getType());
6258 Operation *opp = op.getOperation();
6259 auto structType = mlir::LLVM::LLVMStructType::getLiteral(
6260 opp->getContext(), members, /*isPacked=*/false);
6261 DataLayout dl = DataLayout(opp->getParentOfType<ModuleOp>());
6262 return getTypeByteSize(structType, dl);
6263 }
6264 return 0;
6265}
6266
6267/// Populate default `MinTeams`, `MaxTeams` and `MaxThreads` to their default
6268/// values as stated by the corresponding clauses, if constant.
6269///
6270/// These default values must be set before the creation of the outlined LLVM
6271/// function for the target region, so that they can be used to initialize the
6272/// corresponding global `ConfigurationEnvironmentTy` structure.
6273static void
6274initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
6275 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs,
6276 bool isTargetDevice, bool isGPU) {
6277 // TODO: Handle constant 'if' clauses.
6278
6279 Value numThreads, numTeamsLower, numTeamsUpper, threadLimit;
6280 if (!isTargetDevice) {
6281 extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper,
6282 threadLimit);
6283 } else {
6284 // In the target device, values for these clauses are not passed as
6285 // host_eval, but instead evaluated prior to entry to the region. This
6286 // ensures values are mapped and available inside of the target region.
6287 if (auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
6288 numTeamsLower = teamsOp.getNumTeamsLower();
6289 // Handle num_teams upper bounds (only first value for now)
6290 if (!teamsOp.getNumTeamsUpperVars().empty())
6291 numTeamsUpper = teamsOp.getNumTeams(0);
6292 if (!teamsOp.getThreadLimitVars().empty())
6293 threadLimit = teamsOp.getThreadLimit(0);
6294 }
6295
6296 if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) {
6297 if (!parallelOp.getNumThreadsVars().empty())
6298 numThreads = parallelOp.getNumThreads(0);
6299 }
6300 }
6301
6302 // Handle clauses impacting the number of teams.
6303
6304 int32_t minTeamsVal = 1, maxTeamsVal = -1;
6305 if (castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
6306 // TODO: Use `hostNumTeamsLower` to initialize `minTeamsVal`. For now,
6307 // match clang and set min and max to the same value.
6308 if (numTeamsUpper) {
6309 if (auto val = extractConstInteger(numTeamsUpper))
6310 minTeamsVal = maxTeamsVal = *val;
6311 } else {
6312 minTeamsVal = maxTeamsVal = 0;
6313 }
6314 } else if (castOrGetParentOfType<omp::ParallelOp>(capturedOp,
6315 /*immediateParent=*/true) ||
6317 /*immediateParent=*/true)) {
6318 minTeamsVal = maxTeamsVal = 1;
6319 } else {
6320 minTeamsVal = maxTeamsVal = -1;
6321 }
6322
6323 // Handle clauses impacting the number of threads.
6324
6325 auto setMaxValueFromClause = [](Value clauseValue, int32_t &result) {
6326 if (!clauseValue)
6327 return;
6328
6329 if (auto val = extractConstInteger(clauseValue))
6330 result = *val;
6331
6332 // Found an applicable clause, so it's not undefined. Mark as unknown
6333 // because it's not constant.
6334 if (result < 0)
6335 result = 0;
6336 };
6337
6338 // Extract 'thread_limit' clause from 'target' and 'teams' directives.
6339 int32_t targetThreadLimitVal = -1, teamsThreadLimitVal = -1;
6340 if (!targetOp.getThreadLimitVars().empty())
6341 setMaxValueFromClause(targetOp.getThreadLimit(0), targetThreadLimitVal);
6342 setMaxValueFromClause(threadLimit, teamsThreadLimitVal);
6343
6344 // Extract 'max_threads' clause from 'parallel' or set to 1 if it's SIMD.
6345 int32_t maxThreadsVal = -1;
6347 setMaxValueFromClause(numThreads, maxThreadsVal);
6348 else if (castOrGetParentOfType<omp::SimdOp>(capturedOp,
6349 /*immediateParent=*/true))
6350 maxThreadsVal = 1;
6351
6352 // For max values, < 0 means unset, == 0 means set but unknown. Select the
6353 // minimum value between 'max_threads' and 'thread_limit' clauses that were
6354 // set.
6355 int32_t combinedMaxThreadsVal = targetThreadLimitVal;
6356 if (combinedMaxThreadsVal < 0 ||
6357 (teamsThreadLimitVal >= 0 && teamsThreadLimitVal < combinedMaxThreadsVal))
6358 combinedMaxThreadsVal = teamsThreadLimitVal;
6359
6360 if (combinedMaxThreadsVal < 0 ||
6361 (maxThreadsVal >= 0 && maxThreadsVal < combinedMaxThreadsVal))
6362 combinedMaxThreadsVal = maxThreadsVal;
6363
6364 int32_t reductionDataSize = 0;
6365 if (isGPU && capturedOp) {
6366 if (auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp))
6367 reductionDataSize = getReductionDataSize(teamsOp);
6368 }
6369
6370 // Update kernel bounds structure for the `OpenMPIRBuilder` to use.
6371 omp::TargetRegionFlags kernelFlags = targetOp.getKernelExecFlags(capturedOp);
6372 assert(
6373 omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic |
6374 omp::TargetRegionFlags::spmd) &&
6375 "invalid kernel flags");
6376 attrs.ExecFlags =
6377 omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic)
6378 ? omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::spmd)
6379 ? llvm::omp::OMP_TGT_EXEC_MODE_GENERIC_SPMD
6380 : llvm::omp::OMP_TGT_EXEC_MODE_GENERIC
6381 : llvm::omp::OMP_TGT_EXEC_MODE_SPMD;
6382 if (omp::bitEnumContainsAll(kernelFlags,
6383 omp::TargetRegionFlags::spmd |
6384 omp::TargetRegionFlags::no_loop) &&
6385 !omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic))
6386 attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_SPMD_NO_LOOP;
6387
6388 attrs.MinTeams = minTeamsVal;
6389 attrs.MaxTeams.front() = maxTeamsVal;
6390 attrs.MinThreads = 1;
6391 attrs.MaxThreads.front() = combinedMaxThreadsVal;
6392 attrs.ReductionDataSize = reductionDataSize;
6393 // TODO: Allow modified buffer length similar to
6394 // fopenmp-cuda-teams-reduction-recs-num flag in clang.
6395 if (attrs.ReductionDataSize != 0)
6396 attrs.ReductionBufferLength = 1024;
6397}
6398
6399/// Gather LLVM runtime values for all clauses evaluated in the host that are
6400/// passed to the kernel invocation.
6401///
6402/// This function must be called only when compiling for the host. Also, it will
6403/// only provide correct results if it's called after the body of \c targetOp
6404/// has been fully generated.
6405static void
6406initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
6407 LLVM::ModuleTranslation &moduleTranslation,
6408 omp::TargetOp targetOp, Operation *capturedOp,
6409 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs) {
6410 omp::LoopNestOp loopOp = castOrGetParentOfType<omp::LoopNestOp>(capturedOp);
6411 unsigned numLoops = loopOp ? loopOp.getNumLoops() : 0;
6412
6413 Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
6414 llvm::SmallVector<Value> lowerBounds(numLoops), upperBounds(numLoops),
6415 steps(numLoops);
6416 extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper,
6417 teamsThreadLimit, &lowerBounds, &upperBounds, &steps);
6418
6419 // TODO: Handle constant 'if' clauses.
6420 if (!targetOp.getThreadLimitVars().empty()) {
6421 Value targetThreadLimit = targetOp.getThreadLimit(0);
6422 attrs.TargetThreadLimit.front() =
6423 moduleTranslation.lookupValue(targetThreadLimit);
6424 }
6425
6426 // The __kmpc_push_num_teams_51 function expects int32 as the arguments. So,
6427 // truncate or sign extend lower and upper num_teams bounds as well as
6428 // thread_limit to match int32 ABI requirements for the OpenMP runtime.
6429 if (numTeamsLower)
6430 attrs.MinTeams = builder.CreateSExtOrTrunc(
6431 moduleTranslation.lookupValue(numTeamsLower), builder.getInt32Ty());
6432
6433 if (numTeamsUpper)
6434 attrs.MaxTeams.front() = builder.CreateSExtOrTrunc(
6435 moduleTranslation.lookupValue(numTeamsUpper), builder.getInt32Ty());
6436
6437 if (teamsThreadLimit)
6438 attrs.TeamsThreadLimit.front() = builder.CreateSExtOrTrunc(
6439 moduleTranslation.lookupValue(teamsThreadLimit), builder.getInt32Ty());
6440
6441 if (numThreads)
6442 attrs.MaxThreads = moduleTranslation.lookupValue(numThreads);
6443
6444 if (omp::bitEnumContainsAny(targetOp.getKernelExecFlags(capturedOp),
6445 omp::TargetRegionFlags::trip_count)) {
6446 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
6447 attrs.LoopTripCount = nullptr;
6448
6449 // To calculate the trip count, we multiply together the trip counts of
6450 // every collapsed canonical loop. We don't need to create the loop nests
6451 // here, since we're only interested in the trip count.
6452 for (auto [loopLower, loopUpper, loopStep] :
6453 llvm::zip_equal(lowerBounds, upperBounds, steps)) {
6454 llvm::Value *lowerBound = moduleTranslation.lookupValue(loopLower);
6455 llvm::Value *upperBound = moduleTranslation.lookupValue(loopUpper);
6456 llvm::Value *step = moduleTranslation.lookupValue(loopStep);
6457
6458 if (!lowerBound || !upperBound || !step) {
6459 attrs.LoopTripCount = nullptr;
6460 break;
6461 }
6462
6463 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
6464 llvm::Value *tripCount = ompBuilder->calculateCanonicalLoopTripCount(
6465 loc, lowerBound, upperBound, step, /*IsSigned=*/true,
6466 loopOp.getLoopInclusive());
6467
6468 if (!attrs.LoopTripCount) {
6469 attrs.LoopTripCount = tripCount;
6470 continue;
6471 }
6472
6473 // TODO: Enable UndefinedSanitizer to diagnose an overflow here.
6474 attrs.LoopTripCount = builder.CreateMul(attrs.LoopTripCount, tripCount,
6475 {}, /*HasNUW=*/true);
6476 }
6477 }
6478
6479 attrs.DeviceID = builder.getInt64(llvm::omp::OMP_DEVICEID_UNDEF);
6480 if (mlir::Value devId = targetOp.getDevice()) {
6481 attrs.DeviceID = moduleTranslation.lookupValue(devId);
6482 attrs.DeviceID =
6483 builder.CreateSExtOrTrunc(attrs.DeviceID, builder.getInt64Ty());
6484 }
6485}
6486
6487static LogicalResult
6488convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
6489 LLVM::ModuleTranslation &moduleTranslation) {
6490 auto targetOp = cast<omp::TargetOp>(opInst);
6491
6492 // The current debug location already has the DISubprogram for the outlined
6493 // function that will be created for the target op. We save it here so that
6494 // we can set it on the outlined function.
6495 llvm::DebugLoc outlinedFnLoc = builder.getCurrentDebugLocation();
6496 if (failed(checkImplementationStatus(opInst)))
6497 return failure();
6498
6499 // During the handling of target op, we will generate instructions in the
6500 // parent function like call to the oulined function or branch to a new
6501 // BasicBlock. We set the debug location here to parent function so that those
6502 // get the correct debug locations. For outlined functions, the normal MLIR op
6503 // conversion will automatically pick the correct location.
6504 llvm::BasicBlock *parentBB = builder.GetInsertBlock();
6505 assert(parentBB && "No insert block is set for the builder");
6506 llvm::Function *parentLLVMFn = parentBB->getParent();
6507 assert(parentLLVMFn && "Parent Function must be valid");
6508 if (llvm::DISubprogram *SP = parentLLVMFn->getSubprogram())
6509 builder.SetCurrentDebugLocation(llvm::DILocation::get(
6510 parentLLVMFn->getContext(), outlinedFnLoc.getLine(),
6511 outlinedFnLoc.getCol(), SP, outlinedFnLoc.getInlinedAt()));
6512
6513 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
6514 bool isTargetDevice = ompBuilder->Config.isTargetDevice();
6515 bool isGPU = ompBuilder->Config.isGPU();
6516
6517 auto parentFn = opInst.getParentOfType<LLVM::LLVMFuncOp>();
6518 auto argIface = cast<omp::BlockArgOpenMPOpInterface>(opInst);
6519 auto &targetRegion = targetOp.getRegion();
6520 // Holds the private vars that have been mapped along with the block
6521 // argument that corresponds to the MapInfoOp corresponding to the private
6522 // var in question. So, for instance:
6523 //
6524 // %10 = omp.map.info var_ptr(%6#0 : !fir.ref<!fir.box<!fir.heap<i32>>>, ..)
6525 // omp.target map_entries(%10 -> %arg0) private(@box.privatizer %6#0-> %arg1)
6526 //
6527 // Then, %10 has been created so that the descriptor can be used by the
6528 // privatizer @box.privatizer on the device side. Here we'd record {%6#0,
6529 // %arg0} in the mappedPrivateVars map.
6530 llvm::DenseMap<Value, Value> mappedPrivateVars;
6531 DataLayout dl = DataLayout(opInst.getParentOfType<ModuleOp>());
6532 SmallVector<Value> mapVars = targetOp.getMapVars();
6533 SmallVector<Value> hdaVars = targetOp.getHasDeviceAddrVars();
6534 ArrayRef<BlockArgument> mapBlockArgs = argIface.getMapBlockArgs();
6535 ArrayRef<BlockArgument> hdaBlockArgs = argIface.getHasDeviceAddrBlockArgs();
6536 llvm::Function *llvmOutlinedFn = nullptr;
6537 TargetDirectiveEnumTy targetDirective =
6538 getTargetDirectiveEnumTyFromOp(&opInst);
6539
6540 // TODO: It can also be false if a compile-time constant `false` IF clause is
6541 // specified.
6542 bool isOffloadEntry =
6543 isTargetDevice || !ompBuilder->Config.TargetTriples.empty();
6544
6545 // For some private variables, the MapsForPrivatizedVariablesPass
6546 // creates MapInfoOp instances. Go through the private variables and
6547 // the mapped variables so that during codegeneration we are able
6548 // to quickly look up the corresponding map variable, if any for each
6549 // private variable.
6550 if (!targetOp.getPrivateVars().empty() && !targetOp.getMapVars().empty()) {
6551 OperandRange privateVars = targetOp.getPrivateVars();
6552 std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms();
6553 std::optional<DenseI64ArrayAttr> privateMapIndices =
6554 targetOp.getPrivateMapsAttr();
6555
6556 for (auto [privVarIdx, privVarSymPair] :
6557 llvm::enumerate(llvm::zip_equal(privateVars, *privateSyms))) {
6558 auto privVar = std::get<0>(privVarSymPair);
6559 auto privSym = std::get<1>(privVarSymPair);
6560
6561 SymbolRefAttr privatizerName = llvm::cast<SymbolRefAttr>(privSym);
6562 omp::PrivateClauseOp privatizer =
6563 findPrivatizer(targetOp, privatizerName);
6564
6565 if (!privatizer.needsMap())
6566 continue;
6567
6568 mlir::Value mappedValue =
6569 targetOp.getMappedValueForPrivateVar(privVarIdx);
6570 assert(mappedValue && "Expected to find mapped value for a privatized "
6571 "variable that needs mapping");
6572
6573 // The MapInfoOp defining the map var isn't really needed later.
6574 // So, we don't store it in any datastructure. Instead, we just
6575 // do some sanity checks on it right now.
6576 auto mapInfoOp = mappedValue.getDefiningOp<omp::MapInfoOp>();
6577 [[maybe_unused]] Type varType = mapInfoOp.getVarType();
6578
6579 // Check #1: Check that the type of the private variable matches
6580 // the type of the variable being mapped.
6581 if (!isa<LLVM::LLVMPointerType>(privVar.getType()))
6582 assert(
6583 varType == privVar.getType() &&
6584 "Type of private var doesn't match the type of the mapped value");
6585
6586 // Ok, only 1 sanity check for now.
6587 // Record the block argument corresponding to this mapvar.
6588 mappedPrivateVars.insert(
6589 {privVar,
6590 targetRegion.getArgument(argIface.getMapBlockArgsStart() +
6591 (*privateMapIndices)[privVarIdx])});
6592 }
6593 }
6594
6595 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
6596 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP)
6597 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
6598 llvm::IRBuilderBase::InsertPointGuard guard(builder);
6599 builder.SetCurrentDebugLocation(llvm::DebugLoc());
6600 // Forward target-cpu and target-features function attributes from the
6601 // original function to the new outlined function.
6602 llvm::Function *llvmParentFn =
6603 moduleTranslation.lookupFunction(parentFn.getName());
6604 llvmOutlinedFn = codeGenIP.getBlock()->getParent();
6605 assert(llvmParentFn && llvmOutlinedFn &&
6606 "Both parent and outlined functions must exist at this point");
6607
6608 if (outlinedFnLoc && llvmParentFn->getSubprogram())
6609 llvmOutlinedFn->setSubprogram(outlinedFnLoc->getScope()->getSubprogram());
6610
6611 if (auto attr = llvmParentFn->getFnAttribute("target-cpu");
6612 attr.isStringAttribute())
6613 llvmOutlinedFn->addFnAttr(attr);
6614
6615 if (auto attr = llvmParentFn->getFnAttribute("target-features");
6616 attr.isStringAttribute())
6617 llvmOutlinedFn->addFnAttr(attr);
6618
6619 for (auto [arg, mapOp] : llvm::zip_equal(mapBlockArgs, mapVars)) {
6620 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
6621 llvm::Value *mapOpValue =
6622 moduleTranslation.lookupValue(mapInfoOp.getVarPtr());
6623 moduleTranslation.mapValue(arg, mapOpValue);
6624 }
6625 for (auto [arg, mapOp] : llvm::zip_equal(hdaBlockArgs, hdaVars)) {
6626 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
6627 llvm::Value *mapOpValue =
6628 moduleTranslation.lookupValue(mapInfoOp.getVarPtr());
6629 moduleTranslation.mapValue(arg, mapOpValue);
6630 }
6631
6632 // Do privatization after moduleTranslation has already recorded
6633 // mapped values.
6634 PrivateVarsInfo privateVarsInfo(targetOp);
6635
6637 allocatePrivateVars(builder, moduleTranslation, privateVarsInfo,
6638 allocaIP, &mappedPrivateVars);
6639
6640 if (failed(handleError(afterAllocas, *targetOp)))
6641 return llvm::make_error<PreviouslyReportedError>();
6642
6643 builder.restoreIP(codeGenIP);
6644 if (handleError(initPrivateVars(builder, moduleTranslation, privateVarsInfo,
6645 &mappedPrivateVars),
6646 *targetOp)
6647 .failed())
6648 return llvm::make_error<PreviouslyReportedError>();
6649
6650 if (failed(copyFirstPrivateVars(
6651 targetOp, builder, moduleTranslation, privateVarsInfo.mlirVars,
6652 privateVarsInfo.llvmVars, privateVarsInfo.privatizers,
6653 targetOp.getPrivateNeedsBarrier(), &mappedPrivateVars)))
6654 return llvm::make_error<PreviouslyReportedError>();
6655
6656 SmallVector<Region *> privateCleanupRegions;
6657 llvm::transform(privateVarsInfo.privatizers,
6658 std::back_inserter(privateCleanupRegions),
6659 [](omp::PrivateClauseOp privatizer) {
6660 return &privatizer.getDeallocRegion();
6661 });
6662
6664 targetRegion, "omp.target", builder, moduleTranslation);
6665
6666 if (!exitBlock)
6667 return exitBlock.takeError();
6668
6669 builder.SetInsertPoint(*exitBlock);
6670 if (!privateCleanupRegions.empty()) {
6671 if (failed(inlineOmpRegionCleanup(
6672 privateCleanupRegions, privateVarsInfo.llvmVars,
6673 moduleTranslation, builder, "omp.targetop.private.cleanup",
6674 /*shouldLoadCleanupRegionArg=*/false))) {
6675 return llvm::createStringError(
6676 "failed to inline `dealloc` region of `omp.private` "
6677 "op in the target region");
6678 }
6679 return builder.saveIP();
6680 }
6681
6682 return InsertPointTy(exitBlock.get(), exitBlock.get()->end());
6683 };
6684
6685 StringRef parentName = parentFn.getName();
6686
6687 llvm::TargetRegionEntryInfo entryInfo;
6688
6689 getTargetEntryUniqueInfo(entryInfo, targetOp, parentName);
6690
6691 MapInfoData mapData;
6692 collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, dl,
6693 builder, /*useDevPtrOperands=*/{},
6694 /*useDevAddrOperands=*/{}, hdaVars);
6695
6696 MapInfosTy combinedInfos;
6697 auto genMapInfoCB =
6698 [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP) -> MapInfosTy & {
6699 builder.restoreIP(codeGenIP);
6700 genMapInfos(builder, moduleTranslation, dl, combinedInfos, mapData,
6701 targetDirective);
6702 return combinedInfos;
6703 };
6704
6705 auto argAccessorCB = [&](llvm::Argument &arg, llvm::Value *input,
6706 llvm::Value *&retVal, InsertPointTy allocaIP,
6707 InsertPointTy codeGenIP)
6708 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
6709 llvm::IRBuilderBase::InsertPointGuard guard(builder);
6710 builder.SetCurrentDebugLocation(llvm::DebugLoc());
6711 // We just return the unaltered argument for the host function
6712 // for now, some alterations may be required in the future to
6713 // keep host fallback functions working identically to the device
6714 // version (e.g. pass ByCopy values should be treated as such on
6715 // host and device, currently not always the case)
6716 if (!isTargetDevice) {
6717 retVal = cast<llvm::Value>(&arg);
6718 return codeGenIP;
6719 }
6720
6721 return createDeviceArgumentAccessor(mapData, arg, input, retVal, builder,
6722 *ompBuilder, moduleTranslation,
6723 allocaIP, codeGenIP);
6724 };
6725
6726 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs runtimeAttrs;
6727 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs;
6728 Operation *targetCapturedOp = targetOp.getInnermostCapturedOmpOp();
6729 initTargetDefaultAttrs(targetOp, targetCapturedOp, defaultAttrs,
6730 isTargetDevice, isGPU);
6731
6732 // Collect host-evaluated values needed to properly launch the kernel from the
6733 // host.
6734 if (!isTargetDevice)
6735 initTargetRuntimeAttrs(builder, moduleTranslation, targetOp,
6736 targetCapturedOp, runtimeAttrs);
6737
6738 // Pass host-evaluated values as parameters to the kernel / host fallback,
6739 // except if they are constants. In any case, map the MLIR block argument to
6740 // the corresponding LLVM values.
6742 SmallVector<Value> hostEvalVars = targetOp.getHostEvalVars();
6743 ArrayRef<BlockArgument> hostEvalBlockArgs = argIface.getHostEvalBlockArgs();
6744 for (auto [arg, var] : llvm::zip_equal(hostEvalBlockArgs, hostEvalVars)) {
6745 llvm::Value *value = moduleTranslation.lookupValue(var);
6746 moduleTranslation.mapValue(arg, value);
6747
6748 if (!llvm::isa<llvm::Constant>(value))
6749 kernelInput.push_back(value);
6750 }
6751
6752 for (size_t i = 0, e = mapData.OriginalValue.size(); i != e; ++i) {
6753 // declare target arguments are not passed to kernels as arguments
6754 // TODO: We currently do not handle cases where a member is explicitly
6755 // passed in as an argument, this will likley need to be handled in
6756 // the near future, rather than using IsAMember, it may be better to
6757 // test if the relevant BlockArg is used within the target region and
6758 // then use that as a basis for exclusion in the kernel inputs.
6759 if (!mapData.IsDeclareTarget[i] && !mapData.IsAMember[i])
6760 kernelInput.push_back(mapData.OriginalValue[i]);
6761 }
6762
6764 buildDependData(targetOp.getDependKinds(), targetOp.getDependVars(),
6765 moduleTranslation, dds);
6766
6767 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
6768 findAllocaInsertPoint(builder, moduleTranslation);
6769 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
6770
6771 llvm::OpenMPIRBuilder::TargetDataInfo info(
6772 /*RequiresDevicePointerInfo=*/false,
6773 /*SeparateBeginEndCalls=*/true);
6774
6775 auto customMapperCB =
6776 [&](unsigned int i) -> llvm::Expected<llvm::Function *> {
6777 if (!combinedInfos.Mappers[i])
6778 return nullptr;
6779 info.HasMapper = true;
6780 return getOrCreateUserDefinedMapperFunc(combinedInfos.Mappers[i], builder,
6781 moduleTranslation, targetDirective);
6782 };
6783
6784 llvm::Value *ifCond = nullptr;
6785 if (Value targetIfCond = targetOp.getIfExpr())
6786 ifCond = moduleTranslation.lookupValue(targetIfCond);
6787
6788 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
6789 moduleTranslation.getOpenMPBuilder()->createTarget(
6790 ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), info, entryInfo,
6791 defaultAttrs, runtimeAttrs, ifCond, kernelInput, genMapInfoCB, bodyCB,
6792 argAccessorCB, customMapperCB, dds, targetOp.getNowait());
6793
6794 if (failed(handleError(afterIP, opInst)))
6795 return failure();
6796
6797 builder.restoreIP(*afterIP);
6798
6799 // Remap access operations to declare target reference pointers for the
6800 // device, essentially generating extra loadop's as necessary
6801 if (moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice())
6802 handleDeclareTargetMapVar(mapData, moduleTranslation, builder,
6803 llvmOutlinedFn);
6804
6805 return success();
6806}
6807
6808static LogicalResult
6809convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute,
6810 LLVM::ModuleTranslation &moduleTranslation) {
6811 // Amend omp.declare_target by deleting the IR of the outlined functions
6812 // created for target regions. They cannot be filtered out from MLIR earlier
6813 // because the omp.target operation inside must be translated to LLVM, but
6814 // the wrapper functions themselves must not remain at the end of the
6815 // process. We know that functions where omp.declare_target does not match
6816 // omp.is_target_device at this stage can only be wrapper functions because
6817 // those that aren't are removed earlier as an MLIR transformation pass.
6818 if (FunctionOpInterface funcOp = dyn_cast<FunctionOpInterface>(op)) {
6819 if (auto offloadMod = dyn_cast<omp::OffloadModuleInterface>(
6820 op->getParentOfType<ModuleOp>().getOperation())) {
6821 if (!offloadMod.getIsTargetDevice())
6822 return success();
6823
6824 omp::DeclareTargetDeviceType declareType =
6825 attribute.getDeviceType().getValue();
6826
6827 if (declareType == omp::DeclareTargetDeviceType::host) {
6828 llvm::Function *llvmFunc =
6829 moduleTranslation.lookupFunction(funcOp.getName());
6830 llvmFunc->dropAllReferences();
6831 llvmFunc->eraseFromParent();
6832 }
6833 }
6834 return success();
6835 }
6836
6837 if (LLVM::GlobalOp gOp = dyn_cast<LLVM::GlobalOp>(op)) {
6838 llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
6839 if (auto *gVal = llvmModule->getNamedValue(gOp.getSymName())) {
6840 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
6841 bool isDeclaration = gOp.isDeclaration();
6842 bool isExternallyVisible =
6843 gOp.getVisibility() != mlir::SymbolTable::Visibility::Private;
6844 auto loc = op->getLoc()->findInstanceOf<FileLineColLoc>();
6845 llvm::StringRef mangledName = gOp.getSymName();
6846 auto captureClause =
6847 convertToCaptureClauseKind(attribute.getCaptureClause().getValue());
6848 auto deviceClause =
6849 convertToDeviceClauseKind(attribute.getDeviceType().getValue());
6850 // unused for MLIR at the moment, required in Clang for book
6851 // keeping
6852 std::vector<llvm::GlobalVariable *> generatedRefs;
6853
6854 std::vector<llvm::Triple> targetTriple;
6855 auto targetTripleAttr = dyn_cast_or_null<mlir::StringAttr>(
6856 op->getParentOfType<mlir::ModuleOp>()->getAttr(
6857 LLVM::LLVMDialect::getTargetTripleAttrName()));
6858 if (targetTripleAttr)
6859 targetTriple.emplace_back(targetTripleAttr.data());
6860
6861 auto fileInfoCallBack = [&loc]() {
6862 std::string filename = "";
6863 std::uint64_t lineNo = 0;
6864
6865 if (loc) {
6866 filename = loc.getFilename().str();
6867 lineNo = loc.getLine();
6868 }
6869
6870 return std::pair<std::string, std::uint64_t>(llvm::StringRef(filename),
6871 lineNo);
6872 };
6873
6874 auto vfs = llvm::vfs::getRealFileSystem();
6875
6876 ompBuilder->registerTargetGlobalVariable(
6877 captureClause, deviceClause, isDeclaration, isExternallyVisible,
6878 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack, *vfs),
6879 mangledName, generatedRefs, /*OpenMPSimd*/ false, targetTriple,
6880 /*GlobalInitializer*/ nullptr, /*VariableLinkage*/ nullptr,
6881 gVal->getType(), gVal);
6882
6883 if (ompBuilder->Config.isTargetDevice() &&
6884 (attribute.getCaptureClause().getValue() !=
6885 mlir::omp::DeclareTargetCaptureClause::to ||
6886 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
6887 ompBuilder->getAddrOfDeclareTargetVar(
6888 captureClause, deviceClause, isDeclaration, isExternallyVisible,
6889 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack, *vfs),
6890 mangledName, generatedRefs, /*OpenMPSimd*/ false, targetTriple,
6891 gVal->getType(), /*GlobalInitializer*/ nullptr,
6892 /*VariableLinkage*/ nullptr);
6893 }
6894 }
6895 }
6896
6897 return success();
6898}
6899
6900namespace {
6901
6902/// Implementation of the dialect interface that converts operations belonging
6903/// to the OpenMP dialect to LLVM IR.
6904class OpenMPDialectLLVMIRTranslationInterface
6905 : public LLVMTranslationDialectInterface {
6906public:
6907 using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
6908
6909 /// Translates the given operation to LLVM IR using the provided IR builder
6910 /// and saving the state in `moduleTranslation`.
6911 LogicalResult
6912 convertOperation(Operation *op, llvm::IRBuilderBase &builder,
6913 LLVM::ModuleTranslation &moduleTranslation) const final;
6914
6915 /// Given an OpenMP MLIR attribute, create the corresponding LLVM-IR,
6916 /// runtime calls, or operation amendments
6917 LogicalResult
6918 amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
6919 NamedAttribute attribute,
6920 LLVM::ModuleTranslation &moduleTranslation) const final;
6921};
6922
6923} // namespace
6924
6925LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
6926 Operation *op, ArrayRef<llvm::Instruction *> instructions,
6927 NamedAttribute attribute,
6928 LLVM::ModuleTranslation &moduleTranslation) const {
6929 return llvm::StringSwitch<llvm::function_ref<LogicalResult(Attribute)>>(
6930 attribute.getName())
6931 .Case("omp.is_target_device",
6932 [&](Attribute attr) {
6933 if (auto deviceAttr = dyn_cast<BoolAttr>(attr)) {
6934 llvm::OpenMPIRBuilderConfig &config =
6935 moduleTranslation.getOpenMPBuilder()->Config;
6936 config.setIsTargetDevice(deviceAttr.getValue());
6937 return success();
6938 }
6939 return failure();
6940 })
6941 .Case("omp.is_gpu",
6942 [&](Attribute attr) {
6943 if (auto gpuAttr = dyn_cast<BoolAttr>(attr)) {
6944 llvm::OpenMPIRBuilderConfig &config =
6945 moduleTranslation.getOpenMPBuilder()->Config;
6946 config.setIsGPU(gpuAttr.getValue());
6947 return success();
6948 }
6949 return failure();
6950 })
6951 .Case("omp.host_ir_filepath",
6952 [&](Attribute attr) {
6953 if (auto filepathAttr = dyn_cast<StringAttr>(attr)) {
6954 llvm::OpenMPIRBuilder *ompBuilder =
6955 moduleTranslation.getOpenMPBuilder();
6956 auto VFS = llvm::vfs::getRealFileSystem();
6957 ompBuilder->loadOffloadInfoMetadata(*VFS,
6958 filepathAttr.getValue());
6959 return success();
6960 }
6961 return failure();
6962 })
6963 .Case("omp.flags",
6964 [&](Attribute attr) {
6965 if (auto rtlAttr = dyn_cast<omp::FlagsAttr>(attr))
6966 return convertFlagsAttr(op, rtlAttr, moduleTranslation);
6967 return failure();
6968 })
6969 .Case("omp.version",
6970 [&](Attribute attr) {
6971 if (auto versionAttr = dyn_cast<omp::VersionAttr>(attr)) {
6972 llvm::OpenMPIRBuilder *ompBuilder =
6973 moduleTranslation.getOpenMPBuilder();
6974 ompBuilder->M.addModuleFlag(llvm::Module::Max, "openmp",
6975 versionAttr.getVersion());
6976 return success();
6977 }
6978 return failure();
6979 })
6980 .Case("omp.declare_target",
6981 [&](Attribute attr) {
6982 if (auto declareTargetAttr =
6983 dyn_cast<omp::DeclareTargetAttr>(attr))
6984 return convertDeclareTargetAttr(op, declareTargetAttr,
6985 moduleTranslation);
6986 return failure();
6987 })
6988 .Case("omp.requires",
6989 [&](Attribute attr) {
6990 if (auto requiresAttr = dyn_cast<omp::ClauseRequiresAttr>(attr)) {
6991 using Requires = omp::ClauseRequires;
6992 Requires flags = requiresAttr.getValue();
6993 llvm::OpenMPIRBuilderConfig &config =
6994 moduleTranslation.getOpenMPBuilder()->Config;
6995 config.setHasRequiresReverseOffload(
6996 bitEnumContainsAll(flags, Requires::reverse_offload));
6997 config.setHasRequiresUnifiedAddress(
6998 bitEnumContainsAll(flags, Requires::unified_address));
6999 config.setHasRequiresUnifiedSharedMemory(
7000 bitEnumContainsAll(flags, Requires::unified_shared_memory));
7001 config.setHasRequiresDynamicAllocators(
7002 bitEnumContainsAll(flags, Requires::dynamic_allocators));
7003 return success();
7004 }
7005 return failure();
7006 })
7007 .Case("omp.target_triples",
7008 [&](Attribute attr) {
7009 if (auto triplesAttr = dyn_cast<ArrayAttr>(attr)) {
7010 llvm::OpenMPIRBuilderConfig &config =
7011 moduleTranslation.getOpenMPBuilder()->Config;
7012 config.TargetTriples.clear();
7013 config.TargetTriples.reserve(triplesAttr.size());
7014 for (Attribute tripleAttr : triplesAttr) {
7015 if (auto tripleStrAttr = dyn_cast<StringAttr>(tripleAttr))
7016 config.TargetTriples.emplace_back(tripleStrAttr.getValue());
7017 else
7018 return failure();
7019 }
7020 return success();
7021 }
7022 return failure();
7023 })
7024 .Default([](Attribute) {
7025 // Fall through for omp attributes that do not require lowering.
7026 return success();
7027 })(attribute.getValue());
7028
7029 return failure();
7030}
7031
7032// Returns true if the operation is not inside a TargetOp, it is part of a
7033// function and that function is not declare target.
7034static bool isHostDeviceOp(Operation *op) {
7035 // Assumes no reverse offloading
7036 if (op->getParentOfType<omp::TargetOp>())
7037 return false;
7038
7039 if (auto parentFn = op->getParentOfType<LLVM::LLVMFuncOp>()) {
7040 if (auto declareTargetIface =
7041 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
7042 parentFn.getOperation()))
7043 if (declareTargetIface.isDeclareTarget() &&
7044 declareTargetIface.getDeclareTargetDeviceType() !=
7045 mlir::omp::DeclareTargetDeviceType::host)
7046 return false;
7047
7048 return true;
7049 }
7050
7051 return false;
7052}
7053
7054static llvm::Function *getOmpTargetAlloc(llvm::IRBuilderBase &builder,
7055 llvm::Module *llvmModule) {
7056 llvm::Type *i64Ty = builder.getInt64Ty();
7057 llvm::Type *i32Ty = builder.getInt32Ty();
7058 llvm::Type *returnType = builder.getPtrTy(0);
7059 llvm::FunctionType *fnType =
7060 llvm::FunctionType::get(returnType, {i64Ty, i32Ty}, false);
7061 llvm::Function *func = cast<llvm::Function>(
7062 llvmModule->getOrInsertFunction("omp_target_alloc", fnType).getCallee());
7063 return func;
7064}
7065
7066static LogicalResult
7067convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder,
7068 LLVM::ModuleTranslation &moduleTranslation) {
7069 auto allocMemOp = cast<omp::TargetAllocMemOp>(opInst);
7070 if (!allocMemOp)
7071 return failure();
7072
7073 // Get "omp_target_alloc" function
7074 llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
7075 llvm::Function *ompTargetAllocFunc = getOmpTargetAlloc(builder, llvmModule);
7076 // Get the corresponding device value in llvm
7077 mlir::Value deviceNum = allocMemOp.getDevice();
7078 llvm::Value *llvmDeviceNum = moduleTranslation.lookupValue(deviceNum);
7079 // Get the allocation size.
7080 llvm::DataLayout dataLayout = llvmModule->getDataLayout();
7081 mlir::Type heapTy = allocMemOp.getAllocatedType();
7082 llvm::Type *llvmHeapTy = moduleTranslation.convertType(heapTy);
7083 llvm::TypeSize typeSize = dataLayout.getTypeStoreSize(llvmHeapTy);
7084 llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue());
7085 for (auto typeParam : allocMemOp.getTypeparams())
7086 allocSize =
7087 builder.CreateMul(allocSize, moduleTranslation.lookupValue(typeParam));
7088 // Create call to "omp_target_alloc" with the args as translated llvm values.
7089 llvm::CallInst *call =
7090 builder.CreateCall(ompTargetAllocFunc, {allocSize, llvmDeviceNum});
7091 llvm::Value *resultI64 = builder.CreatePtrToInt(call, builder.getInt64Ty());
7092
7093 // Map the result
7094 moduleTranslation.mapValue(allocMemOp.getResult(), resultI64);
7095 return success();
7096}
7097
7098static llvm::Function *getOmpTargetFree(llvm::IRBuilderBase &builder,
7099 llvm::Module *llvmModule) {
7100 llvm::Type *ptrTy = builder.getPtrTy(0);
7101 llvm::Type *i32Ty = builder.getInt32Ty();
7102 llvm::Type *voidTy = builder.getVoidTy();
7103 llvm::FunctionType *fnType =
7104 llvm::FunctionType::get(voidTy, {ptrTy, i32Ty}, false);
7105 llvm::Function *func = dyn_cast<llvm::Function>(
7106 llvmModule->getOrInsertFunction("omp_target_free", fnType).getCallee());
7107 return func;
7108}
7109
7110static LogicalResult
7111convertTargetFreeMemOp(Operation &opInst, llvm::IRBuilderBase &builder,
7112 LLVM::ModuleTranslation &moduleTranslation) {
7113 auto freeMemOp = cast<omp::TargetFreeMemOp>(opInst);
7114 if (!freeMemOp)
7115 return failure();
7116
7117 // Get "omp_target_free" function
7118 llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
7119 llvm::Function *ompTragetFreeFunc = getOmpTargetFree(builder, llvmModule);
7120 // Get the corresponding device value in llvm
7121 mlir::Value deviceNum = freeMemOp.getDevice();
7122 llvm::Value *llvmDeviceNum = moduleTranslation.lookupValue(deviceNum);
7123 // Get the corresponding heapref value in llvm
7124 mlir::Value heapref = freeMemOp.getHeapref();
7125 llvm::Value *llvmHeapref = moduleTranslation.lookupValue(heapref);
7126 // Convert heapref int to ptr and call "omp_target_free"
7127 llvm::Value *intToPtr =
7128 builder.CreateIntToPtr(llvmHeapref, builder.getPtrTy(0));
7129 builder.CreateCall(ompTragetFreeFunc, {intToPtr, llvmDeviceNum});
7130 return success();
7131}
7132
7133/// Given an OpenMP MLIR operation, create the corresponding LLVM IR (including
7134/// OpenMP runtime calls).
7135LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
7136 Operation *op, llvm::IRBuilderBase &builder,
7137 LLVM::ModuleTranslation &moduleTranslation) const {
7138 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
7139
7140 if (ompBuilder->Config.isTargetDevice() &&
7141 !isa<omp::TargetOp, omp::MapInfoOp, omp::TerminatorOp, omp::YieldOp>(
7142 op) &&
7143 isHostDeviceOp(op))
7144 return op->emitOpError() << "unsupported host op found in device";
7145
7146 // For each loop, introduce one stack frame to hold loop information. Ensure
7147 // this is only done for the outermost loop wrapper to prevent introducing
7148 // multiple stack frames for a single loop. Initially set to null, the loop
7149 // information structure is initialized during translation of the nested
7150 // omp.loop_nest operation, making it available to translation of all loop
7151 // wrappers after their body has been successfully translated.
7152 bool isOutermostLoopWrapper =
7153 isa_and_present<omp::LoopWrapperInterface>(op) &&
7154 !dyn_cast_if_present<omp::LoopWrapperInterface>(op->getParentOp());
7155
7156 if (isOutermostLoopWrapper)
7157 moduleTranslation.stackPush<OpenMPLoopInfoStackFrame>();
7158
7159 auto result =
7160 llvm::TypeSwitch<Operation *, LogicalResult>(op)
7161 .Case([&](omp::BarrierOp op) -> LogicalResult {
7163 return failure();
7164
7165 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
7166 ompBuilder->createBarrier(builder.saveIP(),
7167 llvm::omp::OMPD_barrier);
7168 LogicalResult res = handleError(afterIP, *op);
7169 if (res.succeeded()) {
7170 // If the barrier generated a cancellation check, the insertion
7171 // point might now need to be changed to a new continuation block
7172 builder.restoreIP(*afterIP);
7173 }
7174 return res;
7175 })
7176 .Case([&](omp::TaskyieldOp op) {
7178 return failure();
7179
7180 ompBuilder->createTaskyield(builder.saveIP());
7181 return success();
7182 })
7183 .Case([&](omp::FlushOp op) {
7185 return failure();
7186
7187 // No support in Openmp runtime function (__kmpc_flush) to accept
7188 // the argument list.
7189 // OpenMP standard states the following:
7190 // "An implementation may implement a flush with a list by ignoring
7191 // the list, and treating it the same as a flush without a list."
7192 //
7193 // The argument list is discarded so that, flush with a list is
7194 // treated same as a flush without a list.
7195 ompBuilder->createFlush(builder.saveIP());
7196 return success();
7197 })
7198 .Case([&](omp::ParallelOp op) {
7199 return convertOmpParallel(op, builder, moduleTranslation);
7200 })
7201 .Case([&](omp::MaskedOp) {
7202 return convertOmpMasked(*op, builder, moduleTranslation);
7203 })
7204 .Case([&](omp::MasterOp) {
7205 return convertOmpMaster(*op, builder, moduleTranslation);
7206 })
7207 .Case([&](omp::CriticalOp) {
7208 return convertOmpCritical(*op, builder, moduleTranslation);
7209 })
7210 .Case([&](omp::OrderedRegionOp) {
7211 return convertOmpOrderedRegion(*op, builder, moduleTranslation);
7212 })
7213 .Case([&](omp::OrderedOp) {
7214 return convertOmpOrdered(*op, builder, moduleTranslation);
7215 })
7216 .Case([&](omp::WsloopOp) {
7217 return convertOmpWsloop(*op, builder, moduleTranslation);
7218 })
7219 .Case([&](omp::SimdOp) {
7220 return convertOmpSimd(*op, builder, moduleTranslation);
7221 })
7222 .Case([&](omp::AtomicReadOp) {
7223 return convertOmpAtomicRead(*op, builder, moduleTranslation);
7224 })
7225 .Case([&](omp::AtomicWriteOp) {
7226 return convertOmpAtomicWrite(*op, builder, moduleTranslation);
7227 })
7228 .Case([&](omp::AtomicUpdateOp op) {
7229 return convertOmpAtomicUpdate(op, builder, moduleTranslation);
7230 })
7231 .Case([&](omp::AtomicCaptureOp op) {
7232 return convertOmpAtomicCapture(op, builder, moduleTranslation);
7233 })
7234 .Case([&](omp::CancelOp op) {
7235 return convertOmpCancel(op, builder, moduleTranslation);
7236 })
7237 .Case([&](omp::CancellationPointOp op) {
7238 return convertOmpCancellationPoint(op, builder, moduleTranslation);
7239 })
7240 .Case([&](omp::SectionsOp) {
7241 return convertOmpSections(*op, builder, moduleTranslation);
7242 })
7243 .Case([&](omp::SingleOp op) {
7244 return convertOmpSingle(op, builder, moduleTranslation);
7245 })
7246 .Case([&](omp::TeamsOp op) {
7247 return convertOmpTeams(op, builder, moduleTranslation);
7248 })
7249 .Case([&](omp::TaskOp op) {
7250 return convertOmpTaskOp(op, builder, moduleTranslation);
7251 })
7252 .Case([&](omp::TaskloopOp op) {
7253 return convertOmpTaskloopOp(*op, builder, moduleTranslation);
7254 })
7255 .Case([&](omp::TaskgroupOp op) {
7256 return convertOmpTaskgroupOp(op, builder, moduleTranslation);
7257 })
7258 .Case([&](omp::TaskwaitOp op) {
7259 return convertOmpTaskwaitOp(op, builder, moduleTranslation);
7260 })
7261 .Case<omp::YieldOp, omp::TerminatorOp, omp::DeclareMapperOp,
7262 omp::DeclareMapperInfoOp, omp::DeclareReductionOp,
7263 omp::CriticalDeclareOp>([](auto op) {
7264 // `yield` and `terminator` can be just omitted. The block structure
7265 // was created in the region that handles their parent operation.
7266 // `declare_reduction` will be used by reductions and is not
7267 // converted directly, skip it.
7268 // `declare_mapper` and `declare_mapper.info` are handled whenever
7269 // they are referred to through a `map` clause.
7270 // `critical.declare` is only used to declare names of critical
7271 // sections which will be used by `critical` ops and hence can be
7272 // ignored for lowering. The OpenMP IRBuilder will create unique
7273 // name for critical section names.
7274 return success();
7275 })
7276 .Case([&](omp::ThreadprivateOp) {
7277 return convertOmpThreadprivate(*op, builder, moduleTranslation);
7278 })
7279 .Case<omp::TargetDataOp, omp::TargetEnterDataOp,
7280 omp::TargetExitDataOp, omp::TargetUpdateOp>([&](auto op) {
7281 return convertOmpTargetData(op, builder, moduleTranslation);
7282 })
7283 .Case([&](omp::TargetOp) {
7284 return convertOmpTarget(*op, builder, moduleTranslation);
7285 })
7286 .Case([&](omp::DistributeOp) {
7287 return convertOmpDistribute(*op, builder, moduleTranslation);
7288 })
7289 .Case([&](omp::LoopNestOp) {
7290 return convertOmpLoopNest(*op, builder, moduleTranslation);
7291 })
7292 .Case<omp::MapInfoOp, omp::MapBoundsOp, omp::PrivateClauseOp>(
7293 [&](auto op) {
7294 // No-op, should be handled by relevant owning operations e.g.
7295 // TargetOp, TargetEnterDataOp, TargetExitDataOp, TargetDataOp
7296 // etc. and then discarded
7297 return success();
7298 })
7299 .Case([&](omp::NewCliOp op) {
7300 // Meta-operation: Doesn't do anything by itself, but used to
7301 // identify a loop.
7302 return success();
7303 })
7304 .Case([&](omp::CanonicalLoopOp op) {
7305 return convertOmpCanonicalLoopOp(op, builder, moduleTranslation);
7306 })
7307 .Case([&](omp::UnrollHeuristicOp op) {
7308 // FIXME: Handling omp.unroll_heuristic as an executable requires
7309 // that the generator (e.g. omp.canonical_loop) has been seen first.
7310 // For construct that require all codegen to occur inside a callback
7311 // (e.g. OpenMPIRBilder::createParallel), all codegen of that
7312 // contained region including their transformations must occur at
7313 // the omp.canonical_loop.
7314 return applyUnrollHeuristic(op, builder, moduleTranslation);
7315 })
7316 .Case([&](omp::TileOp op) {
7317 return applyTile(op, builder, moduleTranslation);
7318 })
7319 .Case([&](omp::FuseOp op) {
7320 return applyFuse(op, builder, moduleTranslation);
7321 })
7322 .Case([&](omp::TargetAllocMemOp) {
7323 return convertTargetAllocMemOp(*op, builder, moduleTranslation);
7324 })
7325 .Case([&](omp::TargetFreeMemOp) {
7326 return convertTargetFreeMemOp(*op, builder, moduleTranslation);
7327 })
7328 .Default([&](Operation *inst) {
7329 return inst->emitError()
7330 << "not yet implemented: " << inst->getName();
7331 });
7332
7333 if (isOutermostLoopWrapper)
7334 moduleTranslation.stackPop();
7335
7336 return result;
7337}
7338
7340 registry.insert<omp::OpenMPDialect>();
7341 registry.addExtension(+[](MLIRContext *ctx, omp::OpenMPDialect *dialect) {
7342 dialect->addInterfaces<OpenMPDialectLLVMIRTranslationInterface>();
7343 });
7344}
7345
7347 DialectRegistry registry;
7349 context.appendDialectRegistry(registry);
7350}
for(Operation *op :ops)
return success()
lhs
static ze_device_handle_t getDevice(const uint32_t driverIdx=0, const int32_t devIdx=0)
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
if(!isCopyOut)
static bool mapTypeToBool(ClauseMapFlags value, ClauseMapFlags flag)
static void handleDeclareTargetMapVar(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::Function *func)
static LogicalResult convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP atomic update operation using OpenMPIRBuilder.
static llvm::omp::OrderKind convertOrderKind(std::optional< omp::ClauseOrderKind > o)
Convert Order attribute to llvm::omp::OrderKind.
static bool teamsReductionContainedInDistribute(omp::TeamsOp teamsOp)
static LogicalResult convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
owningDataPtrPtrReductionGens[i]
static Operation * getGlobalOpFromValue(Value value)
static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind convertToCaptureClauseKind(mlir::omp::DeclareTargetCaptureClause captureClause)
static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo, bool first)
static OpTy castOrGetParentOfType(Operation *op, bool immediateParent=false)
If op is of the given type parameter, return it casted to that type. Otherwise, if its immediate pare...
static LogicalResult convertOmpOrderedRegion(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'ordered_region' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertTargetFreeMemOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static void processMapMembersWithParent(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, llvm::omp::OpenMPOffloadMappingFlags memberOfFlag, TargetDirectiveEnumTy targetDirective)
static LogicalResult convertOmpAtomicWrite(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an omp.atomic.write operation to LLVM IR.
static OwningAtomicReductionGen makeAtomicReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Create an OpenMPIRBuilder-compatible atomic reduction generator for the given reduction declaration.
static OwningDataPtrPtrReductionGen makeRefDataPtrGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, bool isByRef)
Create an OpenMPIRBuilder-compatible data_ptr_ptr reduction generator for the given reduction declara...
static llvm::Value * getRefPtrIfDeclareTarget(Value value, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult applyUnrollHeuristic(omp::UnrollHeuristicOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Apply a #pragma omp unroll / "!$omp unroll" transformation using the OpenMPIRBuilder.
static LogicalResult convertOmpMaster(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'master' operation into LLVM IR using OpenMPIRBuilder.
static void getAsIntegers(ArrayAttr values, llvm::SmallVector< int64_t > &ints)
static llvm::Value * findAssociatedValue(Value privateVar, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Return the llvm::Value * corresponding to the privateVar that is being privatized....
static ArrayRef< bool > getIsByRef(std::optional< ArrayRef< bool > > attr)
static LogicalResult convertOmpTaskwaitOp(omp::TaskwaitOp twOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP loop nest into LLVM IR using OpenMPIRBuilder.
static void popCancelFinalizationCB(const ArrayRef< llvm::BranchInst * > cancelTerminators, llvm::OpenMPIRBuilder &ompBuilder, const llvm::OpenMPIRBuilder::InsertPointTy &afterIP)
If we cancelled the construct, we should branch to the finalization block of that construct....
static LogicalResult cleanupPrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, Location loc, SmallVectorImpl< llvm::Value * > &llvmPrivateVars, SmallVectorImpl< omp::PrivateClauseOp > &privateDecls)
static void createAlteredByCaptureMap(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder)
static LogicalResult convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP task construct into LLVM IR using OpenMPIRBuilder.
static void genMapInfos(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, TargetDirectiveEnumTy targetDirective)
static llvm::AtomicOrdering convertAtomicOrdering(std::optional< omp::ClauseMemoryOrderKind > ao)
Convert an Atomic Ordering attribute to llvm::AtomicOrdering.
static void setInsertPointForPossiblyEmptyBlock(llvm::IRBuilderBase &builder, llvm::BasicBlock *block=nullptr)
static LogicalResult convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'critical' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, TargetDirectiveEnumTy targetDirective)
static LogicalResult convertOmpCanonicalLoopOp(omp::CanonicalLoopOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Convert an omp.canonical_loop to LLVM-IR.
static LogicalResult convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static std::optional< int64_t > extractConstInteger(Value value)
If the given value is defined by an llvm.mlir.constant operation and it is of an integer type,...
static llvm::Expected< llvm::Value * > initPrivateVar(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, omp::PrivateClauseOp &privDecl, llvm::Value *nonPrivateVar, BlockArgument &blockArg, llvm::Value *llvmPrivateVar, llvm::BasicBlock *privInitBlock, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Initialize a single (first)private variable. You probably want to use allocateAndInitPrivateVars inst...
static LogicalResult allocAndInitializeReductionVars(OP op, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, llvm::ArrayRef< bool > isByRef)
static LogicalResult convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Function * getOmpTargetAlloc(llvm::IRBuilderBase &builder, llvm::Module *llvmModule)
static llvm::Expected< llvm::Function * > emitUserDefinedMapper(Operation *declMapperOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::StringRef mapperFuncName, TargetDirectiveEnumTy targetDirective)
static llvm::Expected< llvm::BasicBlock * > allocatePrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, PrivateVarsInfo &privateVarsInfo, const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Allocate and initialize delayed private variables. Returns the basic block which comes after all of t...
static LogicalResult convertOmpOrdered(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'ordered' operation into LLVM IR using OpenMPIRBuilder.
static void processMapWithMembersOf(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, TargetDirectiveEnumTy targetDirective)
static LogicalResult convertOmpTaskloopOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpMasked(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'masked' operation into LLVM IR using OpenMPIRBuilder.
static llvm::AtomicRMWInst::BinOp convertBinOpToAtomic(Operation &op)
Converts an LLVM dialect binary operation to the corresponding enum value for atomicrmw supported bin...
static void pushCancelFinalizationCB(SmallVectorImpl< llvm::BranchInst * > &cancelTerminators, llvm::IRBuilderBase &llvmBuilder, llvm::OpenMPIRBuilder &ompBuilder, mlir::Operation *op, llvm::omp::Directive cancelDirective)
Shared implementation of a callback which adds a termiator for the new block created for the branch t...
static LogicalResult convertOmpCancel(omp::CancelOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp)
allocatedType moduleTranslation static convertType(allocatedType) LogicalResult inlineOmpRegionCleanup(llvm::SmallVectorImpl< Region * > &cleanupRegions, llvm::ArrayRef< llvm::Value * > privateVariables, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, StringRef regionName, bool shouldLoadCleanupRegionArg=true)
handling of DeclareReductionOp's cleanup region
static LogicalResult applyFuse(omp::FuseOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Apply a #pragma omp fuse / !$omp fuse transformation using the OpenMPIRBuilder.
static llvm::SmallString< 64 > getDeclareTargetRefPtrSuffix(LLVM::GlobalOp globalOp, llvm::OpenMPIRBuilder &ompBuilder)
static llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Error initPrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, PrivateVarsInfo &privateVarsInfo, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
static llvm::CanonicalLoopInfo * findCurrentLoopInfo(LLVM::ModuleTranslation &moduleTranslation)
Find the loop information structure for the loop nest being translated.
static OwningReductionGen makeReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Create an OpenMPIRBuilder-compatible reduction generator for the given reduction declaration.
static std::vector< llvm::Value * > calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, bool isArrayTy, OperandRange bounds)
This function calculates the array/pointer offset for map data provided with bounds operations,...
static LogicalResult convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts the OpenMP parallel operation to LLVM IR.
static llvm::OpenMPIRBuilder::InsertPointTy findAllocaInsertPoint(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Find the insertion point for allocas given the current insertion point for normal operations in the b...
static LogicalResult inlineConvertOmpRegions(Region &region, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::Value * > *continuationBlockArgs=nullptr)
Translates the blocks contained in the given region and appends them to at the current insertion poin...
static LogicalResult convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP Threadprivate operation into LLVM IR using OpenMPIRBuilder.
static omp::PrivateClauseOp findPrivatizer(Operation *from, SymbolRefAttr symbolName)
Looks up from the operation from and returns the PrivateClauseOp with name symbolName.
static LogicalResult convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Expected< llvm::Function * > getOrCreateUserDefinedMapperFunc(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, TargetDirectiveEnumTy targetDirective)
static uint64_t getTypeByteSize(mlir::Type type, const DataLayout &dl)
static void processIndividualMap(MapInfoData &mapData, size_t mapDataIdx, MapInfosTy &combinedInfo, TargetDirectiveEnumTy targetDirective, int mapDataParentIdx=-1)
static void getTargetEntryUniqueInfo(llvm::TargetRegionEntryInfo &targetInfo, omp::TargetOp targetOp, llvm::StringRef parentName="")
static void extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads, Value &numTeamsLower, Value &numTeamsUpper, Value &threadLimit, llvm::SmallVectorImpl< Value > *lowerBounds=nullptr, llvm::SmallVectorImpl< Value > *upperBounds=nullptr, llvm::SmallVectorImpl< Value > *steps=nullptr)
Follow uses of host_eval-defined block arguments of the given omp.target operation and populate outpu...
static llvm::Expected< llvm::BasicBlock * > convertOmpOpRegions(Region &region, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::PHINode * > *continuationBlockPHIs=nullptr)
Converts the given region that appears within an OpenMP dialect operation to LLVM IR,...
static LogicalResult copyFirstPrivateVars(mlir::Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::Value * > &moldVars, ArrayRef< llvm::Value * > llvmPrivateVars, SmallVectorImpl< omp::PrivateClauseOp > &privateDecls, bool insertBarrier, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
static LogicalResult allocReductionVars(T loop, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, SmallVectorImpl< DeferredStore > &deferredStores, llvm::ArrayRef< bool > isByRefs)
Allocate space for privatized reduction variables.
static bool constructIsCancellable(Operation *op)
Returns true if the construct contains omp.cancel or omp.cancellation_point.
static llvm::omp::OpenMPOffloadMappingFlags convertClauseMapFlags(omp::ClauseMapFlags mlirFlags)
static llvm::IRBuilderBase::InsertPoint createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg, llvm::Value *input, llvm::Value *&retVal, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase::InsertPoint allocaIP, llvm::IRBuilderBase::InsertPoint codeGenIP)
static void mapInitializationArgs(T loop, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, DenseMap< Value, llvm::Value * > &reductionVariableMap, unsigned i)
Map input arguments to reduction initialization region.
static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind)
Convert ProcBindKind from MLIR-generated enum to LLVM enum.
static void getOverlappedMembers(llvm::SmallVectorImpl< size_t > &overlapMapDataIdxs, omp::MapInfoOp parentOp)
static LogicalResult convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP single construct into LLVM IR using OpenMPIRBuilder.
static bool isDeclareTargetTo(Value value)
static uint64_t getArrayElementSizeInBits(LLVM::LLVMArrayType arrTy, DataLayout &dl)
static void collectReductionDecls(T op, SmallVectorImpl< omp::DeclareReductionOp > &reductions)
Populates reductions with reduction declarations used in the given op.
static LogicalResult handleError(llvm::Error error, Operation &op)
static LogicalResult convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind convertToDeviceClauseKind(mlir::omp::DeclareTargetDeviceType deviceClause)
static LogicalResult checkImplementationStatus(Operation &op)
Check whether translation to LLVM IR for the given operation is currently supported.
static LogicalResult createReductionsAndCleanup(OP op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, ArrayRef< llvm::Value * > privateReductionVariables, ArrayRef< bool > isByRef, bool isNowait=false, bool isTeamsReduction=false)
static LogicalResult convertOmpCancellationPoint(omp::CancellationPointOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static bool opIsInSingleThread(mlir::Operation *op)
This can't always be determined statically, but when we can, it is good to avoid generating compiler-...
static uint64_t getReductionDataSize(OpTy &op)
static LogicalResult convertOmpAtomicRead(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Convert omp.atomic.read operation to LLVM IR.
static llvm::omp::Directive convertCancellationConstructType(omp::ClauseCancellationConstructType directive)
static void initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp, llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs, bool isTargetDevice, bool isGPU)
Populate default MinTeams, MaxTeams and MaxThreads to their default values as stated by the correspon...
static void initTargetRuntimeAttrs(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, omp::TargetOp targetOp, Operation *capturedOp, llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs)
Gather LLVM runtime values for all clauses evaluated in the host that are passed to the kernel invoca...
static LogicalResult convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static bool isHostDeviceOp(Operation *op)
static bool isDeclareTargetLink(Value value)
static LogicalResult convertFlagsAttr(Operation *op, mlir::omp::FlagsAttr attribute, LLVM::ModuleTranslation &moduleTranslation)
Lowers the FlagsAttr which is applied to the module on the device pass when offloading,...
static bool checkIfPointerMap(omp::MapInfoOp mapOp)
static LogicalResult applyTile(omp::TileOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Apply a #pragma omp tile / !$omp tile transformation using the OpenMPIRBuilder.
static void sortMapIndices(llvm::SmallVectorImpl< size_t > &indices, omp::MapInfoOp mapInfo)
static void buildDependData(std::optional< ArrayAttr > dependKinds, OperandRange dependVars, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::OpenMPIRBuilder::DependData > &dds)
static llvm::Function * getOmpTargetFree(llvm::IRBuilderBase &builder, llvm::Module *llvmModule)
static LogicalResult convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder.
static void collectMapDataFromMapOperands(MapInfoData &mapData, SmallVectorImpl< Value > &mapVars, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, llvm::IRBuilderBase &builder, ArrayRef< Value > useDevPtrOperands={}, ArrayRef< Value > useDevAddrOperands={}, ArrayRef< Value > hasDevAddrOperands={})
static void extractAtomicControlFlags(omp::AtomicUpdateOp atomicUpdateOp, bool &isIgnoreDenormalMode, bool &isFineGrainedMemory, bool &isRemoteMemory)
static Operation * genLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr, unsigned numCases, bool needsUniv, ArrayRef< TensorLevel > tidLvls)
Generates a for-loop or a while-loop, depending on whether it implements singleton iteration or co-it...
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition TypeID.h:331
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:309
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
Operation & front()
Definition Block.h:163
Operation & back()
Definition Block.h:162
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
iterator begin()
Definition Block.h:153
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
An instance of this location represents a tuple of file, line number, and column number.
Definition Location.h:174
Implementation class for module translation.
llvm::BasicBlock * lookupBlock(Block *block) const
Finds an LLVM IR basic block that corresponds to the given MLIR block.
WalkResult stackWalk(llvm::function_ref< WalkResult(T &)> callback)
Calls callback for every ModuleTranslation stack frame of type T starting from the top of the stack.
void stackPush(Args &&...args)
Creates a stack frame of type T on ModuleTranslation stack.
LogicalResult convertBlock(Block &bb, bool ignoreArguments, llvm::IRBuilderBase &builder)
Translates the contents of the given block to LLVM IR using this translator.
SmallVector< llvm::Value * > lookupValues(ValueRange values)
Looks up remapped a list of remapped values.
void mapFunction(StringRef name, llvm::Function *func)
Stores the mapping between a function name and its LLVM IR representation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
void invalidateOmpLoop(omp::NewCliOp mlir)
Mark an OpenMP loop as having been consumed.
SymbolTableCollection & symbolTable()
llvm::Type * convertType(Type type)
Converts the type from MLIR LLVM dialect to LLVM.
llvm::OpenMPIRBuilder * getOpenMPBuilder()
Returns the OpenMP IR builder associated with the LLVM IR module being constructed.
void mapOmpLoop(omp::NewCliOp mlir, llvm::CanonicalLoopInfo *llvm)
Map an MLIR OpenMP dialect CanonicalLoopInfo to its lowered LLVM-IR OpenMPIRBuilder CanonicalLoopInfo...
llvm::GlobalValue * lookupGlobal(Operation *op)
Finds an LLVM IR global value that corresponds to the given MLIR operation defining a global value.
SaveStateStack< T, ModuleTranslation > SaveStack
RAII object calling stackPush/stackPop on construction/destruction.
llvm::Function * lookupFunction(StringRef name) const
Finds an LLVM IR function by its name.
void mapBlock(Block *mlir, llvm::BasicBlock *llvm)
Stores the mapping between an MLIR block and LLVM IR basic block.
llvm::Module * getLLVMModule()
Returns the LLVM module in which the IR is being constructed.
void stackPop()
Pops the last element from the ModuleTranslation stack.
void forgetMapping(Region &region)
Removes the mapping for blocks contained in the region and values defined in these blocks.
void mapValue(Value mlir, llvm::Value *llvm)
Stores the mapping between an MLIR value and its LLVM IR counterpart.
llvm::CanonicalLoopInfo * lookupOMPLoop(omp::NewCliOp mlir) const
Find the LLVM-IR loop that represents an MLIR loop.
llvm::LLVMContext & getLLVMContext() const
Returns the LLVM context in which the IR is being constructed.
Utility class to translate MLIR LLVM dialect types to LLVM IR.
Definition TypeToLLVM.h:39
unsigned getPreferredAlignment(Type type, const llvm::DataLayout &layout)
Returns the preferred alignment for the type given the data layout.
T findInstanceOf()
Return an instance of the given location type if one is nested under the current location.
Definition Location.h:45
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
void appendDialectRegistry(const DialectRegistry &registry)
Append the contents of the given dialect registry to the registry associated with this context.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:686
Value getOperand(unsigned idx)
Definition Operation.h:350
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
unsigned getNumOperands()
Definition Operation.h:346
OperandRange operand_range
Definition Operation.h:371
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:238
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:797
user_range getUsers()
Returns a range of all users.
Definition Operation.h:873
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
BlockArgListType getArguments()
Definition Region.h:81
bool empty()
Definition Region.h:60
unsigned getNumArguments()
Definition Region.h:123
iterator begin()
Definition Region.h:55
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition Region.h:200
BlockListType & getBlocks()
Definition Region.h:45
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
Concrete CRTP base class for StateStack frames.
Definition StateStack.h:47
@ Private
The symbol is private and may only be referenced by SymbolRefAttrs local to the operations within the...
Definition SymbolTable.h:97
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult advance()
Definition WalkResult.h:47
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition WalkResult.h:51
static WalkResult interrupt()
Definition WalkResult.h:46
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Definition CallGraph.h:229
void connectPHINodes(Region &region, const ModuleTranslation &state)
For all blocks in the region that were converted to LLVM IR using the given ModuleTranslation,...
llvm::Constant * createMappingInformation(Location loc, llvm::OpenMPIRBuilder &builder)
Create a constant string representing the mapping information extracted from the MLIR location inform...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
SetVector< Block * > getBlocksSortedByDominance(Region &region)
Gets a list of blocks that is sorted according to dominance.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
const FrozenRewritePatternSet GreedyRewriteConfig config
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void registerOpenMPDialectTranslation(DialectRegistry &registry)
Register the OpenMP dialect and the translation from it to the LLVM IR in the given registry;.
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:123
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition Utils.cpp:1294
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:118
A util to collect info needed to convert delayed privatizers from MLIR to LLVM.
SmallVector< mlir::Value > mlirVars
SmallVector< omp::PrivateClauseOp > privatizers
MutableArrayRef< BlockArgument > blockArgs
SmallVector< llvm::Value * > llvmVars