MLIR 23.0.0git
OpenMPToLLVMIRTranslation.cpp
Go to the documentation of this file.
1//===- OpenMPToLLVMIRTranslation.cpp - Translate OpenMP dialect to LLVM IR-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a translation between the MLIR OpenMP dialect and LLVM
10// IR.
11//
12//===----------------------------------------------------------------------===//
20#include "mlir/IR/Operation.h"
22#include "mlir/Support/LLVM.h"
25
26#include "llvm/ADT/ArrayRef.h"
27#include "llvm/ADT/SmallVector.h"
28#include "llvm/ADT/TypeSwitch.h"
29#include "llvm/Frontend/OpenMP/OMPConstants.h"
30#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
31#include "llvm/IR/Constants.h"
32#include "llvm/IR/DebugInfoMetadata.h"
33#include "llvm/IR/DerivedTypes.h"
34#include "llvm/IR/IRBuilder.h"
35#include "llvm/IR/MDBuilder.h"
36#include "llvm/IR/ReplaceConstant.h"
37#include "llvm/Support/AMDGPUAddrSpace.h"
38#include "llvm/Support/FileSystem.h"
39#include "llvm/Support/NVPTXAddrSpace.h"
40#include "llvm/Support/VirtualFileSystem.h"
41#include "llvm/TargetParser/Triple.h"
42#include "llvm/Transforms/Utils/ModuleUtils.h"
43
44#include <cstdint>
45#include <iterator>
46#include <numeric>
47#include <optional>
48#include <utility>
49
50using namespace mlir;
51
52namespace {
53static llvm::omp::ScheduleKind
54convertToScheduleKind(std::optional<omp::ClauseScheduleKind> schedKind) {
55 if (!schedKind.has_value())
56 return llvm::omp::OMP_SCHEDULE_Default;
57 switch (schedKind.value()) {
58 case omp::ClauseScheduleKind::Static:
59 return llvm::omp::OMP_SCHEDULE_Static;
60 case omp::ClauseScheduleKind::Dynamic:
61 return llvm::omp::OMP_SCHEDULE_Dynamic;
62 case omp::ClauseScheduleKind::Guided:
63 return llvm::omp::OMP_SCHEDULE_Guided;
64 case omp::ClauseScheduleKind::Auto:
65 return llvm::omp::OMP_SCHEDULE_Auto;
66 case omp::ClauseScheduleKind::Runtime:
67 return llvm::omp::OMP_SCHEDULE_Runtime;
68 case omp::ClauseScheduleKind::Distribute:
69 return llvm::omp::OMP_SCHEDULE_Distribute;
70 }
71 llvm_unreachable("unhandled schedule clause argument");
72}
73
74/// ModuleTranslation stack frame for OpenMP operations. This keeps track of the
75/// insertion points for allocas.
76class OpenMPAllocStackFrame
77 : public StateStackFrameBase<OpenMPAllocStackFrame> {
78public:
80
81 explicit OpenMPAllocStackFrame(
82 llvm::OpenMPIRBuilder::InsertPointTy allocaIP,
83 llvm::ArrayRef<llvm::BasicBlock *> deallocBlocks)
84 : allocInsertPoint(allocaIP), deallocBlocks(deallocBlocks) {}
85 llvm::OpenMPIRBuilder::InsertPointTy allocInsertPoint;
86 llvm::SmallVector<llvm::BasicBlock *> deallocBlocks;
87};
88
89/// Stack frame to hold a \see llvm::CanonicalLoopInfo representing the
90/// collapsed canonical loop information corresponding to an \c omp.loop_nest
91/// operation.
92class OpenMPLoopInfoStackFrame
93 : public StateStackFrameBase<OpenMPLoopInfoStackFrame> {
94public:
95 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPLoopInfoStackFrame)
96 llvm::CanonicalLoopInfo *loopInfo = nullptr;
97};
98
99/// Custom error class to signal translation errors that don't need reporting,
100/// since encountering them will have already triggered relevant error messages.
101///
102/// Its purpose is to serve as the glue between MLIR failures represented as
103/// \see LogicalResult instances and \see llvm::Error instances used to
104/// propagate errors through the \see llvm::OpenMPIRBuilder. Generally, when an
105/// error of the first type is raised, a message is emitted directly (the \see
106/// LogicalResult itself does not hold any information). If we need to forward
107/// this error condition as an \see llvm::Error while avoiding triggering some
108/// redundant error reporting later on, we need a custom \see llvm::ErrorInfo
109/// class to just signal this situation has happened.
110///
111/// For example, this class should be used to trigger errors from within
112/// callbacks passed to the \see OpenMPIRBuilder when they were triggered by the
113/// translation of their own regions. This unclutters the error log from
114/// redundant messages.
115class PreviouslyReportedError
116 : public llvm::ErrorInfo<PreviouslyReportedError> {
117public:
118 void log(raw_ostream &) const override {
119 // Do not log anything.
120 }
121
122 std::error_code convertToErrorCode() const override {
123 llvm_unreachable(
124 "PreviouslyReportedError doesn't support ECError conversion");
125 }
126
127 // Used by ErrorInfo::classID.
128 static char ID;
129};
130
131char PreviouslyReportedError::ID = 0;
132
133/*
134 * Custom class for processing linear clause for omp.wsloop
135 * and omp.simd. Linear clause translation requires setup,
136 * initialization, update, and finalization at varying
137 * basic blocks in the IR. This class helps maintain
138 * internal state to allow consistent translation in
139 * each of these stages.
140 */
141
142class LinearClauseProcessor {
143
144private:
145 SmallVector<llvm::Value *> linearPreconditionVars;
146 SmallVector<llvm::Value *> linearLoopBodyTemps;
147 SmallVector<llvm::Value *> linearOrigVal;
148 SmallVector<llvm::Value *> linearSteps;
149 SmallVector<llvm::Type *> linearVarTypes;
150 llvm::BasicBlock *linearFinalizationBB;
151 llvm::BasicBlock *linearExitBB;
152 llvm::BasicBlock *linearLastIterExitBB;
153
154public:
155 // Register type for the linear variables
156 void registerType(LLVM::ModuleTranslation &moduleTranslation,
157 mlir::Attribute &ty) {
158 linearVarTypes.push_back(moduleTranslation.convertType(
159 mlir::cast<mlir::TypeAttr>(ty).getValue()));
160 }
161
162 // Allocate space for linear variabes
163 void createLinearVar(llvm::IRBuilderBase &builder,
164 LLVM::ModuleTranslation &moduleTranslation,
165 llvm::Value *linearVar, int idx) {
166 linearPreconditionVars.push_back(
167 builder.CreateAlloca(linearVarTypes[idx], nullptr, ".linear_var"));
168 llvm::Value *linearLoopBodyTemp =
169 builder.CreateAlloca(linearVarTypes[idx], nullptr, ".linear_result");
170 linearOrigVal.push_back(linearVar);
171 linearLoopBodyTemps.push_back(linearLoopBodyTemp);
172 }
173
174 // Initialize linear step
175 inline void initLinearStep(LLVM::ModuleTranslation &moduleTranslation,
176 mlir::Value &linearStep) {
177 linearSteps.push_back(moduleTranslation.lookupValue(linearStep));
178 }
179
180 // Emit IR for initialization of linear variables
181 void initLinearVar(llvm::IRBuilderBase &builder,
182 LLVM::ModuleTranslation &moduleTranslation,
183 llvm::BasicBlock *loopPreHeader) {
184 builder.SetInsertPoint(loopPreHeader->getTerminator());
185 for (size_t index = 0; index < linearOrigVal.size(); index++) {
186 llvm::LoadInst *linearVarLoad =
187 builder.CreateLoad(linearVarTypes[index], linearOrigVal[index]);
188 builder.CreateStore(linearVarLoad, linearPreconditionVars[index]);
189 }
190 }
191
192 // Emit IR for updating Linear variables
193 void updateLinearVar(llvm::IRBuilderBase &builder, llvm::BasicBlock *loopBody,
194 llvm::Value *loopInductionVar) {
195 builder.SetInsertPoint(loopBody->getTerminator());
196 for (size_t index = 0; index < linearPreconditionVars.size(); index++) {
197 llvm::Type *linearVarType = linearVarTypes[index];
198 llvm::Value *iv = loopInductionVar;
199 llvm::Value *step = linearSteps[index];
200
201 if (!iv->getType()->isIntegerTy())
202 llvm_unreachable("OpenMP loop induction variable must be an integer "
203 "type");
204
205 if (linearVarType->isIntegerTy()) {
206 // Integer path: normalize all arithmetic to linearVarType
207 iv = builder.CreateSExtOrTrunc(iv, linearVarType);
208 step = builder.CreateSExtOrTrunc(step, linearVarType);
209
210 llvm::LoadInst *linearVarStart =
211 builder.CreateLoad(linearVarType, linearPreconditionVars[index]);
212 llvm::Value *mulInst = builder.CreateMul(iv, step);
213 llvm::Value *addInst = builder.CreateAdd(linearVarStart, mulInst);
214 builder.CreateStore(addInst, linearLoopBodyTemps[index]);
215 } else if (linearVarType->isFloatingPointTy()) {
216 // Float path: perform multiply in integer, then convert to float
217 step = builder.CreateSExtOrTrunc(step, iv->getType());
218 llvm::Value *mulInst = builder.CreateMul(iv, step);
219
220 llvm::LoadInst *linearVarStart =
221 builder.CreateLoad(linearVarType, linearPreconditionVars[index]);
222 llvm::Value *mulFp = builder.CreateSIToFP(mulInst, linearVarType);
223 llvm::Value *addInst = builder.CreateFAdd(linearVarStart, mulFp);
224 builder.CreateStore(addInst, linearLoopBodyTemps[index]);
225 } else {
226 llvm_unreachable(
227 "Linear variable must be of integer or floating-point type");
228 }
229 }
230 }
231
232 // Linear variable finalization is conditional on the last logical iteration.
233 // Create BB splits to manage the same.
234 void splitLinearFiniBB(llvm::IRBuilderBase &builder,
235 llvm::BasicBlock *loopExit) {
236 linearFinalizationBB = loopExit->splitBasicBlock(
237 loopExit->getTerminator(), "omp_loop.linear_finalization");
238 linearExitBB = linearFinalizationBB->splitBasicBlock(
239 linearFinalizationBB->getTerminator(), "omp_loop.linear_exit");
240 linearLastIterExitBB = linearFinalizationBB->splitBasicBlock(
241 linearFinalizationBB->getTerminator(), "omp_loop.linear_lastiter_exit");
242 }
243
244 // Finalize the linear vars
245 llvm::OpenMPIRBuilder::InsertPointOrErrorTy
246 finalizeLinearVar(llvm::IRBuilderBase &builder,
247 LLVM::ModuleTranslation &moduleTranslation,
248 llvm::Value *lastIter) {
249 // Emit condition to check whether last logical iteration is being executed
250 builder.SetInsertPoint(linearFinalizationBB->getTerminator());
251 llvm::Value *loopLastIterLoad = builder.CreateLoad(
252 llvm::Type::getInt32Ty(builder.getContext()), lastIter);
253 llvm::Value *isLast =
254 builder.CreateCmp(llvm::CmpInst::ICMP_NE, loopLastIterLoad,
255 llvm::ConstantInt::get(
256 llvm::Type::getInt32Ty(builder.getContext()), 0));
257 // Store the linear variable values to original variables.
258 builder.SetInsertPoint(linearLastIterExitBB->getTerminator());
259 for (size_t index = 0; index < linearOrigVal.size(); index++) {
260 llvm::LoadInst *linearVarTemp =
261 builder.CreateLoad(linearVarTypes[index], linearLoopBodyTemps[index]);
262 builder.CreateStore(linearVarTemp, linearOrigVal[index]);
263 }
264
265 // Create conditional branch such that the linear variable
266 // values are stored to original variables only at the
267 // last logical iteration
268 builder.SetInsertPoint(linearFinalizationBB->getTerminator());
269 builder.CreateCondBr(isLast, linearLastIterExitBB, linearExitBB);
270 linearFinalizationBB->getTerminator()->eraseFromParent();
271 // Emit barrier
272 builder.SetInsertPoint(linearExitBB->getTerminator());
273 return moduleTranslation.getOpenMPBuilder()->createBarrier(
274 builder.saveIP(), llvm::omp::OMPD_barrier);
275 }
276
277 // Emit stores for linear variables. Useful in case of SIMD
278 // construct.
279 void emitStoresForLinearVar(llvm::IRBuilderBase &builder) {
280 for (size_t index = 0; index < linearOrigVal.size(); index++) {
281 llvm::LoadInst *linearVarTemp =
282 builder.CreateLoad(linearVarTypes[index], linearLoopBodyTemps[index]);
283 builder.CreateStore(linearVarTemp, linearOrigVal[index]);
284 }
285 }
286
287 // Rewrite all uses of the original variable in `BBName`
288 // with the linear variable in-place
289 void rewriteInPlace(llvm::IRBuilderBase &builder, const std::string &BBName,
290 size_t varIndex) {
291 llvm::SmallVector<llvm::User *> users;
292 for (llvm::User *user : linearOrigVal[varIndex]->users())
293 users.push_back(user);
294 for (auto *user : users) {
295 if (auto *userInst = dyn_cast<llvm::Instruction>(user)) {
296 if (userInst->getParent()->getName().str().find(BBName) !=
297 std::string::npos)
298 user->replaceUsesOfWith(linearOrigVal[varIndex],
299 linearLoopBodyTemps[varIndex]);
300 }
301 }
302 }
303};
304
305} // namespace
306
307/// Looks up from the operation from and returns the PrivateClauseOp with
308/// name symbolName
309static omp::PrivateClauseOp findPrivatizer(Operation *from,
310 SymbolRefAttr symbolName) {
311 omp::PrivateClauseOp privatizer =
313 symbolName);
314 assert(privatizer && "privatizer not found in the symbol table");
315 return privatizer;
316}
317
318/// Check whether translation to LLVM IR for the given operation is currently
319/// supported. If not, descriptive diagnostics will be emitted to let users know
320/// this is a not-yet-implemented feature.
321///
322/// \returns success if no unimplemented features are needed to translate the
323/// given operation.
324static LogicalResult checkImplementationStatus(Operation &op) {
325 auto todo = [&op](StringRef clauseName) {
326 return op.emitError() << "not yet implemented: Unhandled clause "
327 << clauseName << " in " << op.getName()
328 << " operation";
329 };
330
331 auto checkAllocate = [&todo](auto op, LogicalResult &result) {
332 if (!op.getAllocateVars().empty() || !op.getAllocatorVars().empty())
333 result = todo("allocate");
334 };
335 auto checkBare = [&todo](auto op, LogicalResult &result) {
336 if (op.getBare())
337 result = todo("ompx_bare");
338 };
339 auto checkDepend = [&todo](auto op, LogicalResult &result) {
340 if (!op.getDependVars().empty() || op.getDependKinds())
341 result = todo("depend");
342 };
343 auto checkHint = [](auto op, LogicalResult &) {
344 if (op.getHint())
345 op.emitWarning("hint clause discarded");
346 };
347 auto checkInReduction = [&todo](auto op, LogicalResult &result) {
348 if (!op.getInReductionVars().empty() || op.getInReductionByref() ||
349 op.getInReductionSyms())
350 result = todo("in_reduction");
351 };
352 auto checkNowait = [&todo](auto op, LogicalResult &result) {
353 if (op.getNowait())
354 result = todo("nowait");
355 };
356 auto checkOrder = [&todo](auto op, LogicalResult &result) {
357 if (op.getOrder() || op.getOrderMod())
358 result = todo("order");
359 };
360 auto checkPrivate = [&todo](auto op, LogicalResult &result) {
361 if (!op.getPrivateVars().empty() || op.getPrivateSyms())
362 result = todo("privatization");
363 };
364 auto checkReduction = [&todo](auto op, LogicalResult &result) {
365 if (isa<omp::TeamsOp>(op) || isa<omp::TaskloopContextOp>(op))
366 if (!op.getReductionVars().empty() || op.getReductionByref() ||
367 op.getReductionSyms())
368 result = todo("reduction");
369 if (op.getReductionMod() &&
370 op.getReductionMod().value() != omp::ReductionModifier::defaultmod)
371 result = todo("reduction with modifier");
372 };
373 auto checkTaskReduction = [&todo](auto op, LogicalResult &result) {
374 if (!op.getTaskReductionVars().empty() || op.getTaskReductionByref() ||
375 op.getTaskReductionSyms())
376 result = todo("task_reduction");
377 };
378 auto checkNumTeams = [&todo](auto op, LogicalResult &result) {
379 if (op.hasNumTeamsMultiDim())
380 result = todo("num_teams with multi-dimensional values");
381 };
382 auto checkNumThreads = [&todo](auto op, LogicalResult &result) {
383 if (op.hasNumThreadsMultiDim())
384 result = todo("num_threads with multi-dimensional values");
385 };
386
387 auto checkThreadLimit = [&todo](auto op, LogicalResult &result) {
388 if (op.hasThreadLimitMultiDim())
389 result = todo("thread_limit with multi-dimensional values");
390 };
391
392 auto checkDynGroupprivate = [&todo](auto op, LogicalResult &result) {
393 if (op.getDynGroupprivateSize())
394 result = todo("dyn_groupprivate");
395 };
396
397 LogicalResult result = success();
399 .Case([&](omp::DistributeOp op) {
400 checkAllocate(op, result);
401 checkOrder(op, result);
402 })
403 .Case([&](omp::SectionsOp op) {
404 checkAllocate(op, result);
405 checkPrivate(op, result);
406 checkReduction(op, result);
407 })
408 .Case([&](omp::ScopeOp op) {
409 checkAllocate(op, result);
410 checkReduction(op, result);
411 })
412 .Case([&](omp::SingleOp op) {
413 checkAllocate(op, result);
414 checkPrivate(op, result);
415 })
416 .Case([&](omp::TeamsOp op) {
417 checkAllocate(op, result);
418 checkPrivate(op, result);
419 checkNumTeams(op, result);
420 checkThreadLimit(op, result);
421 checkDynGroupprivate(op, result);
422 })
423 .Case([&](omp::TaskOp op) {
424 checkAllocate(op, result);
425 checkInReduction(op, result);
426 })
427 .Case([&](omp::TaskgroupOp op) {
428 checkAllocate(op, result);
429 checkTaskReduction(op, result);
430 })
431 .Case([&](omp::TaskwaitOp op) {
432 checkDepend(op, result);
433 checkNowait(op, result);
434 })
435 .Case([&](omp::TaskloopContextOp op) {
436 checkAllocate(op, result);
437 checkInReduction(op, result);
438 checkReduction(op, result);
439 })
440 .Case([&](omp::WsloopOp op) {
441 checkAllocate(op, result);
442 checkOrder(op, result);
443 checkReduction(op, result);
444 })
445 .Case([&](omp::ParallelOp op) {
446 checkAllocate(op, result);
447 checkReduction(op, result);
448 checkNumThreads(op, result);
449 })
450 .Case([&](omp::SimdOp op) { checkReduction(op, result); })
451 .Case<omp::AtomicReadOp, omp::AtomicWriteOp, omp::AtomicUpdateOp,
452 omp::AtomicCaptureOp>([&](auto op) { checkHint(op, result); })
453 .Case<omp::TargetEnterDataOp, omp::TargetExitDataOp>(
454 [&](auto op) { checkDepend(op, result); })
455 .Case([&](omp::TargetUpdateOp op) { checkDepend(op, result); })
456 .Case([&](omp::TargetOp op) {
457 checkAllocate(op, result);
458 checkBare(op, result);
459 checkInReduction(op, result);
460 checkThreadLimit(op, result);
461 })
462 .Default([](Operation &) {
463 // Assume all clauses for an operation can be translated unless they are
464 // checked above.
465 });
466 return result;
467}
468
469static LogicalResult handleError(llvm::Error error, Operation &op) {
470 LogicalResult result = success();
471 if (error) {
472 llvm::handleAllErrors(
473 std::move(error),
474 [&](const PreviouslyReportedError &) { result = failure(); },
475 [&](const llvm::ErrorInfoBase &err) {
476 result = op.emitError(err.message());
477 });
478 }
479 return result;
480}
481
482template <typename T>
483static LogicalResult handleError(llvm::Expected<T> &result, Operation &op) {
484 if (!result)
485 return handleError(result.takeError(), op);
486
487 return success();
488}
489
490/// Find the insertion point for allocas given the current insertion point for
491/// normal operations in the builder.
492static llvm::OpenMPIRBuilder::InsertPointTy findAllocInsertPoints(
493 llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation,
494 llvm::SmallVectorImpl<llvm::BasicBlock *> *deallocBlocks = nullptr) {
495 // If there is an allocation insertion point on stack, i.e. we are in a nested
496 // operation and a specific point was provided by some surrounding operation,
497 // use it.
498 llvm::OpenMPIRBuilder::InsertPointTy allocInsertPoint;
499 llvm::ArrayRef<llvm::BasicBlock *> deallocInsertPoints;
500 WalkResult walkResult = moduleTranslation.stackWalk<OpenMPAllocStackFrame>(
501 [&](OpenMPAllocStackFrame &frame) {
502 allocInsertPoint = frame.allocInsertPoint;
503 deallocInsertPoints = frame.deallocBlocks;
504 return WalkResult::interrupt();
505 });
506 // In cases with multiple levels of outlining, the tree walk might find an
507 // insertion point that is inside the original function while the builder
508 // insertion point is inside the outlined function. We need to make sure that
509 // we do not use it in those cases.
510 if (walkResult.wasInterrupted() &&
511 allocInsertPoint.getBlock()->getParent() ==
512 builder.GetInsertBlock()->getParent()) {
513 if (deallocBlocks)
514 deallocBlocks->insert(deallocBlocks->end(), deallocInsertPoints.begin(),
515 deallocInsertPoints.end());
516 return allocInsertPoint;
517 }
518
519 // Otherwise, insert to the entry block of the surrounding function.
520 // If the current IRBuilder InsertPoint is the function's entry, it cannot
521 // also be used for alloca insertion which would result in insertion order
522 // confusion. Create a new BasicBlock for the Builder and use the entry block
523 // for the allocs.
524 // TODO: Create a dedicated alloca BasicBlock at function creation such that
525 // we do not need to move the current InsertPoint here.
526 if (builder.GetInsertBlock() ==
527 &builder.GetInsertBlock()->getParent()->getEntryBlock()) {
528 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end() &&
529 "Assuming end of basic block");
530 llvm::BasicBlock *entryBB = llvm::BasicBlock::Create(
531 builder.getContext(), "entry", builder.GetInsertBlock()->getParent(),
532 builder.GetInsertBlock()->getNextNode());
533 builder.CreateBr(entryBB);
534 builder.SetInsertPoint(entryBB);
535 }
536
537 // Collect exit blocks, which is where explicit deallocations should happen in
538 // this case.
539 if (deallocBlocks) {
540 for (llvm::BasicBlock &block : *builder.GetInsertBlock()->getParent()) {
541 // TODO: This currently results in no blocks being added to the list when
542 // all exit blocks of the enclosing function have not been lowered before
543 // this is reached.
544 llvm::Instruction *terminator = block.getTerminatorOrNull();
545 if (isa_and_present<llvm::ReturnInst>(terminator))
546 deallocBlocks->emplace_back(&block);
547 }
548 }
549
550 llvm::BasicBlock &funcEntryBlock =
551 builder.GetInsertBlock()->getParent()->getEntryBlock();
552 return llvm::OpenMPIRBuilder::InsertPointTy(
553 &funcEntryBlock, funcEntryBlock.getFirstInsertionPt());
554}
555
556/// Find the loop information structure for the loop nest being translated. It
557/// will return a `null` value unless called from the translation function for
558/// a loop wrapper operation after successfully translating its body.
559static llvm::CanonicalLoopInfo *
561 llvm::CanonicalLoopInfo *loopInfo = nullptr;
562 moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
563 [&](OpenMPLoopInfoStackFrame &frame) {
564 loopInfo = frame.loopInfo;
565 return WalkResult::interrupt();
566 });
567 return loopInfo;
568}
569
570/// Converts the given region that appears within an OpenMP dialect operation to
571/// LLVM IR, creating a branch from the `sourceBlock` to the entry block of the
572/// region, and a branch from any block with an successor-less OpenMP terminator
573/// to `continuationBlock`. Populates `continuationBlockPHIs` with the PHI nodes
574/// of the continuation block if provided.
576 Region &region, StringRef blockName, llvm::IRBuilderBase &builder,
577 LLVM::ModuleTranslation &moduleTranslation,
578 SmallVectorImpl<llvm::PHINode *> *continuationBlockPHIs = nullptr) {
579 bool isLoopWrapper = isa<omp::LoopWrapperInterface>(region.getParentOp());
580
581 llvm::BasicBlock *continuationBlock =
582 splitBB(builder, true, "omp.region.cont");
583 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
584
585 llvm::LLVMContext &llvmContext = builder.getContext();
586 for (Block &bb : region) {
587 llvm::BasicBlock *llvmBB = llvm::BasicBlock::Create(
588 llvmContext, blockName, builder.GetInsertBlock()->getParent(),
589 builder.GetInsertBlock()->getNextNode());
590 moduleTranslation.mapBlock(&bb, llvmBB);
591 }
592
593 llvm::Instruction *sourceTerminator = sourceBlock->getTerminator();
594
595 // Terminators (namely YieldOp) may be forwarding values to the region that
596 // need to be available in the continuation block. Collect the types of these
597 // operands in preparation of creating PHI nodes. This is skipped for loop
598 // wrapper operations, for which we know in advance they have no terminators.
599 SmallVector<llvm::Type *> continuationBlockPHITypes;
600 unsigned numYields = 0;
601
602 if (!isLoopWrapper) {
603 bool operandsProcessed = false;
604 for (Block &bb : region.getBlocks()) {
605 if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(bb.getTerminator())) {
606 if (!operandsProcessed) {
607 for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
608 continuationBlockPHITypes.push_back(
609 moduleTranslation.convertType(yield->getOperand(i).getType()));
610 }
611 operandsProcessed = true;
612 } else {
613 assert(continuationBlockPHITypes.size() == yield->getNumOperands() &&
614 "mismatching number of values yielded from the region");
615 for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
616 llvm::Type *operandType =
617 moduleTranslation.convertType(yield->getOperand(i).getType());
618 (void)operandType;
619 assert(continuationBlockPHITypes[i] == operandType &&
620 "values of mismatching types yielded from the region");
621 }
622 }
623 numYields++;
624 }
625 }
626 }
627
628 // Insert PHI nodes in the continuation block for any values forwarded by the
629 // terminators in this region.
630 if (!continuationBlockPHITypes.empty())
631 assert(
632 continuationBlockPHIs &&
633 "expected continuation block PHIs if converted regions yield values");
634 if (continuationBlockPHIs) {
635 llvm::IRBuilderBase::InsertPointGuard guard(builder);
636 continuationBlockPHIs->reserve(continuationBlockPHITypes.size());
637 builder.SetInsertPoint(continuationBlock, continuationBlock->begin());
638 for (llvm::Type *ty : continuationBlockPHITypes)
639 continuationBlockPHIs->push_back(builder.CreatePHI(ty, numYields));
640 }
641
642 // Convert blocks one by one in topological order to ensure
643 // defs are converted before uses.
645 for (Block *bb : blocks) {
646 llvm::BasicBlock *llvmBB = moduleTranslation.lookupBlock(bb);
647 // Retarget the branch of the entry block to the entry block of the
648 // converted region (regions are single-entry).
649 if (bb->isEntryBlock()) {
650 assert(sourceTerminator->getNumSuccessors() == 1 &&
651 "provided entry block has multiple successors");
652 assert(sourceTerminator->getSuccessor(0) == continuationBlock &&
653 "ContinuationBlock is not the successor of the entry block");
654 sourceTerminator->setSuccessor(0, llvmBB);
655 }
656
657 llvm::IRBuilderBase::InsertPointGuard guard(builder);
658 if (failed(
659 moduleTranslation.convertBlock(*bb, bb->isEntryBlock(), builder)))
660 return llvm::make_error<PreviouslyReportedError>();
661
662 // Create a direct branch here for loop wrappers to prevent their lack of a
663 // terminator from causing a crash below.
664 if (isLoopWrapper) {
665 builder.CreateBr(continuationBlock);
666 continue;
667 }
668
669 // Special handling for `omp.yield` and `omp.terminator` (we may have more
670 // than one): they return the control to the parent OpenMP dialect operation
671 // so replace them with the branch to the continuation block. We handle this
672 // here to avoid relying inter-function communication through the
673 // ModuleTranslation class to set up the correct insertion point. This is
674 // also consistent with MLIR's idiom of handling special region terminators
675 // in the same code that handles the region-owning operation.
676 Operation *terminator = bb->getTerminator();
677 if (isa<omp::TerminatorOp, omp::YieldOp>(terminator)) {
678 builder.CreateBr(continuationBlock);
679
680 for (unsigned i = 0, e = terminator->getNumOperands(); i < e; ++i)
681 (*continuationBlockPHIs)[i]->addIncoming(
682 moduleTranslation.lookupValue(terminator->getOperand(i)), llvmBB);
683 }
684 }
685 // After all blocks have been traversed and values mapped, connect the PHI
686 // nodes to the results of preceding blocks.
687 LLVM::detail::connectPHINodes(region, moduleTranslation);
688
689 // Remove the blocks and values defined in this region from the mapping since
690 // they are not visible outside of this region. This allows the same region to
691 // be converted several times, that is cloned, without clashes, and slightly
692 // speeds up the lookups.
693 moduleTranslation.forgetMapping(region);
694
695 return continuationBlock;
696}
697
698/// Convert ProcBindKind from MLIR-generated enum to LLVM enum.
699static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind) {
700 switch (kind) {
701 case omp::ClauseProcBindKind::Close:
702 return llvm::omp::ProcBindKind::OMP_PROC_BIND_close;
703 case omp::ClauseProcBindKind::Master:
704 return llvm::omp::ProcBindKind::OMP_PROC_BIND_master;
705 case omp::ClauseProcBindKind::Primary:
706 return llvm::omp::ProcBindKind::OMP_PROC_BIND_primary;
707 case omp::ClauseProcBindKind::Spread:
708 return llvm::omp::ProcBindKind::OMP_PROC_BIND_spread;
709 }
710 llvm_unreachable("Unknown ClauseProcBindKind kind");
711}
712
713/// Converts an OpenMP 'masked' operation into LLVM IR using OpenMPIRBuilder.
714static LogicalResult
715convertOmpMasked(Operation &opInst, llvm::IRBuilderBase &builder,
716 LLVM::ModuleTranslation &moduleTranslation) {
717 auto maskedOp = cast<omp::MaskedOp>(opInst);
718 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
719
720 if (failed(checkImplementationStatus(opInst)))
721 return failure();
722
723 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
725 // MaskedOp has only one region associated with it.
726 auto &region = maskedOp.getRegion();
727 builder.restoreIP(codeGenIP);
728 return convertOmpOpRegions(region, "omp.masked.region", builder,
729 moduleTranslation)
730 .takeError();
731 };
732
733 // TODO: Perform finalization actions for variables. This has to be
734 // called for variables which have destructors/finalizers.
735 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
736
737 llvm::Value *filterVal = nullptr;
738 if (auto filterVar = maskedOp.getFilteredThreadId()) {
739 filterVal = moduleTranslation.lookupValue(filterVar);
740 } else {
741 llvm::LLVMContext &llvmContext = builder.getContext();
742 filterVal =
743 llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext), /*V=*/0);
744 }
745 assert(filterVal != nullptr);
746 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
747 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
748 moduleTranslation.getOpenMPBuilder()->createMasked(ompLoc, bodyGenCB,
749 finiCB, filterVal);
750
751 if (failed(handleError(afterIP, opInst)))
752 return failure();
753
754 builder.restoreIP(*afterIP);
755 return success();
756}
757
758/// Converts an OpenMP 'master' operation into LLVM IR using OpenMPIRBuilder.
759static LogicalResult
760convertOmpMaster(Operation &opInst, llvm::IRBuilderBase &builder,
761 LLVM::ModuleTranslation &moduleTranslation) {
762 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
763 auto masterOp = cast<omp::MasterOp>(opInst);
764
765 if (failed(checkImplementationStatus(opInst)))
766 return failure();
767
768 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
770 // MasterOp has only one region associated with it.
771 auto &region = masterOp.getRegion();
772 builder.restoreIP(codeGenIP);
773 return convertOmpOpRegions(region, "omp.master.region", builder,
774 moduleTranslation)
775 .takeError();
776 };
777
778 // TODO: Perform finalization actions for variables. This has to be
779 // called for variables which have destructors/finalizers.
780 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
781
782 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
783 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
784 moduleTranslation.getOpenMPBuilder()->createMaster(ompLoc, bodyGenCB,
785 finiCB);
786
787 if (failed(handleError(afterIP, opInst)))
788 return failure();
789
790 builder.restoreIP(*afterIP);
791 return success();
792}
793
794/// Converts an OpenMP 'critical' operation into LLVM IR using OpenMPIRBuilder.
795static LogicalResult
796convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder,
797 LLVM::ModuleTranslation &moduleTranslation) {
798 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
799 auto criticalOp = cast<omp::CriticalOp>(opInst);
800
801 if (failed(checkImplementationStatus(opInst)))
802 return failure();
803
804 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
806 // CriticalOp has only one region associated with it.
807 auto &region = cast<omp::CriticalOp>(opInst).getRegion();
808 builder.restoreIP(codeGenIP);
809 return convertOmpOpRegions(region, "omp.critical.region", builder,
810 moduleTranslation)
811 .takeError();
812 };
813
814 // TODO: Perform finalization actions for variables. This has to be
815 // called for variables which have destructors/finalizers.
816 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
817
818 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
819 llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
820 llvm::Constant *hint = nullptr;
821
822 // If it has a name, it probably has a hint too.
823 if (criticalOp.getNameAttr()) {
824 // The verifiers in OpenMP Dialect guarentee that all the pointers are
825 // non-null
826 auto symbolRef = cast<SymbolRefAttr>(criticalOp.getNameAttr());
827 auto criticalDeclareOp =
829 symbolRef);
830 hint =
831 llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext),
832 static_cast<int>(criticalDeclareOp.getHint()));
833 }
834 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
835 moduleTranslation.getOpenMPBuilder()->createCritical(
836 ompLoc, bodyGenCB, finiCB, criticalOp.getName().value_or(""), hint);
837
838 if (failed(handleError(afterIP, opInst)))
839 return failure();
840
841 builder.restoreIP(*afterIP);
842 return success();
843}
844
845/// A util to collect info needed to convert delayed privatizers from MLIR to
846/// LLVM.
848 template <typename OP>
850 : blockArgs(
851 cast<omp::BlockArgOpenMPOpInterface>(*op).getPrivateBlockArgs()) {
852 mlirVars.reserve(blockArgs.size());
853 llvmVars.reserve(blockArgs.size());
854 collectPrivatizationDecls<OP>(op);
855
856 for (mlir::Value privateVar : op.getPrivateVars())
857 mlirVars.push_back(privateVar);
858 }
859
864
865private:
866 /// Populates `privatizations` with privatization declarations used for the
867 /// given op.
868 template <class OP>
869 void collectPrivatizationDecls(OP op) {
870 std::optional<ArrayAttr> attr = op.getPrivateSyms();
871 if (!attr)
872 return;
873
874 privatizers.reserve(privatizers.size() + attr->size());
875 for (auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
876 privatizers.push_back(findPrivatizer(op, symbolRef));
877 }
878 }
879};
880
881/// Populates `reductions` with reduction declarations used in the given op.
882template <typename T>
883static void
886 std::optional<ArrayAttr> attr = op.getReductionSyms();
887 if (!attr)
888 return;
889
890 reductions.reserve(reductions.size() + op.getNumReductionVars());
891 for (auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
892 reductions.push_back(
894 op, symbolRef));
895 }
896}
897
898/// Translates the blocks contained in the given region and appends them to at
899/// the current insertion point of `builder`. The operations of the entry block
900/// are appended to the current insertion block. If set, `continuationBlockArgs`
901/// is populated with translated values that correspond to the values
902/// omp.yield'ed from the region.
903static LogicalResult inlineConvertOmpRegions(
904 Region &region, StringRef blockName, llvm::IRBuilderBase &builder,
905 LLVM::ModuleTranslation &moduleTranslation,
906 SmallVectorImpl<llvm::Value *> *continuationBlockArgs = nullptr) {
907 if (region.empty())
908 return success();
909
910 // Special case for single-block regions that don't create additional blocks:
911 // insert operations without creating additional blocks.
912 if (region.hasOneBlock()) {
913 llvm::Instruction *potentialTerminator =
914 builder.GetInsertBlock()->empty() ? nullptr
915 : &builder.GetInsertBlock()->back();
916
917 if (potentialTerminator && potentialTerminator->isTerminator())
918 potentialTerminator->removeFromParent();
919 moduleTranslation.mapBlock(&region.front(), builder.GetInsertBlock());
920
921 if (failed(moduleTranslation.convertBlock(
922 region.front(), /*ignoreArguments=*/true, builder)))
923 return failure();
924
925 // The continuation arguments are simply the translated terminator operands.
926 if (continuationBlockArgs)
927 llvm::append_range(
928 *continuationBlockArgs,
929 moduleTranslation.lookupValues(region.front().back().getOperands()));
930
931 // Drop the mapping that is no longer necessary so that the same region can
932 // be processed multiple times.
933 moduleTranslation.forgetMapping(region);
934
935 if (potentialTerminator && potentialTerminator->isTerminator()) {
936 llvm::BasicBlock *block = builder.GetInsertBlock();
937 if (block->empty()) {
938 // this can happen for really simple reduction init regions e.g.
939 // %0 = llvm.mlir.constant(0 : i32) : i32
940 // omp.yield(%0 : i32)
941 // because the llvm.mlir.constant (MLIR op) isn't converted into any
942 // llvm op
943 potentialTerminator->insertInto(block, block->begin());
944 } else {
945 potentialTerminator->insertAfter(&block->back());
946 }
947 }
948
949 return success();
950 }
951
953 llvm::Expected<llvm::BasicBlock *> continuationBlock =
954 convertOmpOpRegions(region, blockName, builder, moduleTranslation, &phis);
955
956 if (failed(handleError(continuationBlock, *region.getParentOp())))
957 return failure();
958
959 if (continuationBlockArgs)
960 llvm::append_range(*continuationBlockArgs, phis);
961 builder.SetInsertPoint(*continuationBlock,
962 (*continuationBlock)->getFirstInsertionPt());
963 return success();
964}
965
966namespace {
967/// Owning equivalents of OpenMPIRBuilder::(Atomic)ReductionGen that are used to
968/// store lambdas with capture.
969using OwningReductionGen =
970 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
971 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *,
972 llvm::Value *&)>;
973using OwningAtomicReductionGen =
974 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
975 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Type *, llvm::Value *,
976 llvm::Value *)>;
977using OwningDataPtrPtrReductionGen =
978 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
979 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *&)>;
980} // namespace
981
982/// Create an OpenMPIRBuilder-compatible reduction generator for the given
983/// reduction declaration. The generator uses `builder` but ignores its
984/// insertion point.
985static OwningReductionGen
986makeReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder,
987 LLVM::ModuleTranslation &moduleTranslation) {
988 // The lambda is mutable because we need access to non-const methods of decl
989 // (which aren't actually mutating it), and we must capture decl by-value to
990 // avoid the dangling reference after the parent function returns.
991 OwningReductionGen gen =
992 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
993 llvm::Value *lhs, llvm::Value *rhs,
994 llvm::Value *&result) mutable
995 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
996 moduleTranslation.mapValue(decl.getReductionLhsArg(), lhs);
997 moduleTranslation.mapValue(decl.getReductionRhsArg(), rhs);
998 builder.restoreIP(insertPoint);
1000 if (failed(inlineConvertOmpRegions(decl.getReductionRegion(),
1001 "omp.reduction.nonatomic.body", builder,
1002 moduleTranslation, &phis)))
1003 return llvm::createStringError(
1004 "failed to inline `combiner` region of `omp.declare_reduction`");
1005 result = llvm::getSingleElement(phis);
1006 return builder.saveIP();
1007 };
1008 return gen;
1009}
1010
1011/// Create an OpenMPIRBuilder-compatible atomic reduction generator for the
1012/// given reduction declaration. The generator uses `builder` but ignores its
1013/// insertion point. Returns null if there is no atomic region available in the
1014/// reduction declaration.
1015static OwningAtomicReductionGen
1016makeAtomicReductionGen(omp::DeclareReductionOp decl,
1017 llvm::IRBuilderBase &builder,
1018 LLVM::ModuleTranslation &moduleTranslation) {
1019 if (decl.getAtomicReductionRegion().empty())
1020 return OwningAtomicReductionGen();
1021
1022 // The lambda is mutable because we need access to non-const methods of decl
1023 // (which aren't actually mutating it), and we must capture decl by-value to
1024 // avoid the dangling reference after the parent function returns.
1025 OwningAtomicReductionGen atomicGen =
1026 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint, llvm::Type *,
1027 llvm::Value *lhs, llvm::Value *rhs) mutable
1028 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1029 moduleTranslation.mapValue(decl.getAtomicReductionLhsArg(), lhs);
1030 moduleTranslation.mapValue(decl.getAtomicReductionRhsArg(), rhs);
1031 builder.restoreIP(insertPoint);
1033 if (failed(inlineConvertOmpRegions(decl.getAtomicReductionRegion(),
1034 "omp.reduction.atomic.body", builder,
1035 moduleTranslation, &phis)))
1036 return llvm::createStringError(
1037 "failed to inline `atomic` region of `omp.declare_reduction`");
1038 assert(phis.empty());
1039 return builder.saveIP();
1040 };
1041 return atomicGen;
1042}
1043
1044/// Create an OpenMPIRBuilder-compatible `data_ptr_ptr` reduction generator for
1045/// the given reduction declaration. The generator uses `builder` but ignores
1046/// its insertion point. Returns null if there is no `data_ptr_ptr` region
1047/// available in the reduction declaration.
1048static OwningDataPtrPtrReductionGen
1049makeRefDataPtrGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder,
1050 LLVM::ModuleTranslation &moduleTranslation, bool isByRef) {
1051 if (!isByRef || decl.getDataPtrPtrRegion().empty())
1052 return OwningDataPtrPtrReductionGen();
1053
1054 OwningDataPtrPtrReductionGen refDataPtrGen =
1055 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
1056 llvm::Value *byRefVal, llvm::Value *&result) mutable
1057 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1058 moduleTranslation.mapValue(decl.getDataPtrPtrRegionArg(), byRefVal);
1059 builder.restoreIP(insertPoint);
1061 if (failed(inlineConvertOmpRegions(decl.getDataPtrPtrRegion(),
1062 "omp.data_ptr_ptr.body", builder,
1063 moduleTranslation, &phis)))
1064 return llvm::createStringError(
1065 "failed to inline `data_ptr_ptr` region of `omp.declare_reduction`");
1066 result = llvm::getSingleElement(phis);
1067 return builder.saveIP();
1068 };
1069
1070 return refDataPtrGen;
1071}
1072
1073/// Converts an OpenMP 'ordered' operation into LLVM IR using OpenMPIRBuilder.
1074static LogicalResult
1075convertOmpOrdered(Operation &opInst, llvm::IRBuilderBase &builder,
1076 LLVM::ModuleTranslation &moduleTranslation) {
1077 auto orderedOp = cast<omp::OrderedOp>(opInst);
1078
1079 if (failed(checkImplementationStatus(opInst)))
1080 return failure();
1081
1082 omp::ClauseDepend dependType = *orderedOp.getDoacrossDependType();
1083 bool isDependSource = dependType == omp::ClauseDepend::dependsource;
1084 unsigned numLoops = *orderedOp.getDoacrossNumLoops();
1085 SmallVector<llvm::Value *> vecValues =
1086 moduleTranslation.lookupValues(orderedOp.getDoacrossDependVars());
1087
1088 size_t indexVecValues = 0;
1089 while (indexVecValues < vecValues.size()) {
1090 SmallVector<llvm::Value *> storeValues;
1091 storeValues.reserve(numLoops);
1092 for (unsigned i = 0; i < numLoops; i++) {
1093 storeValues.push_back(vecValues[indexVecValues]);
1094 indexVecValues++;
1095 }
1096 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1097 findAllocInsertPoints(builder, moduleTranslation);
1098 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1099 builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createOrderedDepend(
1100 ompLoc, allocaIP, numLoops, storeValues, ".cnt.addr", isDependSource));
1101 }
1102 return success();
1103}
1104
1105/// Converts an OpenMP 'ordered_region' operation into LLVM IR using
1106/// OpenMPIRBuilder.
1107static LogicalResult
1108convertOmpOrderedRegion(Operation &opInst, llvm::IRBuilderBase &builder,
1109 LLVM::ModuleTranslation &moduleTranslation) {
1110 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1111 auto orderedRegionOp = cast<omp::OrderedRegionOp>(opInst);
1112
1113 if (failed(checkImplementationStatus(opInst)))
1114 return failure();
1115
1116 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
1117 llvm::ArrayRef<llvm::BasicBlock *> deallocBlocks) {
1118 // OrderedOp has only one region associated with it.
1119 auto &region = cast<omp::OrderedRegionOp>(opInst).getRegion();
1120 builder.restoreIP(codeGenIP);
1121 return convertOmpOpRegions(region, "omp.ordered.region", builder,
1122 moduleTranslation)
1123 .takeError();
1124 };
1125
1126 // TODO: Perform finalization actions for variables. This has to be
1127 // called for variables which have destructors/finalizers.
1128 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
1129
1130 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1131 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1132 moduleTranslation.getOpenMPBuilder()->createOrderedThreadsSimd(
1133 ompLoc, bodyGenCB, finiCB, !orderedRegionOp.getParLevelSimd());
1134
1135 if (failed(handleError(afterIP, opInst)))
1136 return failure();
1137
1138 builder.restoreIP(*afterIP);
1139 return success();
1140}
1141
1142namespace {
1143/// Contains the arguments for an LLVM store operation
1144struct DeferredStore {
1145 DeferredStore(llvm::Value *value, llvm::Value *address)
1146 : value(value), address(address) {}
1147
1148 llvm::Value *value;
1149 llvm::Value *address;
1150};
1151} // namespace
1152
1153/// Allocate space for privatized reduction variables.
1154/// `deferredStores` contains information to create store operations which needs
1155/// to be inserted after all allocas
1156template <typename T>
1157static LogicalResult
1159 llvm::IRBuilderBase &builder,
1160 LLVM::ModuleTranslation &moduleTranslation,
1161 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1163 SmallVectorImpl<llvm::Value *> &privateReductionVariables,
1164 DenseMap<Value, llvm::Value *> &reductionVariableMap,
1165 SmallVectorImpl<DeferredStore> &deferredStores,
1166 llvm::ArrayRef<bool> isByRefs) {
1167 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1168 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1169
1170 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
1171 bool useDeviceSharedMem = omp::opInSharedDeviceContext(*op);
1172
1173 // delay creating stores until after all allocas
1174 deferredStores.reserve(op.getNumReductionVars());
1175
1176 for (std::size_t i = 0; i < op.getNumReductionVars(); ++i) {
1177 Region &allocRegion = reductionDecls[i].getAllocRegion();
1178 if (isByRefs[i]) {
1179 if (allocRegion.empty())
1180 continue;
1181
1183 if (failed(inlineConvertOmpRegions(allocRegion, "omp.reduction.alloc",
1184 builder, moduleTranslation, &phis)))
1185 return op.emitError(
1186 "failed to inline `alloc` region of `omp.declare_reduction`");
1187
1188 assert(phis.size() == 1 && "expected one allocation to be yielded");
1189 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1190
1191 // Allocate reduction variable (which is a pointer to the real reduction
1192 // variable allocated in the inlined region)
1193 llvm::Type *ptrTy = builder.getPtrTy();
1194 llvm::Type *varTy =
1195 moduleTranslation.convertType(reductionDecls[i].getType());
1196 llvm::Value *var;
1197 if (useDeviceSharedMem) {
1198 var = ompBuilder->createOMPAllocShared(builder, varTy);
1199 } else {
1200 var = builder.CreateAlloca(varTy);
1201 var = builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
1202 }
1203
1204 llvm::Value *castPhi =
1205 builder.CreatePointerBitCastOrAddrSpaceCast(phis[0], ptrTy);
1206
1207 deferredStores.emplace_back(castPhi, var);
1208
1209 privateReductionVariables[i] = var;
1210 moduleTranslation.mapValue(reductionArgs[i], castPhi);
1211 reductionVariableMap.try_emplace(op.getReductionVars()[i], castPhi);
1212 } else {
1213 assert(allocRegion.empty() &&
1214 "allocaction is implicit for by-val reduction");
1215
1216 llvm::Type *ptrTy = builder.getPtrTy();
1217 llvm::Type *varTy =
1218 moduleTranslation.convertType(reductionDecls[i].getType());
1219 llvm::Value *var;
1220 if (useDeviceSharedMem) {
1221 var = ompBuilder->createOMPAllocShared(builder, varTy);
1222 } else {
1223 var = builder.CreateAlloca(varTy);
1224 var = builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
1225 }
1226
1227 moduleTranslation.mapValue(reductionArgs[i], var);
1228 privateReductionVariables[i] = var;
1229 reductionVariableMap.try_emplace(op.getReductionVars()[i], var);
1230 }
1231 }
1232
1233 return success();
1234}
1235
1236/// Map input arguments to reduction initialization region
1237template <typename T>
1238static void
1240 llvm::IRBuilderBase &builder,
1242 DenseMap<Value, llvm::Value *> &reductionVariableMap,
1243 unsigned i) {
1244 // map input argument to the initialization region
1245 mlir::omp::DeclareReductionOp &reduction = reductionDecls[i];
1246 Region &initializerRegion = reduction.getInitializerRegion();
1247 Block &entry = initializerRegion.front();
1248
1249 mlir::Value mlirSource = loop.getReductionVars()[i];
1250 llvm::Value *llvmSource = moduleTranslation.lookupValue(mlirSource);
1251 llvm::Value *origVal = llvmSource;
1252 // If a non-pointer value is expected, load the value from the source pointer.
1253 if (!isa<LLVM::LLVMPointerType>(
1254 reduction.getInitializerMoldArg().getType()) &&
1255 isa<LLVM::LLVMPointerType>(mlirSource.getType())) {
1256 origVal =
1257 builder.CreateLoad(moduleTranslation.convertType(
1258 reduction.getInitializerMoldArg().getType()),
1259 llvmSource, "omp_orig");
1260 }
1261 moduleTranslation.mapValue(reduction.getInitializerMoldArg(), origVal);
1262
1263 if (entry.getNumArguments() > 1) {
1264 llvm::Value *allocation =
1265 reductionVariableMap.lookup(loop.getReductionVars()[i]);
1266 moduleTranslation.mapValue(reduction.getInitializerAllocArg(), allocation);
1267 }
1268}
1269
1270static void
1271setInsertPointForPossiblyEmptyBlock(llvm::IRBuilderBase &builder,
1272 llvm::BasicBlock *block = nullptr) {
1273 if (block == nullptr)
1274 block = builder.GetInsertBlock();
1275
1276 if (!block->hasTerminator())
1277 builder.SetInsertPoint(block);
1278 else
1279 builder.SetInsertPoint(block->getTerminator());
1280}
1281
1282/// Inline reductions' `init` regions. This functions assumes that the
1283/// `builder`'s insertion point is where the user wants the `init` regions to be
1284/// inlined; i.e. it does not try to find a proper insertion location for the
1285/// `init` regions. It also leaves the `builder's insertions point in a state
1286/// where the user can continue the code-gen directly afterwards.
1287template <typename OP>
1288static LogicalResult
1289initReductionVars(OP op, ArrayRef<BlockArgument> reductionArgs,
1290 llvm::IRBuilderBase &builder,
1291 LLVM::ModuleTranslation &moduleTranslation,
1292 llvm::BasicBlock *latestAllocaBlock,
1294 SmallVectorImpl<llvm::Value *> &privateReductionVariables,
1295 DenseMap<Value, llvm::Value *> &reductionVariableMap,
1296 llvm::ArrayRef<bool> isByRef,
1297 SmallVectorImpl<DeferredStore> &deferredStores) {
1298 if (op.getNumReductionVars() == 0)
1299 return success();
1300
1301 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
1302 bool useDeviceSharedMem = omp::opInSharedDeviceContext(*op);
1303
1304 llvm::BasicBlock *initBlock = splitBB(builder, true, "omp.reduction.init");
1305 auto allocaIP = llvm::IRBuilderBase::InsertPoint(
1306 latestAllocaBlock, latestAllocaBlock->getTerminator()->getIterator());
1307 builder.restoreIP(allocaIP);
1308 SmallVector<llvm::Value *> byRefVars(op.getNumReductionVars());
1309
1310 for (unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1311 if (isByRef[i]) {
1312 if (!reductionDecls[i].getAllocRegion().empty())
1313 continue;
1314
1315 // TODO: remove after all users of by-ref are updated to use the alloc
1316 // region: Allocate reduction variable (which is a pointer to the real
1317 // reduciton variable allocated in the inlined region)
1318 llvm::Type *varTy =
1319 moduleTranslation.convertType(reductionDecls[i].getType());
1320 if (useDeviceSharedMem)
1321 byRefVars[i] = ompBuilder->createOMPAllocShared(builder, varTy);
1322 else
1323 byRefVars[i] = builder.CreateAlloca(varTy);
1324 }
1325 }
1326
1327 setInsertPointForPossiblyEmptyBlock(builder, initBlock);
1328
1329 // store result of the alloc region to the allocated pointer to the real
1330 // reduction variable
1331 for (auto [data, addr] : deferredStores)
1332 builder.CreateStore(data, addr);
1333
1334 // Before the loop, store the initial values of reductions into reduction
1335 // variables. Although this could be done after allocas, we don't want to mess
1336 // up with the alloca insertion point.
1337 for (unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1339
1340 // map block argument to initializer region
1341 mapInitializationArgs(op, moduleTranslation, builder, reductionDecls,
1342 reductionVariableMap, i);
1343
1344 // TODO In some cases (specially on the GPU), the init regions may
1345 // contains stack alloctaions. If the region is inlined in a loop, this is
1346 // problematic. Instead of just inlining the region, handle allocations by
1347 // hoisting fixed length allocations to the function entry and using
1348 // stacksave and restore for variable length ones.
1349 if (failed(inlineConvertOmpRegions(reductionDecls[i].getInitializerRegion(),
1350 "omp.reduction.neutral", builder,
1351 moduleTranslation, &phis)))
1352 return failure();
1353
1354 assert(phis.size() == 1 && "expected one value to be yielded from the "
1355 "reduction neutral element declaration region");
1356
1358
1359 if (isByRef[i]) {
1360 if (!reductionDecls[i].getAllocRegion().empty())
1361 // done in allocReductionVars
1362 continue;
1363
1364 // TODO: this path can be removed once all users of by-ref are updated to
1365 // use an alloc region
1366
1367 // Store the result of the inlined region to the allocated reduction var
1368 // ptr
1369 builder.CreateStore(phis[0], byRefVars[i]);
1370
1371 privateReductionVariables[i] = byRefVars[i];
1372 moduleTranslation.mapValue(reductionArgs[i], phis[0]);
1373 reductionVariableMap.try_emplace(op.getReductionVars()[i], phis[0]);
1374 } else {
1375 // for by-ref case the store is inside of the reduction region
1376 builder.CreateStore(phis[0], privateReductionVariables[i]);
1377 // the rest was handled in allocByValReductionVars
1378 }
1379
1380 // forget the mapping for the initializer region because we might need a
1381 // different mapping if this reduction declaration is re-used for a
1382 // different variable
1383 moduleTranslation.forgetMapping(reductionDecls[i].getInitializerRegion());
1384 }
1385
1386 return success();
1387}
1388
1389/// Collect reduction info
1390template <typename T>
1391static void collectReductionInfo(
1392 T loop, llvm::IRBuilderBase &builder,
1393 LLVM::ModuleTranslation &moduleTranslation,
1396 SmallVectorImpl<OwningAtomicReductionGen> &owningAtomicReductionGens,
1398 const ArrayRef<llvm::Value *> privateReductionVariables,
1400 ArrayRef<bool> isByRef) {
1401 unsigned numReductions = loop.getNumReductionVars();
1402
1403 for (unsigned i = 0; i < numReductions; ++i) {
1404 owningReductionGens.push_back(
1405 makeReductionGen(reductionDecls[i], builder, moduleTranslation));
1406 owningAtomicReductionGens.push_back(
1407 makeAtomicReductionGen(reductionDecls[i], builder, moduleTranslation));
1409 reductionDecls[i], builder, moduleTranslation, isByRef[i]));
1410 }
1411
1412 // Collect the reduction information.
1413 reductionInfos.reserve(numReductions);
1414 for (unsigned i = 0; i < numReductions; ++i) {
1415 llvm::OpenMPIRBuilder::ReductionGenAtomicCBTy atomicGen = nullptr;
1416 if (owningAtomicReductionGens[i])
1417 atomicGen = owningAtomicReductionGens[i];
1418 llvm::Value *variable =
1419 moduleTranslation.lookupValue(loop.getReductionVars()[i]);
1420 mlir::Type allocatedType;
1421 reductionDecls[i].getAllocRegion().walk([&](mlir::Operation *op) {
1422 if (auto alloca = mlir::dyn_cast<LLVM::AllocaOp>(op)) {
1423 allocatedType = alloca.getElemType();
1425 }
1426
1428 });
1429
1430 reductionInfos.push_back(
1431 {moduleTranslation.convertType(reductionDecls[i].getType()), variable,
1432 privateReductionVariables[i],
1433 /*EvaluationKind=*/llvm::OpenMPIRBuilder::EvalKind::Scalar,
1435 /*ReductionGenClang=*/nullptr, atomicGen,
1437 allocatedType ? moduleTranslation.convertType(allocatedType) : nullptr,
1438 reductionDecls[i].getByrefElementType()
1439 ? moduleTranslation.convertType(
1440 *reductionDecls[i].getByrefElementType())
1441 : nullptr});
1442 }
1443}
1444
1445/// handling of DeclareReductionOp's cleanup region
1446static LogicalResult
1448 llvm::ArrayRef<llvm::Value *> privateVariables,
1449 LLVM::ModuleTranslation &moduleTranslation,
1450 llvm::IRBuilderBase &builder, StringRef regionName,
1451 bool shouldLoadCleanupRegionArg = true) {
1452 for (auto [i, cleanupRegion] : llvm::enumerate(cleanupRegions)) {
1453 if (cleanupRegion->empty())
1454 continue;
1455
1456 // map the argument to the cleanup region
1457 Block &entry = cleanupRegion->front();
1458
1459 llvm::Instruction *potentialTerminator =
1460 builder.GetInsertBlock()->empty() ? nullptr
1461 : &builder.GetInsertBlock()->back();
1462 if (potentialTerminator && potentialTerminator->isTerminator())
1463 builder.SetInsertPoint(potentialTerminator);
1464 llvm::Value *privateVarValue =
1465 shouldLoadCleanupRegionArg
1466 ? builder.CreateLoad(
1467 moduleTranslation.convertType(entry.getArgument(0).getType()),
1468 privateVariables[i])
1469 : privateVariables[i];
1470
1471 moduleTranslation.mapValue(entry.getArgument(0), privateVarValue);
1472
1473 if (failed(inlineConvertOmpRegions(*cleanupRegion, regionName, builder,
1474 moduleTranslation)))
1475 return failure();
1476
1477 // clear block argument mapping in case it needs to be re-created with a
1478 // different source for another use of the same reduction decl
1479 moduleTranslation.forgetMapping(*cleanupRegion);
1480 }
1481 return success();
1482}
1483
1484// TODO: not used by ParallelOp
1485template <class OP>
1486static LogicalResult createReductionsAndCleanup(
1487 OP op, llvm::IRBuilderBase &builder,
1488 LLVM::ModuleTranslation &moduleTranslation,
1489 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1491 ArrayRef<llvm::Value *> privateReductionVariables, ArrayRef<bool> isByRef,
1492 bool isNowait = false, bool isTeamsReduction = false) {
1493 // Process the reductions if required.
1494 if (op.getNumReductionVars() == 0)
1495 return success();
1496
1498 SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
1499 SmallVector<OwningDataPtrPtrReductionGen> owningReductionGenRefDataPtrGens;
1501
1502 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
1503
1504 // Create the reduction generators. We need to own them here because
1505 // ReductionInfo only accepts references to the generators.
1506 collectReductionInfo(op, builder, moduleTranslation, reductionDecls,
1507 owningReductionGens, owningAtomicReductionGens,
1508 owningReductionGenRefDataPtrGens,
1509 privateReductionVariables, reductionInfos, isByRef);
1510
1511 // The call to createReductions below expects the block to have a
1512 // terminator. Create an unreachable instruction to serve as terminator
1513 // and remove it later.
1514 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
1515 builder.SetInsertPoint(tempTerminator);
1516 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
1517 ompBuilder->createReductions(builder.saveIP(), allocaIP, reductionInfos,
1518 isByRef, isNowait, isTeamsReduction);
1519
1520 if (failed(handleError(contInsertPoint, *op)))
1521 return failure();
1522
1523 if (!contInsertPoint->getBlock())
1524 return op->emitOpError() << "failed to convert reductions";
1525
1526 llvm::OpenMPIRBuilder::InsertPointTy afterIP = *contInsertPoint;
1527 if (!isTeamsReduction) {
1528 llvm::OpenMPIRBuilder::InsertPointOrErrorTy barrierIP =
1529 ompBuilder->createBarrier(*contInsertPoint, llvm::omp::OMPD_for);
1530
1531 if (failed(handleError(barrierIP, *op)))
1532 return failure();
1533 afterIP = *barrierIP;
1534 }
1535
1536 tempTerminator->eraseFromParent();
1537 builder.restoreIP(afterIP);
1538
1539 // after the construct, deallocate private reduction variables
1540 SmallVector<Region *> reductionRegions;
1541 llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
1542 [](omp::DeclareReductionOp reductionDecl) {
1543 return &reductionDecl.getCleanupRegion();
1544 });
1545 LogicalResult result = inlineOmpRegionCleanup(
1546 reductionRegions, privateReductionVariables, moduleTranslation, builder,
1547 "omp.reduction.cleanup");
1548
1549 bool useDeviceSharedMem = omp::opInSharedDeviceContext(*op);
1550 if (useDeviceSharedMem) {
1551 for (auto [var, reductionDecl] :
1552 llvm::zip_equal(privateReductionVariables, reductionDecls))
1553 ompBuilder->createOMPFreeShared(
1554 builder, var, moduleTranslation.convertType(reductionDecl.getType()));
1555 }
1556
1557 return result;
1558}
1559
1560static ArrayRef<bool> getIsByRef(std::optional<ArrayRef<bool>> attr) {
1561 if (!attr)
1562 return {};
1563 return *attr;
1564}
1565
1566// TODO: not used by omp.parallel
1567template <typename OP>
1569 OP op, ArrayRef<BlockArgument> reductionArgs, llvm::IRBuilderBase &builder,
1570 LLVM::ModuleTranslation &moduleTranslation,
1571 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1573 SmallVectorImpl<llvm::Value *> &privateReductionVariables,
1574 DenseMap<Value, llvm::Value *> &reductionVariableMap,
1575 llvm::ArrayRef<bool> isByRef) {
1576 if (op.getNumReductionVars() == 0)
1577 return success();
1578
1579 SmallVector<DeferredStore> deferredStores;
1580
1581 if (failed(allocReductionVars(op, reductionArgs, builder, moduleTranslation,
1582 allocaIP, reductionDecls,
1583 privateReductionVariables, reductionVariableMap,
1584 deferredStores, isByRef)))
1585 return failure();
1586
1587 return initReductionVars(op, reductionArgs, builder, moduleTranslation,
1588 allocaIP.getBlock(), reductionDecls,
1589 privateReductionVariables, reductionVariableMap,
1590 isByRef, deferredStores);
1591}
1592
1593/// Return the llvm::Value * corresponding to the `privateVar` that
1594/// is being privatized. It isn't always as simple as looking up
1595/// moduleTranslation with privateVar. For instance, in case of
1596/// an allocatable, the descriptor for the allocatable is privatized.
1597/// This descriptor is mapped using an MapInfoOp. So, this function
1598/// will return a pointer to the llvm::Value corresponding to the
1599/// block argument for the mapped descriptor.
1600static llvm::Value *
1601findAssociatedValue(Value privateVar, llvm::IRBuilderBase &builder,
1602 LLVM::ModuleTranslation &moduleTranslation,
1603 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1604 if (mappedPrivateVars == nullptr || !mappedPrivateVars->contains(privateVar))
1605 return moduleTranslation.lookupValue(privateVar);
1606
1607 Value blockArg = (*mappedPrivateVars)[privateVar];
1608 Type privVarType = privateVar.getType();
1609 Type blockArgType = blockArg.getType();
1610 assert(isa<LLVM::LLVMPointerType>(blockArgType) &&
1611 "A block argument corresponding to a mapped var should have "
1612 "!llvm.ptr type");
1613
1614 if (privVarType == blockArgType)
1615 return moduleTranslation.lookupValue(blockArg);
1616
1617 // This typically happens when the privatized type is lowered from
1618 // boxchar<KIND> and gets lowered to !llvm.struct<(ptr, i64)>. That is the
1619 // struct/pair is passed by value. But, mapped values are passed only as
1620 // pointers, so before we privatize, we must load the pointer.
1621 if (!isa<LLVM::LLVMPointerType>(privVarType))
1622 return builder.CreateLoad(moduleTranslation.convertType(privVarType),
1623 moduleTranslation.lookupValue(blockArg));
1624
1625 return moduleTranslation.lookupValue(privateVar);
1626}
1627
1628/// Initialize a single (first)private variable. You probably want to use
1629/// allocateAndInitPrivateVars instead of this.
1630/// This returns the private variable which has been initialized. This
1631/// variable should be mapped before constructing the body of the Op.
1633initPrivateVar(llvm::IRBuilderBase &builder,
1634 LLVM::ModuleTranslation &moduleTranslation,
1635 omp::PrivateClauseOp &privDecl, llvm::Value *nonPrivateVar,
1636 BlockArgument &blockArg, llvm::Value *llvmPrivateVar,
1637 llvm::BasicBlock *privInitBlock,
1638 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1639 Region &initRegion = privDecl.getInitRegion();
1640 if (initRegion.empty())
1641 return llvmPrivateVar;
1642
1643 assert(nonPrivateVar);
1644 moduleTranslation.mapValue(privDecl.getInitMoldArg(), nonPrivateVar);
1645 moduleTranslation.mapValue(privDecl.getInitPrivateArg(), llvmPrivateVar);
1646
1647 // in-place convert the private initialization region
1649 if (failed(inlineConvertOmpRegions(initRegion, "omp.private.init", builder,
1650 moduleTranslation, &phis)))
1651 return llvm::createStringError(
1652 "failed to inline `init` region of `omp.private`");
1653
1654 assert(phis.size() == 1 && "expected one allocation to be yielded");
1655
1656 // clear init region block argument mapping in case it needs to be
1657 // re-created with a different source for another use of the same
1658 // reduction decl
1659 moduleTranslation.forgetMapping(initRegion);
1660
1661 // Prefer the value yielded from the init region to the allocated private
1662 // variable in case the region is operating on arguments by-value (e.g.
1663 // Fortran character boxes).
1664 return phis[0];
1665}
1666
1667/// Version of initPrivateVar which looks up the nonPrivateVar from mlirPrivVar.
1669 llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation,
1670 omp::PrivateClauseOp &privDecl, Value mlirPrivVar, BlockArgument &blockArg,
1671 llvm::Value *llvmPrivateVar, llvm::BasicBlock *privInitBlock,
1672 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1673 return initPrivateVar(
1674 builder, moduleTranslation, privDecl,
1675 findAssociatedValue(mlirPrivVar, builder, moduleTranslation,
1676 mappedPrivateVars),
1677 blockArg, llvmPrivateVar, privInitBlock, mappedPrivateVars);
1678}
1679
1680static llvm::Error
1681initPrivateVars(llvm::IRBuilderBase &builder,
1682 LLVM::ModuleTranslation &moduleTranslation,
1683 PrivateVarsInfo &privateVarsInfo,
1684 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1685 if (privateVarsInfo.blockArgs.empty())
1686 return llvm::Error::success();
1687
1688 llvm::BasicBlock *privInitBlock = splitBB(builder, true, "omp.private.init");
1689 setInsertPointForPossiblyEmptyBlock(builder, privInitBlock);
1690
1691 for (auto [idx, zip] : llvm::enumerate(llvm::zip_equal(
1692 privateVarsInfo.privatizers, privateVarsInfo.mlirVars,
1693 privateVarsInfo.blockArgs, privateVarsInfo.llvmVars))) {
1694 auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVar] = zip;
1696 builder, moduleTranslation, privDecl, mlirPrivVar, blockArg,
1697 llvmPrivateVar, privInitBlock, mappedPrivateVars);
1698
1699 if (!privVarOrErr)
1700 return privVarOrErr.takeError();
1701
1702 llvmPrivateVar = privVarOrErr.get();
1703 moduleTranslation.mapValue(blockArg, llvmPrivateVar);
1704
1706 }
1707
1708 return llvm::Error::success();
1709}
1710
1711/// Allocate and initialize delayed private variables. Returns the basic block
1712/// which comes after all of these allocations. llvm::Value * for each of these
1713/// private variables are populated in llvmPrivateVars.
1714template <typename T>
1716allocatePrivateVars(T op, llvm::IRBuilderBase &builder,
1717 LLVM::ModuleTranslation &moduleTranslation,
1718 PrivateVarsInfo &privateVarsInfo,
1719 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1720 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1721 // Allocate private vars
1722 llvm::Instruction *allocaTerminator = allocaIP.getBlock()->getTerminator();
1723 splitBB(llvm::OpenMPIRBuilder::InsertPointTy(allocaIP.getBlock(),
1724 allocaTerminator->getIterator()),
1725 true, allocaTerminator->getStableDebugLoc(),
1726 "omp.region.after_alloca");
1727
1728 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1729 // Update the allocaTerminator since the alloca block was split above.
1730 allocaTerminator = allocaIP.getBlock()->getTerminator();
1731 builder.SetInsertPoint(allocaTerminator);
1732 // The new terminator is an uncondition branch created by the splitBB above.
1733 assert(allocaTerminator->getNumSuccessors() == 1 &&
1734 "This is an unconditional branch created by splitBB");
1735
1736 llvm::DataLayout dataLayout = builder.GetInsertBlock()->getDataLayout();
1737 llvm::BasicBlock *afterAllocas = allocaTerminator->getSuccessor(0);
1738
1739 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
1740 bool mightUseDeviceSharedMem = omp::opInSharedDeviceContext(*op);
1741 unsigned int allocaAS =
1742 moduleTranslation.getLLVMModule()->getDataLayout().getAllocaAddrSpace();
1743 unsigned int defaultAS = moduleTranslation.getLLVMModule()
1744 ->getDataLayout()
1745 .getProgramAddressSpace();
1746
1747 for (auto [privDecl, mlirPrivVar, blockArg] :
1748 llvm::zip_equal(privateVarsInfo.privatizers, privateVarsInfo.mlirVars,
1749 privateVarsInfo.blockArgs)) {
1750 llvm::Type *llvmAllocType =
1751 moduleTranslation.convertType(privDecl.getType());
1752 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1753 llvm::Value *llvmPrivateVar = nullptr;
1754 if (mightUseDeviceSharedMem && omp::allocaUsesRequireSharedMem(blockArg)) {
1755 llvmPrivateVar = ompBuilder->createOMPAllocShared(builder, llvmAllocType);
1756 } else {
1757 llvmPrivateVar = builder.CreateAlloca(
1758 llvmAllocType, /*ArraySize=*/nullptr, "omp.private.alloc");
1759 if (allocaAS != defaultAS)
1760 llvmPrivateVar = builder.CreateAddrSpaceCast(
1761 llvmPrivateVar, builder.getPtrTy(defaultAS));
1762 }
1763
1764 privateVarsInfo.llvmVars.push_back(llvmPrivateVar);
1765 }
1766
1767 return afterAllocas;
1768}
1769
1770/// This can't always be determined statically, but when we can, it is good to
1771/// avoid generating compiler-added barriers which will deadlock the program.
1773 for (mlir::Operation *parent = op->getParentOp(); parent != nullptr;
1774 parent = parent->getParentOp()) {
1775 if (mlir::isa<omp::SingleOp, omp::CriticalOp>(parent))
1776 return true;
1777
1778 // e.g.
1779 // omp.single {
1780 // omp.parallel {
1781 // op
1782 // }
1783 // }
1784 if (mlir::isa<omp::ParallelOp>(parent))
1785 return false;
1786 }
1787 return false;
1788}
1789
1790static LogicalResult copyFirstPrivateVars(
1791 mlir::Operation *op, llvm::IRBuilderBase &builder,
1792 LLVM::ModuleTranslation &moduleTranslation,
1794 ArrayRef<llvm::Value *> llvmPrivateVars,
1795 SmallVectorImpl<omp::PrivateClauseOp> &privateDecls, bool insertBarrier,
1796 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1797 // Apply copy region for firstprivate.
1798 bool needsFirstprivate =
1799 llvm::any_of(privateDecls, [](omp::PrivateClauseOp &privOp) {
1800 return privOp.getDataSharingType() ==
1801 omp::DataSharingClauseType::FirstPrivate;
1802 });
1803
1804 if (!needsFirstprivate)
1805 return success();
1806
1807 llvm::BasicBlock *copyBlock =
1808 splitBB(builder, /*CreateBranch=*/true, "omp.private.copy");
1809 setInsertPointForPossiblyEmptyBlock(builder, copyBlock);
1810
1811 for (auto [decl, moldVar, llvmVar] :
1812 llvm::zip_equal(privateDecls, moldVars, llvmPrivateVars)) {
1813 if (decl.getDataSharingType() != omp::DataSharingClauseType::FirstPrivate)
1814 continue;
1815
1816 // copyRegion implements `lhs = rhs`
1817 Region &copyRegion = decl.getCopyRegion();
1818
1819 moduleTranslation.mapValue(decl.getCopyMoldArg(), moldVar);
1820
1821 // map copyRegion lhs arg
1822 moduleTranslation.mapValue(decl.getCopyPrivateArg(), llvmVar);
1823
1824 // in-place convert copy region
1825 if (failed(inlineConvertOmpRegions(copyRegion, "omp.private.copy", builder,
1826 moduleTranslation)))
1827 return decl.emitError("failed to inline `copy` region of `omp.private`");
1828
1830
1831 // ignore unused value yielded from copy region
1832
1833 // clear copy region block argument mapping in case it needs to be
1834 // re-created with different sources for reuse of the same reduction
1835 // decl
1836 moduleTranslation.forgetMapping(copyRegion);
1837 }
1838
1839 if (insertBarrier && !opIsInSingleThread(op)) {
1840 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
1841 llvm::OpenMPIRBuilder::InsertPointOrErrorTy res =
1842 ompBuilder->createBarrier(builder.saveIP(), llvm::omp::OMPD_barrier);
1843 if (failed(handleError(res, *op)))
1844 return failure();
1845 }
1846
1847 return success();
1848}
1849
1850static LogicalResult copyFirstPrivateVars(
1851 mlir::Operation *op, llvm::IRBuilderBase &builder,
1852 LLVM::ModuleTranslation &moduleTranslation,
1853 SmallVectorImpl<mlir::Value> &mlirPrivateVars,
1854 ArrayRef<llvm::Value *> llvmPrivateVars,
1855 SmallVectorImpl<omp::PrivateClauseOp> &privateDecls, bool insertBarrier,
1856 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1857 llvm::SmallVector<llvm::Value *> moldVars(mlirPrivateVars.size());
1858 llvm::transform(mlirPrivateVars, moldVars.begin(), [&](mlir::Value mlirVar) {
1859 // map copyRegion rhs arg
1860 llvm::Value *moldVar = findAssociatedValue(
1861 mlirVar, builder, moduleTranslation, mappedPrivateVars);
1862 assert(moldVar);
1863 return moldVar;
1864 });
1865 return copyFirstPrivateVars(op, builder, moduleTranslation, moldVars,
1866 llvmPrivateVars, privateDecls, insertBarrier,
1867 mappedPrivateVars);
1868}
1869
1870template <typename T>
1871static LogicalResult
1872cleanupPrivateVars(T op, llvm::IRBuilderBase &builder,
1873 LLVM::ModuleTranslation &moduleTranslation, Location loc,
1874 PrivateVarsInfo &privateVarsInfo) {
1875 // private variable deallocation
1876 SmallVector<Region *> privateCleanupRegions;
1877 llvm::transform(privateVarsInfo.privatizers,
1878 std::back_inserter(privateCleanupRegions),
1879 [](omp::PrivateClauseOp privatizer) {
1880 return &privatizer.getDeallocRegion();
1881 });
1882
1883 if (failed(inlineOmpRegionCleanup(privateCleanupRegions,
1884 privateVarsInfo.llvmVars, moduleTranslation,
1885 builder, "omp.private.dealloc",
1886 /*shouldLoadCleanupRegionArg=*/false)))
1887 return mlir::emitError(loc, "failed to inline `dealloc` region of an "
1888 "`omp.private` op in");
1889
1890 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
1891 bool mightUseDeviceSharedMem = omp::opInSharedDeviceContext(*op);
1892 for (auto [privDecl, llvmPrivVar, blockArg] :
1893 llvm::zip_equal(privateVarsInfo.privatizers, privateVarsInfo.llvmVars,
1894 privateVarsInfo.blockArgs)) {
1895 if (mightUseDeviceSharedMem && omp::allocaUsesRequireSharedMem(blockArg)) {
1896 ompBuilder->createOMPFreeShared(
1897 builder, llvmPrivVar,
1898 moduleTranslation.convertType(privDecl.getType()));
1899 }
1900 }
1901
1902 return success();
1903}
1904
1905/// Returns true if the construct contains omp.cancel or omp.cancellation_point
1907 // omp.cancel and omp.cancellation_point must be "closely nested" so they will
1908 // be visible and not inside of function calls. This is enforced by the
1909 // verifier.
1910 return op
1911 ->walk([](Operation *child) {
1912 if (mlir::isa<omp::CancelOp, omp::CancellationPointOp>(child))
1913 return WalkResult::interrupt();
1914 return WalkResult::advance();
1915 })
1916 .wasInterrupted();
1917}
1918
1919static LogicalResult
1920convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder,
1921 LLVM::ModuleTranslation &moduleTranslation) {
1922 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1923 using StorableBodyGenCallbackTy =
1924 llvm::OpenMPIRBuilder::StorableBodyGenCallbackTy;
1925
1926 auto sectionsOp = cast<omp::SectionsOp>(opInst);
1927
1928 if (failed(checkImplementationStatus(opInst)))
1929 return failure();
1930
1931 llvm::ArrayRef<bool> isByRef = getIsByRef(sectionsOp.getReductionByref());
1932 assert(isByRef.size() == sectionsOp.getNumReductionVars());
1933
1935 collectReductionDecls(sectionsOp, reductionDecls);
1936 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1937 findAllocInsertPoints(builder, moduleTranslation);
1938
1939 SmallVector<llvm::Value *> privateReductionVariables(
1940 sectionsOp.getNumReductionVars());
1941 DenseMap<Value, llvm::Value *> reductionVariableMap;
1942
1943 MutableArrayRef<BlockArgument> reductionArgs =
1944 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
1945
1947 sectionsOp, reductionArgs, builder, moduleTranslation, allocaIP,
1948 reductionDecls, privateReductionVariables, reductionVariableMap,
1949 isByRef)))
1950 return failure();
1951
1953
1954 for (Operation &op : *sectionsOp.getRegion().begin()) {
1955 auto sectionOp = dyn_cast<omp::SectionOp>(op);
1956 if (!sectionOp) // omp.terminator
1957 continue;
1958
1959 Region &region = sectionOp.getRegion();
1960 auto sectionCB = [&sectionsOp, &region, &builder, &moduleTranslation](
1961 InsertPointTy allocaIP, InsertPointTy codeGenIP,
1962 ArrayRef<llvm::BasicBlock *> deallocBlocks) {
1963 builder.restoreIP(codeGenIP);
1964
1965 // map the omp.section reduction block argument to the omp.sections block
1966 // arguments
1967 // TODO: this assumes that the only block arguments are reduction
1968 // variables
1969 assert(region.getNumArguments() ==
1970 sectionsOp.getRegion().getNumArguments());
1971 for (auto [sectionsArg, sectionArg] : llvm::zip_equal(
1972 sectionsOp.getRegion().getArguments(), region.getArguments())) {
1973 llvm::Value *llvmVal = moduleTranslation.lookupValue(sectionsArg);
1974 assert(llvmVal);
1975 moduleTranslation.mapValue(sectionArg, llvmVal);
1976 }
1977
1978 return convertOmpOpRegions(region, "omp.section.region", builder,
1979 moduleTranslation)
1980 .takeError();
1981 };
1982 sectionCBs.push_back(sectionCB);
1983 }
1984
1985 // No sections within omp.sections operation - skip generation. This situation
1986 // is only possible if there is only a terminator operation inside the
1987 // sections operation
1988 if (sectionCBs.empty())
1989 return success();
1990
1991 assert(isa<omp::SectionOp>(*sectionsOp.getRegion().op_begin()));
1992
1993 // TODO: Perform appropriate actions according to the data-sharing
1994 // attribute (shared, private, firstprivate, ...) of variables.
1995 // Currently defaults to shared.
1996 auto privCB = [&](InsertPointTy, InsertPointTy codeGenIP, llvm::Value &,
1997 llvm::Value &vPtr, llvm::Value *&replacementValue)
1998 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1999 replacementValue = &vPtr;
2000 return codeGenIP;
2001 };
2002
2003 // TODO: Perform finalization actions for variables. This has to be
2004 // called for variables which have destructors/finalizers.
2005 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
2006
2007 allocaIP = findAllocInsertPoints(builder, moduleTranslation);
2008 bool isCancellable = constructIsCancellable(sectionsOp);
2009 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2010 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2011 moduleTranslation.getOpenMPBuilder()->createSections(
2012 ompLoc, allocaIP, sectionCBs, privCB, finiCB, isCancellable,
2013 sectionsOp.getNowait());
2014
2015 if (failed(handleError(afterIP, opInst)))
2016 return failure();
2017
2018 builder.restoreIP(*afterIP);
2019
2020 // Process the reductions if required.
2022 sectionsOp, builder, moduleTranslation, allocaIP, reductionDecls,
2023 privateReductionVariables, isByRef, sectionsOp.getNowait());
2024}
2025
2026/// Converts an OpenMP scope construct into LLVM IR.
2027static LogicalResult
2028convertOmpScope(omp::ScopeOp &scopeOp, llvm::IRBuilderBase &builder,
2029 LLVM::ModuleTranslation &moduleTranslation) {
2030 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2031 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
2032
2033 if (failed(checkImplementationStatus(*scopeOp)))
2034 return failure();
2035
2036 llvm::ArrayRef<bool> isByRef = getIsByRef(scopeOp.getReductionByref());
2037 assert(isByRef.size() == scopeOp.getNumReductionVars());
2038
2039 PrivateVarsInfo privateVarsInfo(scopeOp);
2040
2042 collectReductionDecls(scopeOp, reductionDecls);
2043 InsertPointTy allocaIP = findAllocInsertPoints(builder, moduleTranslation);
2044
2045 SmallVector<llvm::Value *> privateReductionVariables(
2046 scopeOp.getNumReductionVars());
2047 DenseMap<Value, llvm::Value *> reductionVariableMap;
2048
2049 MutableArrayRef<BlockArgument> reductionArgs =
2050 cast<omp::BlockArgOpenMPOpInterface>(*scopeOp).getReductionBlockArgs();
2051
2052 // Allocate private vars before the scope body
2054 scopeOp, builder, moduleTranslation, privateVarsInfo, allocaIP);
2055 if (failed(handleError(afterAllocas, *scopeOp)))
2056 return failure();
2057
2059 scopeOp, reductionArgs, builder, moduleTranslation, allocaIP,
2060 reductionDecls, privateReductionVariables, reductionVariableMap,
2061 isByRef)))
2062 return failure();
2063
2064 auto bodyCB =
2065 [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
2066 llvm::ArrayRef<llvm::BasicBlock *> deallocBlocks) -> llvm::Error {
2067 builder.restoreIP(codeGenIP);
2068
2069 if (handleError(
2070 initPrivateVars(builder, moduleTranslation, privateVarsInfo),
2071 *scopeOp)
2072 .failed())
2073 return llvm::make_error<PreviouslyReportedError>();
2074
2075 if (failed(copyFirstPrivateVars(
2076 scopeOp, builder, moduleTranslation, privateVarsInfo.mlirVars,
2077 privateVarsInfo.llvmVars, privateVarsInfo.privatizers,
2078 scopeOp.getPrivateNeedsBarrier())))
2079 return llvm::make_error<PreviouslyReportedError>();
2080
2081 return convertOmpOpRegions(scopeOp.getRegion(), "omp.scope.region", builder,
2082 moduleTranslation)
2083 .takeError();
2084 };
2085
2086 auto finiCB = [&](InsertPointTy codeGenIP) -> llvm::Error {
2087 InsertPointTy oldIP = builder.saveIP();
2088 builder.restoreIP(codeGenIP);
2089 if (failed(cleanupPrivateVars(scopeOp, builder, moduleTranslation,
2090 scopeOp.getLoc(), privateVarsInfo)))
2091 return llvm::make_error<PreviouslyReportedError>();
2092 builder.restoreIP(oldIP);
2093 return llvm::Error::success();
2094 };
2095
2096 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2097 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2098 ompBuilder->createScope(ompLoc, bodyCB, finiCB, scopeOp.getNowait());
2099
2100 if (failed(handleError(afterIP, *scopeOp)))
2101 return failure();
2102
2103 builder.restoreIP(*afterIP);
2104
2105 // Process the reductions if required.
2107 scopeOp, builder, moduleTranslation, allocaIP, reductionDecls,
2108 privateReductionVariables, isByRef, scopeOp.getNowait(),
2109 /*isTeamsReduction=*/false);
2110}
2111
2112/// Converts an OpenMP single construct into LLVM IR using OpenMPIRBuilder.
2113static LogicalResult
2114convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder,
2115 LLVM::ModuleTranslation &moduleTranslation) {
2116 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2117 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2118
2119 if (failed(checkImplementationStatus(*singleOp)))
2120 return failure();
2121
2122 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP,
2123 llvm::ArrayRef<llvm::BasicBlock *> deallocBlocks) {
2124 builder.restoreIP(codegenIP);
2125 return convertOmpOpRegions(singleOp.getRegion(), "omp.single.region",
2126 builder, moduleTranslation)
2127 .takeError();
2128 };
2129 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
2130
2131 // Handle copyprivate
2132 Operation::operand_range cpVars = singleOp.getCopyprivateVars();
2133 std::optional<ArrayAttr> cpFuncs = singleOp.getCopyprivateSyms();
2136 for (size_t i = 0, e = cpVars.size(); i < e; ++i) {
2137 llvmCPVars.push_back(moduleTranslation.lookupValue(cpVars[i]));
2139 singleOp, cast<SymbolRefAttr>((*cpFuncs)[i]));
2140 llvmCPFuncs.push_back(
2141 moduleTranslation.lookupFunction(llvmFuncOp.getName()));
2142 }
2143
2144 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2145 moduleTranslation.getOpenMPBuilder()->createSingle(
2146 ompLoc, bodyCB, finiCB, singleOp.getNowait(), llvmCPVars,
2147 llvmCPFuncs);
2148
2149 if (failed(handleError(afterIP, *singleOp)))
2150 return failure();
2151
2152 builder.restoreIP(*afterIP);
2153 return success();
2154}
2155
2156static omp::DistributeOp
2158 // Early return if we found more than one distribute op or if we can't find
2159 // any distribute op in the teams region.
2160 omp::DistributeOp distOp;
2161 WalkResult walk = teamsOp.getRegion().walk([&](omp::DistributeOp op) {
2162 if (distOp)
2163 return WalkResult::interrupt();
2164 distOp = op;
2165 return WalkResult::skip();
2166 });
2167 if (walk.wasInterrupted() || !distOp)
2168 return {};
2169
2170 auto iface =
2171 llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(teamsOp.getOperation());
2172 // Check that all uses of the reduction block arg has the same distribute op
2173 // parent.
2175 for (auto ra : iface.getReductionBlockArgs())
2176 for (auto &use : ra.getUses()) {
2177 auto *useOp = use.getOwner();
2178 // Ignore debug uses.
2179 if (mlir::isa<LLVM::DbgDeclareOp, LLVM::DbgValueOp>(useOp)) {
2180 debugUses.push_back(useOp);
2181 continue;
2182 }
2183 if (!distOp->isProperAncestor(useOp))
2184 return {};
2185 }
2186
2187 // If we are going to use distribute reduction then remove any debug uses of
2188 // the reduction parameters in teamsOp. Otherwise they will be left without
2189 // any mapped value in moduleTranslation and will eventually error out.
2190 for (auto *use : debugUses)
2191 use->erase();
2192 return distOp;
2193}
2194
2195// Convert an OpenMP Teams construct to LLVM IR using OpenMPIRBuilder
2196static LogicalResult
2197convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder,
2198 LLVM::ModuleTranslation &moduleTranslation) {
2199 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2200 if (failed(checkImplementationStatus(*op)))
2201 return failure();
2202
2203 DenseMap<Value, llvm::Value *> reductionVariableMap;
2204 unsigned numReductionVars = op.getNumReductionVars();
2206 SmallVector<llvm::Value *> privateReductionVariables(numReductionVars);
2207 llvm::ArrayRef<bool> isByRef;
2208 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2209 findAllocInsertPoints(builder, moduleTranslation);
2210
2211 // Only do teams reduction if there is no distribute op that captures the
2212 // reduction instead.
2213 bool doTeamsReduction = !getDistributeCapturingTeamsReduction(op);
2214 if (doTeamsReduction) {
2215 isByRef = getIsByRef(op.getReductionByref());
2216
2217 assert(isByRef.size() == op.getNumReductionVars());
2218
2219 MutableArrayRef<BlockArgument> reductionArgs =
2220 llvm::cast<omp::BlockArgOpenMPOpInterface>(*op).getReductionBlockArgs();
2221
2222 collectReductionDecls(op, reductionDecls);
2223
2225 op, reductionArgs, builder, moduleTranslation, allocaIP,
2226 reductionDecls, privateReductionVariables, reductionVariableMap,
2227 isByRef)))
2228 return failure();
2229 }
2230
2231 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP,
2232 llvm::ArrayRef<llvm::BasicBlock *> deallocBlocks) {
2234 moduleTranslation, allocaIP, deallocBlocks);
2235 builder.restoreIP(codegenIP);
2236 return convertOmpOpRegions(op.getRegion(), "omp.teams.region", builder,
2237 moduleTranslation)
2238 .takeError();
2239 };
2240
2241 llvm::Value *numTeamsLower = nullptr;
2242 if (Value numTeamsLowerVar = op.getNumTeamsLower())
2243 numTeamsLower = moduleTranslation.lookupValue(numTeamsLowerVar);
2244
2245 llvm::Value *numTeamsUpper = nullptr;
2246 if (!op.getNumTeamsUpperVars().empty())
2247 numTeamsUpper = moduleTranslation.lookupValue(op.getNumTeams(0));
2248
2249 llvm::Value *threadLimit = nullptr;
2250 if (!op.getThreadLimitVars().empty())
2251 threadLimit = moduleTranslation.lookupValue(op.getThreadLimit(0));
2252
2253 llvm::Value *ifExpr = nullptr;
2254 if (Value ifVar = op.getIfExpr())
2255 ifExpr = moduleTranslation.lookupValue(ifVar);
2256
2257 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2258 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2259 moduleTranslation.getOpenMPBuilder()->createTeams(
2260 ompLoc, bodyCB, numTeamsLower, numTeamsUpper, threadLimit, ifExpr);
2261
2262 if (failed(handleError(afterIP, *op)))
2263 return failure();
2264
2265 builder.restoreIP(*afterIP);
2266 if (doTeamsReduction) {
2267 // Process the reductions if required.
2269 op, builder, moduleTranslation, allocaIP, reductionDecls,
2270 privateReductionVariables, isByRef,
2271 /*isNoWait*/ false, /*isTeamsReduction*/ true);
2272 }
2273 return success();
2274}
2275
2276static llvm::omp::RTLDependenceKindTy
2277convertDependKind(mlir::omp::ClauseTaskDepend kind) {
2278 switch (kind) {
2279 case mlir::omp::ClauseTaskDepend::taskdependin:
2280 return llvm::omp::RTLDependenceKindTy::DepIn;
2281 // The OpenMP runtime requires that the codegen for 'depend' clause for
2282 // 'out' dependency kind must be the same as codegen for 'depend' clause
2283 // with 'inout' dependency.
2284 case mlir::omp::ClauseTaskDepend::taskdependout:
2285 case mlir::omp::ClauseTaskDepend::taskdependinout:
2286 return llvm::omp::RTLDependenceKindTy::DepInOut;
2287 case mlir::omp::ClauseTaskDepend::taskdependmutexinoutset:
2288 return llvm::omp::RTLDependenceKindTy::DepMutexInOutSet;
2289 case mlir::omp::ClauseTaskDepend::taskdependinoutset:
2290 return llvm::omp::RTLDependenceKindTy::DepInOutSet;
2291 }
2292 llvm_unreachable("unhandled depend kind");
2293}
2294
2296 std::optional<ArrayAttr> dependKinds, OperandRange dependVars,
2297 LLVM::ModuleTranslation &moduleTranslation,
2299 if (dependVars.empty())
2300 return;
2301 for (auto dep : llvm::zip(dependVars, dependKinds->getValue())) {
2302 auto kind =
2303 cast<mlir::omp::ClauseTaskDependAttr>(std::get<1>(dep)).getValue();
2304 llvm::omp::RTLDependenceKindTy type = convertDependKind(kind);
2305 llvm::Value *depVal = moduleTranslation.lookupValue(std::get<0>(dep));
2306 llvm::OpenMPIRBuilder::DependData dd(type, depVal->getType(), depVal);
2307 dds.emplace_back(dd);
2308 }
2309}
2310
2311/// Shared implementation of a callback which adds a termiator for the new block
2312/// created for the branch taken when an openmp construct is cancelled. The
2313/// terminator is saved in \p cancelTerminators. This callback is invoked only
2314/// if there is cancellation inside of the taskgroup body.
2315/// The terminator will need to be fixed to branch to the correct block to
2316/// cleanup the construct.
2318 SmallVectorImpl<llvm::UncondBrInst *> &cancelTerminators,
2319 llvm::IRBuilderBase &llvmBuilder, llvm::OpenMPIRBuilder &ompBuilder,
2320 mlir::Operation *op, llvm::omp::Directive cancelDirective) {
2321 auto finiCB = [&](llvm::OpenMPIRBuilder::InsertPointTy ip) -> llvm::Error {
2322 llvm::IRBuilderBase::InsertPointGuard guard(llvmBuilder);
2323
2324 // ip is currently in the block branched to if cancellation occurred.
2325 // We need to create a branch to terminate that block.
2326 llvmBuilder.restoreIP(ip);
2327
2328 // We must still clean up the construct after cancelling it, so we need to
2329 // branch to the block that finalizes the taskgroup.
2330 // That block has not been created yet so use this block as a dummy for now
2331 // and fix this after creating the operation.
2332 cancelTerminators.push_back(llvmBuilder.CreateBr(ip.getBlock()));
2333 return llvm::Error::success();
2334 };
2335 // We have to add the cleanup to the OpenMPIRBuilder before the body gets
2336 // created in case the body contains omp.cancel (which will then expect to be
2337 // able to find this cleanup callback).
2338 ompBuilder.pushFinalizationCB(
2339 {finiCB, cancelDirective, constructIsCancellable(op)});
2340}
2341
2342/// If we cancelled the construct, we should branch to the finalization block of
2343/// that construct. OMPIRBuilder structures the CFG such that the cleanup block
2344/// is immediately before the continuation block. Now this finalization has
2345/// been created we can fix the branch.
2346static void
2348 llvm::OpenMPIRBuilder &ompBuilder,
2349 const llvm::OpenMPIRBuilder::InsertPointTy &afterIP) {
2350 ompBuilder.popFinalizationCB();
2351 llvm::BasicBlock *constructFini = afterIP.getBlock()->getSinglePredecessor();
2352 for (llvm::UncondBrInst *cancelBranch : cancelTerminators)
2353 cancelBranch->setSuccessor(constructFini);
2354}
2355
2356namespace {
2357/// TaskContextStructManager takes care of creating and freeing a structure
2358/// containing information needed by the task body to execute.
2359class TaskContextStructManager {
2360public:
2361 TaskContextStructManager(llvm::IRBuilderBase &builder,
2362 LLVM::ModuleTranslation &moduleTranslation,
2363 MutableArrayRef<omp::PrivateClauseOp> privateDecls)
2364 : builder{builder}, moduleTranslation{moduleTranslation},
2365 privateDecls{privateDecls} {}
2366
2367 /// Creates a heap allocated struct containing space for each private
2368 /// variable. Invariant: privateVarTypes, privateDecls, and the elements of
2369 /// the structure should all have the same order (although privateDecls which
2370 /// do not read from the mold argument are skipped).
2371 void generateTaskContextStruct();
2372
2373 /// Create GEPs to access each member of the structure representing a private
2374 /// variable, adding them to llvmPrivateVars. Null values are added where
2375 /// private decls were skipped so that the ordering continues to match the
2376 /// private decls.
2377 void createGEPsToPrivateVars();
2378
2379 /// Given the address of the structure, return a GEP for each private variable
2380 /// in the structure. Null values are added where private decls were skipped
2381 /// so that the ordering continues to match the private decls.
2382 /// Must be called after generateTaskContextStruct().
2383 SmallVector<llvm::Value *>
2384 createGEPsToPrivateVars(llvm::Value *altStructPtr) const;
2385
2386 /// De-allocate the task context structure.
2387 void freeStructPtr();
2388
2389 MutableArrayRef<llvm::Value *> getLLVMPrivateVarGEPs() {
2390 return llvmPrivateVarGEPs;
2391 }
2392
2393 llvm::Value *getStructPtr() { return structPtr; }
2394
2395private:
2396 llvm::IRBuilderBase &builder;
2397 LLVM::ModuleTranslation &moduleTranslation;
2398 MutableArrayRef<omp::PrivateClauseOp> privateDecls;
2399
2400 /// The type of each member of the structure, in order.
2401 SmallVector<llvm::Type *> privateVarTypes;
2402
2403 /// LLVM values for each private variable, or null if that private variable is
2404 /// not included in the task context structure
2405 SmallVector<llvm::Value *> llvmPrivateVarGEPs;
2406
2407 /// A pointer to the structure containing context for this task.
2408 llvm::Value *structPtr = nullptr;
2409 /// The type of the structure
2410 llvm::Type *structTy = nullptr;
2411};
2412
2413/// IteratorInfo extracts and prepares loop bounds information from an
2414/// mlir::omp::IteratorOp for lowering to LLVM IR.
2415///
2416/// It computes the per-dimension trip counts and the total linearized trip
2417/// count, casted to i64. These are used to build a canonical loop and to
2418/// reconstruct the physical induction variables inside the loop body.
2419class IteratorInfo {
2420private:
2421 llvm::SmallVector<llvm::Value *> lowerBounds;
2422 llvm::SmallVector<llvm::Value *> upperBounds;
2423 llvm::SmallVector<llvm::Value *> steps;
2424 llvm::SmallVector<llvm::Value *> trips;
2425 unsigned dims;
2426 llvm::Value *totalTrips;
2427
2428 llvm::Value *lookUpAsI64(mlir::Value val, const LLVM::ModuleTranslation &mt,
2429 llvm::IRBuilderBase &builder) {
2430 llvm::Value *v = mt.lookupValue(val);
2431 if (!v)
2432 return nullptr;
2433 if (v->getType()->isIntegerTy(64))
2434 return v;
2435 if (v->getType()->isIntegerTy())
2436 return builder.CreateSExtOrTrunc(v, builder.getInt64Ty());
2437 return nullptr;
2438 }
2439
2440public:
2441 IteratorInfo(mlir::omp::IteratorOp itersOp,
2442 mlir::LLVM::ModuleTranslation &moduleTranslation,
2443 llvm::IRBuilderBase &builder) {
2444 dims = itersOp.getLoopLowerBounds().size();
2445 lowerBounds.resize(dims);
2446 upperBounds.resize(dims);
2447 steps.resize(dims);
2448 trips.resize(dims);
2449
2450 for (unsigned d = 0; d < dims; ++d) {
2451 llvm::Value *lb = lookUpAsI64(itersOp.getLoopLowerBounds()[d],
2452 moduleTranslation, builder);
2453 llvm::Value *ub = lookUpAsI64(itersOp.getLoopUpperBounds()[d],
2454 moduleTranslation, builder);
2455 llvm::Value *st =
2456 lookUpAsI64(itersOp.getLoopSteps()[d], moduleTranslation, builder);
2457 assert(lb && ub && st &&
2458 "Expect lowerBounds, upperBounds, and steps in IteratorOp");
2459 assert((!llvm::isa<llvm::ConstantInt>(st) ||
2460 !llvm::cast<llvm::ConstantInt>(st)->isZero()) &&
2461 "Expect non-zero step in IteratorOp");
2462
2463 lowerBounds[d] = lb;
2464 upperBounds[d] = ub;
2465 steps[d] = st;
2466
2467 // trips = ((ub - lb) / step) + 1 (inclusive ub, assume positive step)
2468 llvm::Value *diff = builder.CreateSub(ub, lb);
2469 llvm::Value *div = builder.CreateSDiv(diff, st);
2470 trips[d] = builder.CreateAdd(
2471 div, llvm::ConstantInt::get(builder.getInt64Ty(), 1));
2472 }
2473
2474 totalTrips = llvm::ConstantInt::get(builder.getInt64Ty(), 1);
2475 for (unsigned d = 0; d < dims; ++d)
2476 totalTrips = builder.CreateMul(totalTrips, trips[d]);
2477 }
2478
2479 unsigned getDims() const { return dims; }
2480 llvm::ArrayRef<llvm::Value *> getLowerBounds() const { return lowerBounds; }
2481 llvm::ArrayRef<llvm::Value *> getUpperBounds() const { return upperBounds; }
2482 llvm::ArrayRef<llvm::Value *> getSteps() const { return steps; }
2483 llvm::ArrayRef<llvm::Value *> getTrips() const { return trips; }
2484 llvm::Value *getTotalTrips() const { return totalTrips; }
2485};
2486
2487} // namespace
2488
2489void TaskContextStructManager::generateTaskContextStruct() {
2490 if (privateDecls.empty())
2491 return;
2492 privateVarTypes.reserve(privateDecls.size());
2493
2494 for (omp::PrivateClauseOp &privOp : privateDecls) {
2495 // Skip private variables which can safely be allocated and initialised
2496 // inside of the task
2497 if (!privOp.readsFromMold())
2498 continue;
2499 Type mlirType = privOp.getType();
2500 privateVarTypes.push_back(moduleTranslation.convertType(mlirType));
2501 }
2502
2503 if (privateVarTypes.empty())
2504 return;
2505
2506 structTy = llvm::StructType::get(moduleTranslation.getLLVMContext(),
2507 privateVarTypes);
2508
2509 llvm::DataLayout dataLayout =
2510 builder.GetInsertBlock()->getModule()->getDataLayout();
2511 llvm::Type *intPtrTy = builder.getIntPtrTy(dataLayout);
2512 llvm::Constant *allocSize = llvm::ConstantExpr::getSizeOf(structTy);
2513
2514 // Heap allocate the structure
2515 structPtr = builder.CreateMalloc(intPtrTy, structTy, allocSize,
2516 /*ArraySize=*/nullptr, /*MallocF=*/nullptr,
2517 "omp.task.context_ptr");
2518}
2519
2520SmallVector<llvm::Value *> TaskContextStructManager::createGEPsToPrivateVars(
2521 llvm::Value *altStructPtr) const {
2522 SmallVector<llvm::Value *> ret;
2523
2524 // Create GEPs for each struct member
2525 ret.reserve(privateDecls.size());
2526 llvm::Value *zero = builder.getInt32(0);
2527 unsigned i = 0;
2528 for (auto privDecl : privateDecls) {
2529 if (!privDecl.readsFromMold()) {
2530 // Handle this inside of the task so we don't pass unnessecary vars in
2531 ret.push_back(nullptr);
2532 continue;
2533 }
2534 llvm::Value *iVal = builder.getInt32(i);
2535 llvm::Value *gep = builder.CreateGEP(structTy, altStructPtr, {zero, iVal});
2536 ret.push_back(gep);
2537 i += 1;
2538 }
2539 return ret;
2540}
2541
2542void TaskContextStructManager::createGEPsToPrivateVars() {
2543 if (!structPtr)
2544 assert(privateVarTypes.empty());
2545 // Still need to run createGEPsToPrivateVars to populate llvmPrivateVarGEPs
2546 // with null values for skipped private decls
2547
2548 llvmPrivateVarGEPs = createGEPsToPrivateVars(structPtr);
2549}
2550
2551void TaskContextStructManager::freeStructPtr() {
2552 if (!structPtr)
2553 return;
2554
2555 llvm::IRBuilderBase::InsertPointGuard guard{builder};
2556 // Ensure we don't put the call to free() after the terminator
2557 builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
2558 builder.CreateFree(structPtr);
2559}
2560
2561static void storeAffinityEntry(llvm::IRBuilderBase &builder,
2562 llvm::OpenMPIRBuilder &ompBuilder,
2563 llvm::Value *affinityList, llvm::Value *index,
2564 llvm::Value *addr, llvm::Value *len) {
2565 llvm::StructType *kmpTaskAffinityInfoTy =
2566 ompBuilder.getKmpTaskAffinityInfoTy();
2567 llvm::Value *entry = builder.CreateInBoundsGEP(
2568 kmpTaskAffinityInfoTy, affinityList, index, "omp.affinity.entry");
2569
2570 addr = builder.CreatePtrToInt(addr, kmpTaskAffinityInfoTy->getElementType(0));
2571 len = builder.CreateIntCast(len, kmpTaskAffinityInfoTy->getElementType(1),
2572 /*isSigned=*/false);
2573 llvm::Value *flags = builder.getInt32(0);
2574
2575 builder.CreateStore(addr,
2576 builder.CreateStructGEP(kmpTaskAffinityInfoTy, entry, 0));
2577 builder.CreateStore(len,
2578 builder.CreateStructGEP(kmpTaskAffinityInfoTy, entry, 1));
2579 builder.CreateStore(flags,
2580 builder.CreateStructGEP(kmpTaskAffinityInfoTy, entry, 2));
2581}
2582
2584 llvm::IRBuilderBase &builder,
2585 LLVM::ModuleTranslation &moduleTranslation,
2586 llvm::Value *affinityList) {
2587 for (auto [i, affinityVar] : llvm::enumerate(affinityVars)) {
2588 auto entryOp = affinityVar.getDefiningOp<mlir::omp::AffinityEntryOp>();
2589 assert(entryOp && "affinity item must be omp.affinity_entry");
2590
2591 llvm::Value *addr = moduleTranslation.lookupValue(entryOp.getAddr());
2592 llvm::Value *len = moduleTranslation.lookupValue(entryOp.getLen());
2593 assert(addr && len && "expect affinity addr and len to be non-null");
2594 storeAffinityEntry(builder, *moduleTranslation.getOpenMPBuilder(),
2595 affinityList, builder.getInt64(i), addr, len);
2596 }
2597}
2598
2599static mlir::LogicalResult
2600convertIteratorRegion(llvm::Value *linearIV, IteratorInfo &iterInfo,
2601 mlir::Block &iteratorRegionBlock,
2602 llvm::IRBuilderBase &builder,
2603 LLVM::ModuleTranslation &moduleTranslation) {
2604 llvm::Value *tmp = linearIV;
2605 for (int d = (int)iterInfo.getDims() - 1; d >= 0; --d) {
2606 llvm::Value *trip = iterInfo.getTrips()[d];
2607 // idx_d = tmp % trip_d
2608 llvm::Value *idx = builder.CreateURem(tmp, trip);
2609 // tmp = tmp / trip_d
2610 tmp = builder.CreateUDiv(tmp, trip);
2611
2612 // physIV_d = lb_d + idx_d * step_d
2613 llvm::Value *physIV = builder.CreateAdd(
2614 iterInfo.getLowerBounds()[d],
2615 builder.CreateMul(idx, iterInfo.getSteps()[d]), "omp.it.phys_iv");
2616
2617 moduleTranslation.mapValue(iteratorRegionBlock.getArgument(d), physIV);
2618 }
2619
2620 // Translate the iterator region into the loop body.
2621 moduleTranslation.mapBlock(&iteratorRegionBlock, builder.GetInsertBlock());
2622 if (mlir::failed(moduleTranslation.convertBlock(iteratorRegionBlock,
2623 /*ignoreArguments=*/true,
2624 builder))) {
2625 return mlir::failure();
2626 }
2627 return mlir::success();
2628}
2629
2631 llvm::function_ref<void(llvm::Value *linearIV, mlir::omp::YieldOp yield)>;
2632
2633static mlir::LogicalResult
2634fillIteratorLoop(mlir::omp::IteratorOp itersOp, llvm::IRBuilderBase &builder,
2635 mlir::LLVM::ModuleTranslation &moduleTranslation,
2636 IteratorInfo &iterInfo, llvm::StringRef loopName,
2637 IteratorStoreEntryTy genStoreEntry) {
2638 mlir::Region &itersRegion = itersOp.getRegion();
2639 mlir::Block &iteratorRegionBlock = itersRegion.front();
2640
2641 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
2642
2643 auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy bodyIP,
2644 llvm::Value *linearIV) -> llvm::Error {
2645 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2646 builder.restoreIP(bodyIP);
2647
2648 if (failed(convertIteratorRegion(linearIV, iterInfo, iteratorRegionBlock,
2649 builder, moduleTranslation))) {
2650 return llvm::make_error<llvm::StringError>(
2651 "failed to convert iterator region", llvm::inconvertibleErrorCode());
2652 }
2653
2654 auto yield =
2655 mlir::dyn_cast<mlir::omp::YieldOp>(iteratorRegionBlock.getTerminator());
2656 assert(yield && yield.getResults().size() == 1 &&
2657 "expect omp.yield in iterator region to have one result");
2658
2659 genStoreEntry(linearIV, yield);
2660
2661 // Iterator-region block/value mappings are temporary for this conversion,
2662 // clear them to avoid stale entries in ModuleTranslation.
2663 moduleTranslation.forgetMapping(itersRegion);
2664
2665 return llvm::Error::success();
2666 };
2667
2668 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2669 moduleTranslation.getOpenMPBuilder()->createIteratorLoop(
2670 loc, iterInfo.getTotalTrips(), bodyGen, loopName);
2671 if (failed(handleError(afterIP, *itersOp)))
2672 return failure();
2673
2674 builder.restoreIP(*afterIP);
2675
2676 return mlir::success();
2677}
2678
2679static mlir::LogicalResult
2680buildAffinityData(mlir::omp::TaskOp &taskOp, llvm::IRBuilderBase &builder,
2681 mlir::LLVM::ModuleTranslation &moduleTranslation,
2682 llvm::OpenMPIRBuilder::AffinityData &ad) {
2683
2684 if (taskOp.getAffinityVars().empty() && taskOp.getIterated().empty()) {
2685 ad.Count = nullptr;
2686 ad.Info = nullptr;
2687 return mlir::success();
2688 }
2689
2691 llvm::StructType *kmpTaskAffinityInfoTy =
2692 moduleTranslation.getOpenMPBuilder()->getKmpTaskAffinityInfoTy();
2693
2694 auto allocateAffinityList = [&](llvm::Value *count) -> llvm::Value * {
2695 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2696 if (llvm::isa<llvm::Constant>(count) || llvm::isa<llvm::Argument>(count))
2697 builder.restoreIP(findAllocInsertPoints(builder, moduleTranslation));
2698 return builder.CreateAlloca(kmpTaskAffinityInfoTy, count,
2699 "omp.affinity_list");
2700 };
2701
2702 auto createAffinity =
2703 [&](llvm::Value *count,
2704 llvm::Value *info) -> llvm::OpenMPIRBuilder::AffinityData {
2705 llvm::OpenMPIRBuilder::AffinityData ad{};
2706 ad.Count = builder.CreateTrunc(count, builder.getInt32Ty());
2707 ad.Info =
2708 builder.CreatePointerBitCastOrAddrSpaceCast(info, builder.getPtrTy(0));
2709 return ad;
2710 };
2711
2712 if (!taskOp.getAffinityVars().empty()) {
2713 llvm::Value *count = llvm::ConstantInt::get(
2714 builder.getInt64Ty(), taskOp.getAffinityVars().size());
2715 llvm::Value *list = allocateAffinityList(count);
2716 fillAffinityLocators(taskOp.getAffinityVars(), builder, moduleTranslation,
2717 list);
2718 ads.emplace_back(createAffinity(count, list));
2719 }
2720
2721 if (!taskOp.getIterated().empty()) {
2722 for (auto [i, iter] : llvm::enumerate(taskOp.getIterated())) {
2723 auto itersOp = iter.getDefiningOp<omp::IteratorOp>();
2724 assert(itersOp && "iterated value must be defined by omp.iterator");
2725 IteratorInfo iterInfo(itersOp, moduleTranslation, builder);
2726 llvm::Value *affList = allocateAffinityList(iterInfo.getTotalTrips());
2727 if (failed(fillIteratorLoop(
2728 itersOp, builder, moduleTranslation, iterInfo, "iterator",
2729 [&](llvm::Value *linearIV, mlir::omp::YieldOp yield) {
2730 auto entryOp = yield.getResults()[0]
2731 .getDefiningOp<mlir::omp::AffinityEntryOp>();
2732 assert(entryOp && "expect yield produce an affinity entry");
2733 llvm::Value *addr =
2734 moduleTranslation.lookupValue(entryOp.getAddr());
2735 llvm::Value *len =
2736 moduleTranslation.lookupValue(entryOp.getLen());
2737 storeAffinityEntry(builder,
2738 *moduleTranslation.getOpenMPBuilder(),
2739 affList, linearIV, addr, len);
2740 })))
2741 return llvm::failure();
2742 ads.emplace_back(createAffinity(iterInfo.getTotalTrips(), affList));
2743 }
2744 }
2745
2746 llvm::Value *totalAffinityCount = builder.getInt32(0);
2747 for (const auto &affinity : ads)
2748 totalAffinityCount = builder.CreateAdd(
2749 totalAffinityCount,
2750 builder.CreateIntCast(affinity.Count, builder.getInt32Ty(),
2751 /*isSigned=*/false));
2752
2753 llvm::Value *affinityInfo = ads.front().Info;
2754 if (ads.size() > 1) {
2755 llvm::StructType *kmpTaskAffinityInfoTy =
2756 moduleTranslation.getOpenMPBuilder()->getKmpTaskAffinityInfoTy();
2757 llvm::Value *affinityInfoElemSize = builder.getInt64(
2758 moduleTranslation.getLLVMModule()->getDataLayout().getTypeAllocSize(
2759 kmpTaskAffinityInfoTy));
2760
2761 llvm::Value *packedAffinityInfo = allocateAffinityList(totalAffinityCount);
2762 llvm::Value *packedAffinityInfoOffset = builder.getInt32(0);
2763 for (const auto &affinity : ads) {
2764 llvm::Value *affinityCount = builder.CreateIntCast(
2765 affinity.Count, builder.getInt32Ty(), /*isSigned=*/false);
2766 llvm::Value *affinityCountInt64 = builder.CreateIntCast(
2767 affinityCount, builder.getInt64Ty(), /*isSigned=*/false);
2768 llvm::Value *affinityInfoSize =
2769 builder.CreateMul(affinityCountInt64, affinityInfoElemSize);
2770
2771 llvm::Value *packedAffinityInfoIndex = builder.CreateIntCast(
2772 packedAffinityInfoOffset, kmpTaskAffinityInfoTy->getElementType(0),
2773 /*isSigned=*/false);
2774 packedAffinityInfoIndex = builder.CreateInBoundsGEP(
2775 kmpTaskAffinityInfoTy, packedAffinityInfo, packedAffinityInfoIndex);
2776
2777 builder.CreateMemCpy(
2778 packedAffinityInfoIndex, llvm::Align(1),
2779 builder.CreatePointerBitCastOrAddrSpaceCast(
2780 affinity.Info, builder.getPtrTy(packedAffinityInfoIndex->getType()
2781 ->getPointerAddressSpace())),
2782 llvm::Align(1), affinityInfoSize);
2783
2784 packedAffinityInfoOffset =
2785 builder.CreateAdd(packedAffinityInfoOffset, affinityCount);
2786 }
2787
2788 affinityInfo = packedAffinityInfo;
2789 }
2790
2791 ad.Count = totalAffinityCount;
2792 ad.Info = affinityInfo;
2793
2794 return mlir::success();
2795}
2796
2797// Allocates a single kmp_dep_info array sized to hold both locator
2798// (non-iterated) and iterated entries, fills the locator entries first, then
2799// runs an iterator loop for each iterator modifier object.
2800static mlir::LogicalResult
2801buildDependData(OperandRange dependVars, std::optional<ArrayAttr> dependKinds,
2802 OperandRange dependIterated,
2803 std::optional<ArrayAttr> dependIteratedKinds,
2804 llvm::IRBuilderBase &builder,
2805 mlir::LLVM::ModuleTranslation &moduleTranslation,
2806 llvm::OpenMPIRBuilder::DependenciesInfo &taskDeps) {
2807 if (dependIterated.empty()) {
2808 buildDependDataLocator(dependKinds, dependVars, moduleTranslation,
2809 taskDeps.Deps);
2810 return mlir::success();
2811 }
2812
2813 llvm::OpenMPIRBuilder &ompBuilder = *moduleTranslation.getOpenMPBuilder();
2814 llvm::Type *dependInfoTy = ompBuilder.DependInfo;
2815 unsigned numLocator = dependVars.size();
2816
2817 // Compute total count: locator deps + sum of iterator trip counts.
2818 llvm::Value *totalCount =
2819 llvm::ConstantInt::get(builder.getInt64Ty(), numLocator);
2820
2822 for (auto iter : dependIterated) {
2823 auto itersOp = iter.getDefiningOp<mlir::omp::IteratorOp>();
2824 assert(itersOp && "depend_iterated value must be defined by omp.iterator");
2825 iterInfos.emplace_back(itersOp, moduleTranslation, builder);
2826 totalCount =
2827 builder.CreateAdd(totalCount, iterInfos.back().getTotalTrips());
2828 }
2829
2830 // Heap-allocate the kmp_depend_info array so we don't risk
2831 // dynamic-sized alloca outside the entry block (e.g. inside loops).
2832 llvm::Constant *allocSize = llvm::ConstantExpr::getSizeOf(dependInfoTy);
2833 llvm::Value *depArray =
2834 builder.CreateMalloc(ompBuilder.SizeTy, dependInfoTy, allocSize,
2835 totalCount, /*MallocF=*/nullptr, ".dep.arr.addr");
2836
2837 // Fill non-iterated entries at indices [0, numLocator).
2838 if (numLocator > 0) {
2840 buildDependDataLocator(dependKinds, dependVars, moduleTranslation, dds);
2841 for (auto [i, dd] : llvm::enumerate(dds)) {
2842 llvm::Value *idx = llvm::ConstantInt::get(builder.getInt64Ty(), i);
2843 llvm::Value *entry =
2844 builder.CreateInBoundsGEP(dependInfoTy, depArray, idx);
2845 ompBuilder.emitTaskDependency(builder, entry, dd);
2846 }
2847 }
2848
2849 // Fill iterated entries starting at index numLocator.
2850 llvm::Value *offset =
2851 llvm::ConstantInt::get(builder.getInt64Ty(), numLocator);
2852 for (auto [i, iterInfo] : llvm::enumerate(iterInfos)) {
2853 auto kindAttr = cast<mlir::omp::ClauseTaskDependAttr>(
2854 dependIteratedKinds->getValue()[i]);
2855 llvm::omp::RTLDependenceKindTy rtlKind =
2856 convertDependKind(kindAttr.getValue());
2857
2858 auto itersOp = dependIterated[i].getDefiningOp<mlir::omp::IteratorOp>();
2859 if (failed(fillIteratorLoop(
2860 itersOp, builder, moduleTranslation, iterInfo, "dep_iterator",
2861 [&](llvm::Value *linearIV, mlir::omp::YieldOp yield) {
2862 llvm::Value *addr =
2863 moduleTranslation.lookupValue(yield.getResults()[0]);
2864 llvm::Value *idx = builder.CreateAdd(offset, linearIV);
2865 llvm::Value *entry =
2866 builder.CreateInBoundsGEP(dependInfoTy, depArray, idx);
2867 ompBuilder.emitTaskDependency(
2868 builder, entry,
2869 llvm::OpenMPIRBuilder::DependData{rtlKind, addr->getType(),
2870 addr});
2871 })))
2872 return mlir::failure();
2873
2874 // Advance offset by the trip count of this iterator.
2875 offset = builder.CreateAdd(offset, iterInfo.getTotalTrips());
2876 }
2877
2878 taskDeps.DepArray = depArray;
2879 taskDeps.NumDeps = builder.CreateTrunc(totalCount, builder.getInt32Ty());
2880 return mlir::success();
2881}
2882
2883/// Converts an OpenMP task construct into LLVM IR using OpenMPIRBuilder.
2884static LogicalResult
2885convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
2886 LLVM::ModuleTranslation &moduleTranslation) {
2887 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2888 if (failed(checkImplementationStatus(*taskOp)))
2889 return failure();
2890
2891 PrivateVarsInfo privateVarsInfo(taskOp);
2892 TaskContextStructManager taskStructMgr{builder, moduleTranslation,
2893 privateVarsInfo.privatizers};
2894
2895 // Allocate and copy private variables before creating the task. This avoids
2896 // accessing invalid memory if (after this scope ends) the private variables
2897 // are initialized from host variables or if the variables are copied into
2898 // from host variables (firstprivate). The insertion point is just before
2899 // where the code for creating and scheduling the task will go. That puts this
2900 // code outside of the outlined task region, which is what we want because
2901 // this way the initialization and copy regions are executed immediately while
2902 // the host variable data are still live.
2904 InsertPointTy allocaIP =
2905 findAllocInsertPoints(builder, moduleTranslation, &deallocBlocks);
2906
2907 // Not using splitBB() because that requires the current block to have a
2908 // terminator.
2909 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end());
2910 llvm::BasicBlock *taskStartBlock = llvm::BasicBlock::Create(
2911 builder.getContext(), "omp.task.start",
2912 /*Parent=*/builder.GetInsertBlock()->getParent());
2913 llvm::Instruction *branchToTaskStartBlock = builder.CreateBr(taskStartBlock);
2914 builder.SetInsertPoint(branchToTaskStartBlock);
2915
2916 // Now do this again to make the initialization and copy blocks
2917 llvm::BasicBlock *copyBlock =
2918 splitBB(builder, /*CreateBranch=*/true, "omp.private.copy");
2919 llvm::BasicBlock *initBlock =
2920 splitBB(builder, /*CreateBranch=*/true, "omp.private.init");
2921
2922 // Now the control flow graph should look like
2923 // starter_block:
2924 // <---- where we started when convertOmpTaskOp was called
2925 // br %omp.private.init
2926 // omp.private.init:
2927 // br %omp.private.copy
2928 // omp.private.copy:
2929 // br %omp.task.start
2930 // omp.task.start:
2931 // <---- where we want the insertion point to be when we call createTask()
2932
2933 // Save the alloca insertion point on ModuleTranslation stack for use in
2934 // nested regions.
2936 moduleTranslation, allocaIP, deallocBlocks);
2937
2938 // Allocate and initialize private variables
2939 builder.SetInsertPoint(initBlock->getTerminator());
2940
2941 // Create task variable structure
2942 taskStructMgr.generateTaskContextStruct();
2943 // GEPs so that we can initialize the variables. Don't use these GEPs inside
2944 // of the body otherwise it will be the GEP not the struct which is fowarded
2945 // to the outlined function. GEPs forwarded in this way are passed in a
2946 // stack-allocated (by OpenMPIRBuilder) structure which is not safe for tasks
2947 // which may not be executed until after the current stack frame goes out of
2948 // scope.
2949 taskStructMgr.createGEPsToPrivateVars();
2950
2951 for (auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVarAlloc] :
2952 llvm::zip_equal(privateVarsInfo.privatizers, privateVarsInfo.mlirVars,
2953 privateVarsInfo.blockArgs,
2954 taskStructMgr.getLLVMPrivateVarGEPs())) {
2955 // To be handled inside the task.
2956 if (!privDecl.readsFromMold())
2957 continue;
2958 assert(llvmPrivateVarAlloc &&
2959 "reads from mold so shouldn't have been skipped");
2960
2961 llvm::Expected<llvm::Value *> privateVarOrErr =
2962 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2963 blockArg, llvmPrivateVarAlloc, initBlock);
2964 if (!privateVarOrErr)
2965 return handleError(privateVarOrErr, *taskOp.getOperation());
2966
2968
2969 // TODO: this is a bit of a hack for Fortran character boxes.
2970 // Character boxes are passed by value into the init region and then the
2971 // initialized character box is yielded by value. Here we need to store the
2972 // yielded value into the private allocation, and load the private
2973 // allocation to match the type expected by region block arguments.
2974 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
2975 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2976 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
2977 // Load it so we have the value pointed to by the GEP
2978 llvmPrivateVarAlloc = builder.CreateLoad(privateVarOrErr.get()->getType(),
2979 llvmPrivateVarAlloc);
2980 }
2981 assert(llvmPrivateVarAlloc->getType() ==
2982 moduleTranslation.convertType(blockArg.getType()));
2983
2984 // Mapping blockArg -> llvmPrivateVarAlloc is done inside the body callback
2985 // so that OpenMPIRBuilder doesn't try to pass each GEP address through a
2986 // stack allocated structure.
2987 }
2988
2989 // firstprivate copy region
2990 setInsertPointForPossiblyEmptyBlock(builder, copyBlock);
2991 if (failed(copyFirstPrivateVars(
2992 taskOp, builder, moduleTranslation, privateVarsInfo.mlirVars,
2993 taskStructMgr.getLLVMPrivateVarGEPs(), privateVarsInfo.privatizers,
2994 taskOp.getPrivateNeedsBarrier())))
2995 return llvm::failure();
2996
2997 llvm::OpenMPIRBuilder::AffinityData ad;
2998 if (failed(buildAffinityData(taskOp, builder, moduleTranslation, ad)))
2999 return llvm::failure();
3000
3001 // Set up for call to createTask()
3002 builder.SetInsertPoint(taskStartBlock);
3003
3004 auto bodyCB =
3005 [&](InsertPointTy allocaIP, InsertPointTy codegenIP,
3006 llvm::ArrayRef<llvm::BasicBlock *> deallocBlocks) -> llvm::Error {
3007 // Save the alloca insertion point on ModuleTranslation stack for use in
3008 // nested regions.
3010 moduleTranslation, allocaIP, deallocBlocks);
3011
3012 // translate the body of the task:
3013 builder.restoreIP(codegenIP);
3014
3015 llvm::BasicBlock *privInitBlock = nullptr;
3016 privateVarsInfo.llvmVars.resize(privateVarsInfo.blockArgs.size());
3017 for (auto [i, zip] : llvm::enumerate(llvm::zip_equal(
3018 privateVarsInfo.blockArgs, privateVarsInfo.privatizers,
3019 privateVarsInfo.mlirVars))) {
3020 auto [blockArg, privDecl, mlirPrivVar] = zip;
3021 // This is handled before the task executes
3022 if (privDecl.readsFromMold())
3023 continue;
3024
3025 llvm::IRBuilderBase::InsertPointGuard guard(builder);
3026 llvm::Type *llvmAllocType =
3027 moduleTranslation.convertType(privDecl.getType());
3028 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
3029 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
3030 llvmAllocType, /*ArraySize=*/nullptr, "omp.private.alloc");
3031
3032 llvm::Expected<llvm::Value *> privateVarOrError =
3033 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
3034 blockArg, llvmPrivateVar, privInitBlock);
3035 if (!privateVarOrError)
3036 return privateVarOrError.takeError();
3037 moduleTranslation.mapValue(blockArg, privateVarOrError.get());
3038 privateVarsInfo.llvmVars[i] = privateVarOrError.get();
3039 }
3040
3041 taskStructMgr.createGEPsToPrivateVars();
3042 for (auto [i, llvmPrivVar] :
3043 llvm::enumerate(taskStructMgr.getLLVMPrivateVarGEPs())) {
3044 if (!llvmPrivVar) {
3045 assert(privateVarsInfo.llvmVars[i] &&
3046 "This is added in the loop above");
3047 continue;
3048 }
3049 privateVarsInfo.llvmVars[i] = llvmPrivVar;
3050 }
3051
3052 // Find and map the addresses of each variable within the task context
3053 // structure
3054 for (auto [blockArg, llvmPrivateVar, privateDecl] :
3055 llvm::zip_equal(privateVarsInfo.blockArgs, privateVarsInfo.llvmVars,
3056 privateVarsInfo.privatizers)) {
3057 // This was handled above.
3058 if (!privateDecl.readsFromMold())
3059 continue;
3060 // Fix broken pass-by-value case for Fortran character boxes
3061 if (!mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
3062 llvmPrivateVar = builder.CreateLoad(
3063 moduleTranslation.convertType(blockArg.getType()), llvmPrivateVar);
3064 }
3065 assert(llvmPrivateVar->getType() ==
3066 moduleTranslation.convertType(blockArg.getType()));
3067 moduleTranslation.mapValue(blockArg, llvmPrivateVar);
3068 }
3069
3070 auto continuationBlockOrError = convertOmpOpRegions(
3071 taskOp.getRegion(), "omp.task.region", builder, moduleTranslation);
3072 if (failed(handleError(continuationBlockOrError, *taskOp)))
3073 return llvm::make_error<PreviouslyReportedError>();
3074
3075 builder.SetInsertPoint(continuationBlockOrError.get()->getTerminator());
3076
3077 if (failed(cleanupPrivateVars(taskOp, builder, moduleTranslation,
3078 taskOp.getLoc(), privateVarsInfo)))
3079 return llvm::make_error<PreviouslyReportedError>();
3080
3081 // Free heap allocated task context structure at the end of the task.
3082 taskStructMgr.freeStructPtr();
3083
3084 return llvm::Error::success();
3085 };
3086
3087 llvm::OpenMPIRBuilder &ompBuilder = *moduleTranslation.getOpenMPBuilder();
3088 SmallVector<llvm::UncondBrInst *> cancelTerminators;
3089 // The directive to match here is OMPD_taskgroup because it is the taskgroup
3090 // which is canceled. This is handled here because it is the task's cleanup
3091 // block which should be branched to.
3092 pushCancelFinalizationCB(cancelTerminators, builder, ompBuilder, taskOp,
3093 llvm::omp::Directive::OMPD_taskgroup);
3094
3095 llvm::OpenMPIRBuilder::DependenciesInfo dependencies;
3096 if (failed(buildDependData(taskOp.getDependVars(), taskOp.getDependKinds(),
3097 taskOp.getDependIterated(),
3098 taskOp.getDependIteratedKinds(), builder,
3099 moduleTranslation, dependencies)))
3100 return failure();
3101
3102 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3103 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3104 moduleTranslation.getOpenMPBuilder()->createTask(
3105 ompLoc, allocaIP, deallocBlocks, bodyCB, !taskOp.getUntied(),
3106 moduleTranslation.lookupValue(taskOp.getFinal()),
3107 moduleTranslation.lookupValue(taskOp.getIfExpr()), dependencies, ad,
3108 taskOp.getMergeable(),
3109 moduleTranslation.lookupValue(taskOp.getEventHandle()),
3110 moduleTranslation.lookupValue(taskOp.getPriority()));
3111
3112 if (failed(handleError(afterIP, *taskOp)))
3113 return failure();
3114
3115 // Set the correct branch target for task cancellation
3116 popCancelFinalizationCB(cancelTerminators, ompBuilder, afterIP.get());
3117
3118 builder.restoreIP(*afterIP);
3119
3120 if (dependencies.DepArray)
3121 builder.CreateFree(dependencies.DepArray);
3122
3123 return success();
3124}
3125
3126/// The correct entry point is convertOmpTaskloopContextOp. This gets called
3127/// whilst lowering the body of the taskloop context (i.e. the task function).
3128static LogicalResult
3129convertOmpTaskloopWrapperOp(omp::TaskloopWrapperOp loopWrapperOp,
3130 llvm::IRBuilderBase &builder,
3131 LLVM::ModuleTranslation &moduleTranslation) {
3132 mlir::Operation &opInst = *loopWrapperOp.getOperation();
3133 if (failed(checkImplementationStatus(opInst)))
3134 return failure();
3135
3136 // Recurse into the loop body.
3137 auto continuationBlockOrError = convertOmpOpRegions(
3138 loopWrapperOp.getRegion(), "omp.taskloop.wrapper.region", builder,
3139 moduleTranslation);
3140
3141 if (failed(handleError(continuationBlockOrError, opInst)))
3142 return failure();
3143
3144 builder.SetInsertPoint(continuationBlockOrError.get());
3145 return success();
3146}
3147
3148/// Look up the given value in the mapping, and if it's not there, translate its
3149/// defining operation at the current builder insertion point. Only pure,
3150/// regionless operations are supported because the same operation will later be
3151/// translated again when the taskloop body itself is lowered.
3152static llvm::Expected<llvm::Value *>
3154 LLVM::ModuleTranslation &moduleTranslation,
3155 llvm::IRBuilderBase &builder) {
3156 if (llvm::Value *mapped = moduleTranslation.lookupValue(value))
3157 return mapped;
3158
3159 Operation *defOp = value.getDefiningOp();
3160 if (!defOp)
3161 return llvm::make_error<llvm::StringError>(
3162 "value is a block argument and is not mapped",
3163 llvm::inconvertibleErrorCode());
3164 if (defOp->getNumRegions() != 0 || !isPure(defOp))
3165 return llvm::make_error<llvm::StringError>(
3166 "unsupported op defining taskloop loop bound",
3167 llvm::inconvertibleErrorCode());
3168
3169 SmallVector<Value> mappingsToRemove;
3170 mappingsToRemove.reserve(defOp->getNumOperands() + defOp->getNumResults());
3171 for (Value operand : defOp->getOperands()) {
3172 if (moduleTranslation.lookupValue(operand))
3173 continue;
3174
3175 llvm::Expected<llvm::Value *> operandOrError =
3176 lookupOrTranslatePureValue(operand, moduleTranslation, builder);
3177 if (!operandOrError)
3178 return operandOrError.takeError();
3179 moduleTranslation.mapValue(operand, *operandOrError);
3180 mappingsToRemove.push_back(operand);
3181 }
3182
3183 if (failed(moduleTranslation.convertOperation(*defOp, builder)))
3184 return llvm::make_error<llvm::StringError>(
3185 "failed to convert op defining taskloop loop bound",
3186 llvm::inconvertibleErrorCode());
3187
3188 llvm::Value *result = moduleTranslation.lookupValue(value);
3189 assert(result && "expected conversion of loop bound op to produce a value");
3190
3191 for (Value resultValue : defOp->getResults()) {
3192 if (moduleTranslation.lookupValue(resultValue))
3193 mappingsToRemove.push_back(resultValue);
3194 }
3195 for (Value mappedValue : mappingsToRemove)
3196 moduleTranslation.forgetMapping(mappedValue);
3197
3198 return result;
3199}
3200
3201static llvm::Error
3202computeTaskloopBounds(omp::LoopNestOp loopOp, llvm::IRBuilderBase &builder,
3203 LLVM::ModuleTranslation &moduleTranslation,
3204 llvm::Value *&lbVal, llvm::Value *&ubVal,
3205 llvm::Value *&stepVal) {
3206 Operation::operand_range lowerBounds = loopOp.getLoopLowerBounds();
3207 Operation::operand_range upperBounds = loopOp.getLoopUpperBounds();
3208 Operation::operand_range steps = loopOp.getLoopSteps();
3209
3210 llvm::Expected<llvm::Value *> firstLbOrErr =
3211 lookupOrTranslatePureValue(lowerBounds[0], moduleTranslation, builder);
3212 if (!firstLbOrErr)
3213 return firstLbOrErr.takeError();
3214
3215 llvm::Type *boundType = (*firstLbOrErr)->getType();
3216 ubVal = builder.getIntN(boundType->getIntegerBitWidth(), 1);
3217 if (loopOp.getCollapseNumLoops() > 1) {
3218 // In cases where Collapse is used with Taskloop, the upper bound of the
3219 // iteration space needs to be recalculated to cater for the collapsed loop.
3220 // The Collapsed Loop UpperBound is the product of all collapsed
3221 // loop's tripcount.
3222 // The LowerBound for collapsed loops is always 1. When the loops are
3223 // collapsed, it will reset the bounds and introduce processing to ensure
3224 // the index's are presented as expected. As this happens after creating
3225 // Taskloop, these bounds need predicting. Example:
3226 // !$omp taskloop collapse(2)
3227 // do i = 1, 10
3228 // do j = 1, 5
3229 // ..
3230 // end do
3231 // end do
3232 // This loop above has a total of 50 iterations, so the lb will be 1, and
3233 // the ub will be 50. collapseLoops in OMPIRBuilder then handles ensuring
3234 // that i and j are properly presented when used in the loop.
3235 for (uint64_t i = 0; i < loopOp.getCollapseNumLoops(); i++) {
3237 i == 0 ? std::move(firstLbOrErr)
3238 : lookupOrTranslatePureValue(lowerBounds[i], moduleTranslation,
3239 builder);
3240 if (!lbOrErr)
3241 return lbOrErr.takeError();
3243 upperBounds[i], moduleTranslation, builder);
3244 if (!ubOrErr)
3245 return ubOrErr.takeError();
3247 lookupOrTranslatePureValue(steps[i], moduleTranslation, builder);
3248 if (!stepOrErr)
3249 return stepOrErr.takeError();
3250
3251 llvm::Value *loopLb = *lbOrErr;
3252 llvm::Value *loopUb = *ubOrErr;
3253 llvm::Value *loopStep = *stepOrErr;
3254 // In some cases, such as where the ub is less than the lb so the loop
3255 // steps down, the calculation for the loopTripCount is swapped. To ensure
3256 // the correct value is found, calculate both UB - LB and LB - UB then
3257 // select which value to use depending on how the loop has been
3258 // configured.
3259 llvm::Value *loopLbMinusOne = builder.CreateSub(
3260 loopLb, builder.getIntN(boundType->getIntegerBitWidth(), 1));
3261 llvm::Value *loopUbMinusOne = builder.CreateSub(
3262 loopUb, builder.getIntN(boundType->getIntegerBitWidth(), 1));
3263 llvm::Value *boundsCmp = builder.CreateICmpSLT(loopLb, loopUb);
3264 llvm::Value *ubMinusLb = builder.CreateSub(loopUb, loopLbMinusOne);
3265 llvm::Value *lbMinusUb = builder.CreateSub(loopLb, loopUbMinusOne);
3266 llvm::Value *loopTripCount =
3267 builder.CreateSelect(boundsCmp, ubMinusLb, lbMinusUb);
3268 loopTripCount = builder.CreateBinaryIntrinsic(
3269 llvm::Intrinsic::abs, loopTripCount, builder.getFalse());
3270 // For loops that have a step value not equal to 1, we need to adjust the
3271 // trip count to ensure the correct number of iterations for the loop is
3272 // captured.
3273 llvm::Value *loopTripCountDivStep =
3274 builder.CreateSDiv(loopTripCount, loopStep);
3275 loopTripCountDivStep = builder.CreateBinaryIntrinsic(
3276 llvm::Intrinsic::abs, loopTripCountDivStep, builder.getFalse());
3277 llvm::Value *loopTripCountRem =
3278 builder.CreateSRem(loopTripCount, loopStep);
3279 loopTripCountRem = builder.CreateBinaryIntrinsic(
3280 llvm::Intrinsic::abs, loopTripCountRem, builder.getFalse());
3281 llvm::Value *needsRoundUp = builder.CreateICmpNE(
3282 loopTripCountRem,
3283 builder.getIntN(loopTripCountRem->getType()->getIntegerBitWidth(),
3284 0));
3285 loopTripCount =
3286 builder.CreateAdd(loopTripCountDivStep,
3287 builder.CreateZExtOrTrunc(
3288 needsRoundUp, loopTripCountDivStep->getType()));
3289 ubVal = builder.CreateMul(ubVal, loopTripCount);
3290 }
3291 lbVal = builder.getIntN(boundType->getIntegerBitWidth(), 1);
3292 stepVal = builder.getIntN(boundType->getIntegerBitWidth(), 1);
3293 } else {
3295 lookupOrTranslatePureValue(upperBounds[0], moduleTranslation, builder);
3296 if (!ubOrErr)
3297 return ubOrErr.takeError();
3299 lookupOrTranslatePureValue(steps[0], moduleTranslation, builder);
3300 if (!stepOrErr)
3301 return stepOrErr.takeError();
3302 lbVal = *firstLbOrErr;
3303 ubVal = *ubOrErr;
3304 stepVal = *stepOrErr;
3305 }
3306
3307 assert(lbVal != nullptr && "Expected value for lbVal");
3308 assert(ubVal != nullptr && "Expected value for ubVal");
3309 assert(stepVal != nullptr && "Expected value for stepVal");
3310 return llvm::Error::success();
3311}
3312
3313// Converts an OpenMP taskloop construct into LLVM IR using OpenMPIRBuilder.
3314static LogicalResult
3315convertOmpTaskloopContextOp(omp::TaskloopContextOp contextOp,
3316 llvm::IRBuilderBase &builder,
3317 LLVM::ModuleTranslation &moduleTranslation) {
3318 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
3319 mlir::Operation &opInst = *contextOp.getOperation();
3320 omp::TaskloopWrapperOp loopWrapperOp = contextOp.getLoopOp();
3321 if (failed(checkImplementationStatus(opInst)))
3322 return failure();
3323
3324 // It stores the pointer of allocated firstprivate copies,
3325 // which can be used later for freeing the allocated space.
3326 SmallVector<llvm::Value *> llvmFirstPrivateVars;
3327 PrivateVarsInfo privateVarsInfo(contextOp);
3328 TaskContextStructManager taskStructMgr{builder, moduleTranslation,
3329 privateVarsInfo.privatizers};
3330
3332 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3333 findAllocInsertPoints(builder, moduleTranslation, &deallocBlocks);
3334
3335 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end());
3336 llvm::BasicBlock *taskloopStartBlock = llvm::BasicBlock::Create(
3337 builder.getContext(), "omp.taskloop.wrapper.start",
3338 /*Parent=*/builder.GetInsertBlock()->getParent());
3339 llvm::Instruction *branchToTaskloopStartBlock =
3340 builder.CreateBr(taskloopStartBlock);
3341 builder.SetInsertPoint(branchToTaskloopStartBlock);
3342
3343 llvm::BasicBlock *copyBlock =
3344 splitBB(builder, /*CreateBranch=*/true, "omp.private.copy");
3345 llvm::BasicBlock *initBlock =
3346 splitBB(builder, /*CreateBranch=*/true, "omp.private.init");
3347
3349 moduleTranslation, allocaIP, deallocBlocks);
3350
3351 // Allocate and initialize private variables
3352 builder.SetInsertPoint(initBlock->getTerminator());
3353
3354 // TODO: don't allocate if the loop has zero iterations.
3355 taskStructMgr.generateTaskContextStruct();
3356 taskStructMgr.createGEPsToPrivateVars();
3357
3358 llvmFirstPrivateVars.resize(privateVarsInfo.blockArgs.size());
3359
3360 for (auto [i, zip] : llvm::enumerate(llvm::zip_equal(
3361 privateVarsInfo.privatizers, privateVarsInfo.mlirVars,
3362 privateVarsInfo.blockArgs, taskStructMgr.getLLVMPrivateVarGEPs()))) {
3363 auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVarAlloc] = zip;
3364 // To be handled inside the taskloop.
3365 if (!privDecl.readsFromMold())
3366 continue;
3367 assert(llvmPrivateVarAlloc &&
3368 "reads from mold so shouldn't have been skipped");
3369
3370 llvm::Expected<llvm::Value *> privateVarOrErr =
3371 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
3372 blockArg, llvmPrivateVarAlloc, initBlock);
3373 if (!privateVarOrErr)
3374 return handleError(privateVarOrErr, *contextOp.getOperation());
3375
3376 llvmFirstPrivateVars[i] = privateVarOrErr.get();
3377
3378 llvm::IRBuilderBase::InsertPointGuard guard(builder);
3379 builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
3380
3381 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
3382 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
3383 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
3384 // Load it so we have the value pointed to by the GEP
3385 llvmPrivateVarAlloc = builder.CreateLoad(privateVarOrErr.get()->getType(),
3386 llvmPrivateVarAlloc);
3387 }
3388 assert(llvmPrivateVarAlloc->getType() ==
3389 moduleTranslation.convertType(blockArg.getType()));
3390 }
3391
3392 // firstprivate copy region
3393 setInsertPointForPossiblyEmptyBlock(builder, copyBlock);
3394 if (failed(copyFirstPrivateVars(
3395 contextOp, builder, moduleTranslation, privateVarsInfo.mlirVars,
3396 taskStructMgr.getLLVMPrivateVarGEPs(), privateVarsInfo.privatizers,
3397 contextOp.getPrivateNeedsBarrier())))
3398 return llvm::failure();
3399
3400 // Set up inserttion point for call to createTaskloop()
3401 builder.SetInsertPoint(taskloopStartBlock);
3402
3403 auto loopOp = cast<omp::LoopNestOp>(loopWrapperOp.getWrappedLoop());
3404 llvm::Value *lbVal = nullptr;
3405 llvm::Value *ubVal = nullptr;
3406 llvm::Value *stepVal = nullptr;
3407 if (llvm::Error err = computeTaskloopBounds(
3408 loopOp, builder, moduleTranslation, lbVal, ubVal, stepVal))
3409 return handleError(std::move(err), opInst);
3410
3411 auto bodyCB =
3412 [&](InsertPointTy allocaIP, InsertPointTy codegenIP,
3413 llvm::ArrayRef<llvm::BasicBlock *> deallocBlocks) -> llvm::Error {
3414 // Save the alloca insertion point on ModuleTranslation stack for use in
3415 // nested regions.
3417 moduleTranslation, allocaIP, deallocBlocks);
3418
3419 // translate the body of the taskloop:
3420 builder.restoreIP(codegenIP);
3421
3422 llvm::BasicBlock *privInitBlock = nullptr;
3423 privateVarsInfo.llvmVars.resize(privateVarsInfo.blockArgs.size());
3424 for (auto [i, zip] : llvm::enumerate(llvm::zip_equal(
3425 privateVarsInfo.blockArgs, privateVarsInfo.privatizers,
3426 privateVarsInfo.mlirVars))) {
3427 auto [blockArg, privDecl, mlirPrivVar] = zip;
3428 // This is handled before the task executes
3429 if (privDecl.readsFromMold())
3430 continue;
3431
3432 llvm::IRBuilderBase::InsertPointGuard guard(builder);
3433 llvm::Type *llvmAllocType =
3434 moduleTranslation.convertType(privDecl.getType());
3435 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
3436 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
3437 llvmAllocType, /*ArraySize=*/nullptr, "omp.private.alloc");
3438
3439 llvm::Expected<llvm::Value *> privateVarOrError =
3440 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
3441 blockArg, llvmPrivateVar, privInitBlock);
3442 if (!privateVarOrError)
3443 return privateVarOrError.takeError();
3444 moduleTranslation.mapValue(blockArg, privateVarOrError.get());
3445 privateVarsInfo.llvmVars[i] = privateVarOrError.get();
3446 }
3447
3448 taskStructMgr.createGEPsToPrivateVars();
3449 for (auto [i, llvmPrivVar] :
3450 llvm::enumerate(taskStructMgr.getLLVMPrivateVarGEPs())) {
3451 if (!llvmPrivVar) {
3452 assert(privateVarsInfo.llvmVars[i] &&
3453 "This is added in the loop above");
3454 continue;
3455 }
3456 privateVarsInfo.llvmVars[i] = llvmPrivVar;
3457 }
3458
3459 // Find and map the addresses of each variable within the taskloop context
3460 // structure
3461 for (auto [blockArg, llvmPrivateVar, privateDecl] :
3462 llvm::zip_equal(privateVarsInfo.blockArgs, privateVarsInfo.llvmVars,
3463 privateVarsInfo.privatizers)) {
3464 // This was handled above.
3465 if (!privateDecl.readsFromMold())
3466 continue;
3467 // Fix broken pass-by-value case for Fortran character boxes
3468 if (!mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
3469 llvmPrivateVar = builder.CreateLoad(
3470 moduleTranslation.convertType(blockArg.getType()), llvmPrivateVar);
3471 }
3472 assert(llvmPrivateVar->getType() ==
3473 moduleTranslation.convertType(blockArg.getType()));
3474 moduleTranslation.mapValue(blockArg, llvmPrivateVar);
3475 }
3476
3477 // Lower the contents of the taskloop context region: this is the body of
3478 // the generated task, not the loop.
3479 auto continuationBlockOrError = convertOmpOpRegions(
3480 contextOp.getRegion(), "omp.taskloop.context.region", builder,
3481 moduleTranslation);
3482
3483 if (failed(handleError(continuationBlockOrError, opInst)))
3484 return llvm::make_error<PreviouslyReportedError>();
3485
3486 builder.SetInsertPoint(continuationBlockOrError.get()->getTerminator());
3487
3488 // This is freeing the private variables as mapped inside of the task: these
3489 // will be per-task private copies possibly after task duplication. This is
3490 // handled transparently by how these are passed to the structure passed
3491 // into the outlined function. When the task is duplicated, that structure
3492 // is duplicated too.
3493 if (failed(cleanupPrivateVars(contextOp, builder, moduleTranslation,
3494 contextOp.getLoc(), privateVarsInfo)))
3495 return llvm::make_error<PreviouslyReportedError>();
3496 // Similarly, the task context structure freed inside the task is the
3497 // per-task copy after task duplication.
3498 taskStructMgr.freeStructPtr();
3499
3500 return llvm::Error::success();
3501 };
3502
3503 // Taskloop divides into an appropriate number of tasks by repeatedly
3504 // duplicating the original task. Each time this is done, the task context
3505 // structure must be duplicated too.
3506 auto taskDupCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP,
3507 llvm::Value *destPtr, llvm::Value *srcPtr)
3509 llvm::IRBuilderBase::InsertPointGuard guard(builder);
3510 builder.restoreIP(codegenIP);
3511
3512 llvm::Type *ptrTy =
3513 builder.getPtrTy(srcPtr->getType()->getPointerAddressSpace());
3514 llvm::Value *src =
3515 builder.CreateLoad(ptrTy, srcPtr, "omp.taskloop.context.src");
3516
3517 TaskContextStructManager &srcStructMgr = taskStructMgr;
3518 TaskContextStructManager destStructMgr(builder, moduleTranslation,
3519 privateVarsInfo.privatizers);
3520 destStructMgr.generateTaskContextStruct();
3521 llvm::Value *dest = destStructMgr.getStructPtr();
3522 dest->setName("omp.taskloop.context.dest");
3523 builder.CreateStore(dest, destPtr);
3524
3526 srcStructMgr.createGEPsToPrivateVars(src);
3528 destStructMgr.createGEPsToPrivateVars(dest);
3529
3530 // Inline init regions.
3531 for (auto [privDecl, mold, blockArg, llvmPrivateVarAlloc] :
3532 llvm::zip_equal(privateVarsInfo.privatizers, srcGEPs,
3533 privateVarsInfo.blockArgs, destGEPs)) {
3534 // To be handled inside task body.
3535 if (!privDecl.readsFromMold())
3536 continue;
3537 assert(llvmPrivateVarAlloc &&
3538 "reads from mold so shouldn't have been skipped");
3539
3540 llvm::Expected<llvm::Value *> privateVarOrErr =
3541 initPrivateVar(builder, moduleTranslation, privDecl, mold, blockArg,
3542 llvmPrivateVarAlloc, builder.GetInsertBlock());
3543 if (!privateVarOrErr)
3544 return privateVarOrErr.takeError();
3545
3547
3548 // TODO: this is a bit of a hack for Fortran character boxes.
3549 // Character boxes are passed by value into the init region and then the
3550 // initialized character box is yielded by value. Here we need to store
3551 // the yielded value into the private allocation, and load the private
3552 // allocation to match the type expected by region block arguments.
3553 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
3554 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
3555 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
3556 // Load it so we have the value pointed to by the GEP
3557 llvmPrivateVarAlloc = builder.CreateLoad(
3558 privateVarOrErr.get()->getType(), llvmPrivateVarAlloc);
3559 }
3560 assert(llvmPrivateVarAlloc->getType() ==
3561 moduleTranslation.convertType(blockArg.getType()));
3562
3563 // Mapping blockArg -> llvmPrivateVarAlloc is done inside the body
3564 // callback so that OpenMPIRBuilder doesn't try to pass each GEP address
3565 // through a stack allocated structure.
3566 }
3567
3568 if (failed(copyFirstPrivateVars(contextOp.getOperation(), builder,
3569 moduleTranslation, srcGEPs, destGEPs,
3570 privateVarsInfo.privatizers,
3571 contextOp.getPrivateNeedsBarrier())))
3572 return llvm::make_error<PreviouslyReportedError>();
3573
3574 return builder.saveIP();
3575 };
3576
3577 auto loopInfo = [&]() -> llvm::Expected<llvm::CanonicalLoopInfo *> {
3578 llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
3579 return loopInfo;
3580 };
3581
3582 llvm::Value *ifCond = nullptr;
3583 llvm::Value *grainsize = nullptr;
3584 int sched = 0; // default
3585 mlir::Value grainsizeVal = contextOp.getGrainsize();
3586 mlir::Value numTasksVal = contextOp.getNumTasks();
3587 if (Value ifVar = contextOp.getIfExpr())
3588 ifCond = moduleTranslation.lookupValue(ifVar);
3589 if (grainsizeVal) {
3590 grainsize = moduleTranslation.lookupValue(grainsizeVal);
3591 sched = 1; // grainsize
3592 } else if (numTasksVal) {
3593 grainsize = moduleTranslation.lookupValue(numTasksVal);
3594 sched = 2; // num_tasks
3595 }
3596
3597 llvm::OpenMPIRBuilder::TaskDupCallbackTy taskDupOrNull = nullptr;
3598 if (taskStructMgr.getStructPtr())
3599 taskDupOrNull = taskDupCB;
3600
3601 llvm::OpenMPIRBuilder &ompBuilder = *moduleTranslation.getOpenMPBuilder();
3602 SmallVector<llvm::UncondBrInst *> cancelTerminators;
3603 // The directive to match here is OMPD_taskgroup because it is the
3604 // taskgroup which is canceled. This is handled here because it is the
3605 // task's cleanup block which should be branched to. It doesn't depend upon
3606 // nogroup because even in that case the taskloop might still be inside an
3607 // explicit taskgroup.
3608 pushCancelFinalizationCB(cancelTerminators, builder, ompBuilder, contextOp,
3609 llvm::omp::Directive::OMPD_taskgroup);
3610
3611 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3612 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3613 moduleTranslation.getOpenMPBuilder()->createTaskloop(
3614 ompLoc, allocaIP, deallocBlocks, bodyCB, loopInfo, lbVal, ubVal,
3615 stepVal, contextOp.getUntied(), ifCond, grainsize,
3616 contextOp.getNogroup(), sched,
3617 moduleTranslation.lookupValue(contextOp.getFinal()),
3618 contextOp.getMergeable(),
3619 moduleTranslation.lookupValue(contextOp.getPriority()),
3620 loopOp.getCollapseNumLoops(), taskDupOrNull,
3621 taskStructMgr.getStructPtr());
3622
3623 if (failed(handleError(afterIP, opInst)))
3624 return failure();
3625
3626 popCancelFinalizationCB(cancelTerminators, ompBuilder, afterIP.get());
3627
3628 builder.restoreIP(*afterIP);
3629 return success();
3630}
3631
3632/// Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder.
3633static LogicalResult
3634convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder,
3635 LLVM::ModuleTranslation &moduleTranslation) {
3636 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
3637 if (failed(checkImplementationStatus(*tgOp)))
3638 return failure();
3639
3640 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP,
3641 llvm::ArrayRef<llvm::BasicBlock *> deallocBlocks) {
3642 builder.restoreIP(codegenIP);
3643 return convertOmpOpRegions(tgOp.getRegion(), "omp.taskgroup.region",
3644 builder, moduleTranslation)
3645 .takeError();
3646 };
3647
3649 InsertPointTy allocaIP =
3650 findAllocInsertPoints(builder, moduleTranslation, &deallocBlocks);
3651 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3652 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3653 moduleTranslation.getOpenMPBuilder()->createTaskgroup(
3654 ompLoc, allocaIP, deallocBlocks, bodyCB);
3655
3656 if (failed(handleError(afterIP, *tgOp)))
3657 return failure();
3658
3659 builder.restoreIP(*afterIP);
3660 return success();
3661}
3662
3663static LogicalResult
3664convertOmpTaskwaitOp(omp::TaskwaitOp twOp, llvm::IRBuilderBase &builder,
3665 LLVM::ModuleTranslation &moduleTranslation) {
3666 if (failed(checkImplementationStatus(*twOp)))
3667 return failure();
3668
3669 moduleTranslation.getOpenMPBuilder()->createTaskwait(builder.saveIP());
3670 return success();
3671}
3672
3673/// Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
3674static LogicalResult
3675convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
3676 LLVM::ModuleTranslation &moduleTranslation) {
3677 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3678 auto wsloopOp = cast<omp::WsloopOp>(opInst);
3679 if (failed(checkImplementationStatus(opInst)))
3680 return failure();
3681
3682 auto loopOp = cast<omp::LoopNestOp>(wsloopOp.getWrappedLoop());
3683 llvm::ArrayRef<bool> isByRef = getIsByRef(wsloopOp.getReductionByref());
3684 assert(isByRef.size() == wsloopOp.getNumReductionVars());
3685
3686 // Static is the default.
3687 auto schedule =
3688 wsloopOp.getScheduleKind().value_or(omp::ClauseScheduleKind::Static);
3689
3690 // Find the loop configuration.
3691 llvm::Value *step = moduleTranslation.lookupValue(loopOp.getLoopSteps()[0]);
3692 llvm::Type *ivType = step->getType();
3693 llvm::Value *chunk = nullptr;
3694 if (wsloopOp.getScheduleChunk()) {
3695 llvm::Value *chunkVar =
3696 moduleTranslation.lookupValue(wsloopOp.getScheduleChunk());
3697 chunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
3698 }
3699
3700 omp::DistributeOp distributeOp = nullptr;
3701 llvm::Value *distScheduleChunk = nullptr;
3702 bool hasDistSchedule = false;
3703 if (llvm::isa_and_present<omp::DistributeOp>(opInst.getParentOp())) {
3704 distributeOp = cast<omp::DistributeOp>(opInst.getParentOp());
3705 hasDistSchedule = distributeOp.getDistScheduleStatic();
3706 if (distributeOp.getDistScheduleChunkSize()) {
3707 llvm::Value *chunkVar = moduleTranslation.lookupValue(
3708 distributeOp.getDistScheduleChunkSize());
3709 distScheduleChunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
3710 }
3711 }
3712
3713 PrivateVarsInfo privateVarsInfo(wsloopOp);
3714
3716 collectReductionDecls(wsloopOp, reductionDecls);
3717
3718 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3719 findAllocInsertPoints(builder, moduleTranslation);
3720
3721 SmallVector<llvm::Value *> privateReductionVariables(
3722 wsloopOp.getNumReductionVars());
3723
3725 wsloopOp, builder, moduleTranslation, privateVarsInfo, allocaIP);
3726 if (handleError(afterAllocas, opInst).failed())
3727 return failure();
3728
3729 DenseMap<Value, llvm::Value *> reductionVariableMap;
3730
3731 MutableArrayRef<BlockArgument> reductionArgs =
3732 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
3733
3734 SmallVector<DeferredStore> deferredStores;
3735
3736 if (failed(allocReductionVars(wsloopOp, reductionArgs, builder,
3737 moduleTranslation, allocaIP, reductionDecls,
3738 privateReductionVariables, reductionVariableMap,
3739 deferredStores, isByRef)))
3740 return failure();
3741
3742 if (handleError(initPrivateVars(builder, moduleTranslation, privateVarsInfo),
3743 opInst)
3744 .failed())
3745 return failure();
3746
3747 if (failed(copyFirstPrivateVars(
3748 wsloopOp, builder, moduleTranslation, privateVarsInfo.mlirVars,
3749 privateVarsInfo.llvmVars, privateVarsInfo.privatizers,
3750 wsloopOp.getPrivateNeedsBarrier())))
3751 return failure();
3752
3753 assert(afterAllocas.get()->getSinglePredecessor());
3754 if (failed(initReductionVars(wsloopOp, reductionArgs, builder,
3755 moduleTranslation,
3756 afterAllocas.get()->getSinglePredecessor(),
3757 reductionDecls, privateReductionVariables,
3758 reductionVariableMap, isByRef, deferredStores)))
3759 return failure();
3760
3761 // TODO: Handle doacross loops when the ordered clause has a parameter.
3762 bool isOrdered = wsloopOp.getOrdered().has_value();
3763 std::optional<omp::ScheduleModifier> scheduleMod = wsloopOp.getScheduleMod();
3764 bool isSimd = wsloopOp.getScheduleSimd();
3765 bool loopNeedsBarrier = !wsloopOp.getNowait();
3766
3767 // The only legal way for the direct parent to be omp.distribute is that this
3768 // represents 'distribute parallel do'. Otherwise, this is a regular
3769 // worksharing loop.
3770 llvm::omp::WorksharingLoopType workshareLoopType =
3771 llvm::isa_and_present<omp::DistributeOp>(opInst.getParentOp())
3772 ? llvm::omp::WorksharingLoopType::DistributeForStaticLoop
3773 : llvm::omp::WorksharingLoopType::ForStaticLoop;
3774
3775 SmallVector<llvm::UncondBrInst *> cancelTerminators;
3776 pushCancelFinalizationCB(cancelTerminators, builder, *ompBuilder, wsloopOp,
3777 llvm::omp::Directive::OMPD_for);
3778
3779 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3780
3781 // Initialize linear variables and linear step
3782 LinearClauseProcessor linearClauseProcessor;
3783
3784 if (!wsloopOp.getLinearVars().empty()) {
3785 auto linearVarTypes = wsloopOp.getLinearVarTypes().value();
3786 for (mlir::Attribute linearVarType : linearVarTypes)
3787 linearClauseProcessor.registerType(moduleTranslation, linearVarType);
3788
3789 for (auto [idx, linearVar] : llvm::enumerate(wsloopOp.getLinearVars()))
3790 linearClauseProcessor.createLinearVar(
3791 builder, moduleTranslation, moduleTranslation.lookupValue(linearVar),
3792 idx);
3793 for (mlir::Value linearStep : wsloopOp.getLinearStepVars())
3794 linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
3795 }
3796
3798 wsloopOp.getRegion(), "omp.wsloop.region", builder, moduleTranslation);
3799
3800 if (failed(handleError(regionBlock, opInst)))
3801 return failure();
3802
3803 llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
3804
3805 // Emit Initialization and Update IR for linear variables
3806 if (!wsloopOp.getLinearVars().empty()) {
3807 linearClauseProcessor.initLinearVar(builder, moduleTranslation,
3808 loopInfo->getPreheader());
3809 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
3810 moduleTranslation.getOpenMPBuilder()->createBarrier(
3811 builder.saveIP(), llvm::omp::OMPD_barrier);
3812 if (failed(handleError(afterBarrierIP, *loopOp)))
3813 return failure();
3814 builder.restoreIP(*afterBarrierIP);
3815 linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
3816 loopInfo->getIndVar());
3817 linearClauseProcessor.splitLinearFiniBB(builder, loopInfo->getExit());
3818 }
3819
3820 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
3821
3822 // Check if we can generate no-loop kernel
3823 bool noLoopMode = false;
3824 omp::TargetOp targetOp = wsloopOp->getParentOfType<mlir::omp::TargetOp>();
3825 if (targetOp) {
3826 Operation *targetCapturedOp = targetOp.getInnermostCapturedOmpOp();
3827 // We need this check because, without it, noLoopMode would be set to true
3828 // for every omp.wsloop nested inside a no-loop SPMD target region, even if
3829 // that loop is not the top-level SPMD one.
3830 if (loopOp == targetCapturedOp) {
3831 if (targetOp.getKernelExecFlags(targetCapturedOp) ==
3832 omp::TargetExecMode::no_loop)
3833 noLoopMode = true;
3834 }
3835 }
3836
3837 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
3838 ompBuilder->applyWorkshareLoop(
3839 ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
3840 convertToScheduleKind(schedule), chunk, isSimd,
3841 scheduleMod == omp::ScheduleModifier::monotonic,
3842 scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
3843 workshareLoopType, noLoopMode, hasDistSchedule, distScheduleChunk);
3844
3845 if (failed(handleError(wsloopIP, opInst)))
3846 return failure();
3847
3848 // Emit finalization and in-place rewrites for linear vars.
3849 if (!wsloopOp.getLinearVars().empty()) {
3850 llvm::OpenMPIRBuilder::InsertPointTy oldIP = builder.saveIP();
3851 assert(loopInfo->getLastIter() &&
3852 "`lastiter` in CanonicalLoopInfo is nullptr");
3853 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
3854 linearClauseProcessor.finalizeLinearVar(builder, moduleTranslation,
3855 loopInfo->getLastIter());
3856 if (failed(handleError(afterBarrierIP, *loopOp)))
3857 return failure();
3858 for (size_t index = 0; index < wsloopOp.getLinearVars().size(); index++)
3859 linearClauseProcessor.rewriteInPlace(builder, "omp.loop_nest.region",
3860 index);
3861 builder.restoreIP(oldIP);
3862 }
3863
3864 // Set the correct branch target for task cancellation
3865 popCancelFinalizationCB(cancelTerminators, *ompBuilder, wsloopIP.get());
3866
3867 // Process the reductions if required.
3868 if (failed(createReductionsAndCleanup(
3869 wsloopOp, builder, moduleTranslation, allocaIP, reductionDecls,
3870 privateReductionVariables, isByRef, wsloopOp.getNowait(),
3871 /*isTeamsReduction=*/false)))
3872 return failure();
3873
3874 return cleanupPrivateVars(wsloopOp, builder, moduleTranslation,
3875 wsloopOp.getLoc(), privateVarsInfo);
3876}
3877
3878/// Converts the OpenMP parallel operation to LLVM IR.
3879static LogicalResult
3880convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
3881 LLVM::ModuleTranslation &moduleTranslation) {
3882 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
3883 ArrayRef<bool> isByRef = getIsByRef(opInst.getReductionByref());
3884 assert(isByRef.size() == opInst.getNumReductionVars());
3885 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3886 bool isCancellable = constructIsCancellable(opInst);
3887
3888 if (failed(checkImplementationStatus(*opInst)))
3889 return failure();
3890
3891 PrivateVarsInfo privateVarsInfo(opInst);
3892
3893 // Collect reduction declarations
3895 collectReductionDecls(opInst, reductionDecls);
3896 SmallVector<llvm::Value *> privateReductionVariables(
3897 opInst.getNumReductionVars());
3898 SmallVector<DeferredStore> deferredStores;
3899
3900 auto bodyGenCB =
3901 [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
3902 llvm::ArrayRef<llvm::BasicBlock *> deallocBlocks) -> llvm::Error {
3904 opInst, builder, moduleTranslation, privateVarsInfo, allocaIP);
3905 if (handleError(afterAllocas, *opInst).failed())
3906 return llvm::make_error<PreviouslyReportedError>();
3907
3908 // Allocate reduction vars
3909 DenseMap<Value, llvm::Value *> reductionVariableMap;
3910
3911 MutableArrayRef<BlockArgument> reductionArgs =
3912 cast<omp::BlockArgOpenMPOpInterface>(*opInst).getReductionBlockArgs();
3913
3914 allocaIP =
3915 InsertPointTy(allocaIP.getBlock(),
3916 allocaIP.getBlock()->getTerminator()->getIterator());
3917
3918 if (failed(allocReductionVars(
3919 opInst, reductionArgs, builder, moduleTranslation, allocaIP,
3920 reductionDecls, privateReductionVariables, reductionVariableMap,
3921 deferredStores, isByRef)))
3922 return llvm::make_error<PreviouslyReportedError>();
3923
3924 assert(afterAllocas.get()->getSinglePredecessor());
3925 builder.restoreIP(codeGenIP);
3926
3927 if (handleError(
3928 initPrivateVars(builder, moduleTranslation, privateVarsInfo),
3929 *opInst)
3930 .failed())
3931 return llvm::make_error<PreviouslyReportedError>();
3932
3933 if (failed(copyFirstPrivateVars(
3934 opInst, builder, moduleTranslation, privateVarsInfo.mlirVars,
3935 privateVarsInfo.llvmVars, privateVarsInfo.privatizers,
3936 opInst.getPrivateNeedsBarrier())))
3937 return llvm::make_error<PreviouslyReportedError>();
3938
3939 if (failed(
3940 initReductionVars(opInst, reductionArgs, builder, moduleTranslation,
3941 afterAllocas.get()->getSinglePredecessor(),
3942 reductionDecls, privateReductionVariables,
3943 reductionVariableMap, isByRef, deferredStores)))
3944 return llvm::make_error<PreviouslyReportedError>();
3945
3946 // Save the alloca insertion point on ModuleTranslation stack for use in
3947 // nested regions.
3949 moduleTranslation, allocaIP, deallocBlocks);
3950
3951 // ParallelOp has only one region associated with it.
3953 opInst.getRegion(), "omp.par.region", builder, moduleTranslation);
3954 if (!regionBlock)
3955 return regionBlock.takeError();
3956
3957 // Process the reductions if required.
3958 if (opInst.getNumReductionVars() > 0) {
3959 // Collect reduction info
3961 SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
3963 owningReductionGenRefDataPtrGens;
3965 collectReductionInfo(opInst, builder, moduleTranslation, reductionDecls,
3966 owningReductionGens, owningAtomicReductionGens,
3967 owningReductionGenRefDataPtrGens,
3968 privateReductionVariables, reductionInfos, isByRef);
3969
3970 // Move to region cont block
3971 builder.SetInsertPoint((*regionBlock)->getTerminator());
3972
3973 // Generate reductions from info
3974 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
3975 builder.SetInsertPoint(tempTerminator);
3976
3977 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
3978 ompBuilder->createReductions(
3979 builder.saveIP(), allocaIP, reductionInfos, isByRef,
3980 /*IsNoWait=*/false, /*IsTeamsReduction=*/false);
3981 if (!contInsertPoint)
3982 return contInsertPoint.takeError();
3983
3984 if (!contInsertPoint->getBlock())
3985 return llvm::make_error<PreviouslyReportedError>();
3986
3987 tempTerminator->eraseFromParent();
3988 builder.restoreIP(*contInsertPoint);
3989 }
3990
3991 return llvm::Error::success();
3992 };
3993
3994 auto privCB = [](InsertPointTy allocaIP, InsertPointTy codeGenIP,
3995 llvm::Value &, llvm::Value &val, llvm::Value *&replVal) {
3996 // tell OpenMPIRBuilder not to do anything. We handled Privatisation in
3997 // bodyGenCB.
3998 replVal = &val;
3999 return codeGenIP;
4000 };
4001
4002 // TODO: Perform finalization actions for variables. This has to be
4003 // called for variables which have destructors/finalizers.
4004 auto finiCB = [&](InsertPointTy codeGenIP) -> llvm::Error {
4005 InsertPointTy oldIP = builder.saveIP();
4006 builder.restoreIP(codeGenIP);
4007
4008 // if the reduction has a cleanup region, inline it here to finalize the
4009 // reduction variables
4010 SmallVector<Region *> reductionCleanupRegions;
4011 llvm::transform(reductionDecls, std::back_inserter(reductionCleanupRegions),
4012 [](omp::DeclareReductionOp reductionDecl) {
4013 return &reductionDecl.getCleanupRegion();
4014 });
4015 if (failed(inlineOmpRegionCleanup(
4016 reductionCleanupRegions, privateReductionVariables,
4017 moduleTranslation, builder, "omp.reduction.cleanup")))
4018 return llvm::createStringError(
4019 "failed to inline `cleanup` region of `omp.declare_reduction`");
4020
4021 if (failed(cleanupPrivateVars(opInst, builder, moduleTranslation,
4022 opInst.getLoc(), privateVarsInfo)))
4023 return llvm::make_error<PreviouslyReportedError>();
4024
4025 // If we could be performing cancellation, add the cancellation barrier on
4026 // the way out of the outlined region.
4027 if (isCancellable) {
4028 auto IPOrErr = ompBuilder->createBarrier(
4029 llvm::OpenMPIRBuilder::LocationDescription(builder),
4030 llvm::omp::Directive::OMPD_unknown,
4031 /* ForceSimpleCall */ false,
4032 /* CheckCancelFlag */ false);
4033 if (!IPOrErr)
4034 return IPOrErr.takeError();
4035 }
4036
4037 builder.restoreIP(oldIP);
4038 return llvm::Error::success();
4039 };
4040
4041 llvm::Value *ifCond = nullptr;
4042 if (auto ifVar = opInst.getIfExpr())
4043 ifCond = moduleTranslation.lookupValue(ifVar);
4044 llvm::Value *numThreads = nullptr;
4045 if (!opInst.getNumThreadsVars().empty())
4046 numThreads = moduleTranslation.lookupValue(opInst.getNumThreads(0));
4047 auto pbKind = llvm::omp::OMP_PROC_BIND_default;
4048 if (auto bind = opInst.getProcBindKind())
4049 pbKind = getProcBindKind(*bind);
4050
4052 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4053 findAllocInsertPoints(builder, moduleTranslation, &deallocBlocks);
4054 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4055
4056 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4057 ompBuilder->createParallel(ompLoc, allocaIP, deallocBlocks, bodyGenCB,
4058 privCB, finiCB, ifCond, numThreads, pbKind,
4059 isCancellable);
4060
4061 if (failed(handleError(afterIP, *opInst)))
4062 return failure();
4063
4064 builder.restoreIP(*afterIP);
4065 return success();
4066}
4067
4068/// Convert Order attribute to llvm::omp::OrderKind.
4069static llvm::omp::OrderKind
4070convertOrderKind(std::optional<omp::ClauseOrderKind> o) {
4071 if (!o)
4072 return llvm::omp::OrderKind::OMP_ORDER_unknown;
4073 switch (*o) {
4074 case omp::ClauseOrderKind::Concurrent:
4075 return llvm::omp::OrderKind::OMP_ORDER_concurrent;
4076 }
4077 llvm_unreachable("Unknown ClauseOrderKind kind");
4078}
4079
4080/// Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder.
4081static LogicalResult
4082convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
4083 LLVM::ModuleTranslation &moduleTranslation) {
4084 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4085 auto simdOp = cast<omp::SimdOp>(opInst);
4086
4087 if (failed(checkImplementationStatus(opInst)))
4088 return failure();
4089
4090 PrivateVarsInfo privateVarsInfo(simdOp);
4091
4092 MutableArrayRef<BlockArgument> reductionArgs =
4093 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
4094 DenseMap<Value, llvm::Value *> reductionVariableMap;
4095 SmallVector<llvm::Value *> privateReductionVariables(
4096 simdOp.getNumReductionVars());
4097 SmallVector<DeferredStore> deferredStores;
4099 collectReductionDecls(simdOp, reductionDecls);
4100 llvm::ArrayRef<bool> isByRef = getIsByRef(simdOp.getReductionByref());
4101 assert(isByRef.size() == simdOp.getNumReductionVars());
4102
4103 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4104 findAllocInsertPoints(builder, moduleTranslation);
4105
4107 simdOp, builder, moduleTranslation, privateVarsInfo, allocaIP);
4108 if (handleError(afterAllocas, opInst).failed())
4109 return failure();
4110
4111 // Initialize linear variables and linear step
4112 LinearClauseProcessor linearClauseProcessor;
4113
4114 if (!simdOp.getLinearVars().empty()) {
4115 auto linearVarTypes = simdOp.getLinearVarTypes().value();
4116 for (mlir::Attribute linearVarType : linearVarTypes)
4117 linearClauseProcessor.registerType(moduleTranslation, linearVarType);
4118 for (auto [idx, linearVar] : llvm::enumerate(simdOp.getLinearVars())) {
4119 bool isImplicit = false;
4120 for (auto [mlirPrivVar, llvmPrivateVar] : llvm::zip_equal(
4121 privateVarsInfo.mlirVars, privateVarsInfo.llvmVars)) {
4122 // If the linear variable is implicit, reuse the already
4123 // existing llvm::Value
4124 if (linearVar == mlirPrivVar) {
4125 isImplicit = true;
4126 linearClauseProcessor.createLinearVar(builder, moduleTranslation,
4127 llvmPrivateVar, idx);
4128 break;
4129 }
4130 }
4131
4132 if (!isImplicit)
4133 linearClauseProcessor.createLinearVar(
4134 builder, moduleTranslation,
4135 moduleTranslation.lookupValue(linearVar), idx);
4136 }
4137 for (mlir::Value linearStep : simdOp.getLinearStepVars())
4138 linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
4139 }
4140
4141 if (failed(allocReductionVars(simdOp, reductionArgs, builder,
4142 moduleTranslation, allocaIP, reductionDecls,
4143 privateReductionVariables, reductionVariableMap,
4144 deferredStores, isByRef)))
4145 return failure();
4146
4147 if (handleError(initPrivateVars(builder, moduleTranslation, privateVarsInfo),
4148 opInst)
4149 .failed())
4150 return failure();
4151
4152 // No call to copyFirstPrivateVars because FIRSTPRIVATE is not allowed for
4153 // SIMD.
4154
4155 assert(afterAllocas.get()->getSinglePredecessor());
4156 if (failed(initReductionVars(simdOp, reductionArgs, builder,
4157 moduleTranslation,
4158 afterAllocas.get()->getSinglePredecessor(),
4159 reductionDecls, privateReductionVariables,
4160 reductionVariableMap, isByRef, deferredStores)))
4161 return failure();
4162
4163 llvm::ConstantInt *simdlen = nullptr;
4164 if (std::optional<uint64_t> simdlenVar = simdOp.getSimdlen())
4165 simdlen = builder.getInt64(simdlenVar.value());
4166
4167 llvm::ConstantInt *safelen = nullptr;
4168 if (std::optional<uint64_t> safelenVar = simdOp.getSafelen())
4169 safelen = builder.getInt64(safelenVar.value());
4170
4171 llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars;
4172 llvm::omp::OrderKind order = convertOrderKind(simdOp.getOrder());
4173
4174 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
4175 std::optional<ArrayAttr> alignmentValues = simdOp.getAlignments();
4176 mlir::OperandRange operands = simdOp.getAlignedVars();
4177 for (size_t i = 0; i < operands.size(); ++i) {
4178 llvm::Value *alignment = nullptr;
4179 llvm::Value *llvmVal = moduleTranslation.lookupValue(operands[i]);
4180 llvm::Type *ty = llvmVal->getType();
4181
4182 auto intAttr = cast<IntegerAttr>((*alignmentValues)[i]);
4183 alignment = builder.getInt64(intAttr.getInt());
4184 assert(ty->isPointerTy() && "Invalid type for aligned variable");
4185 assert(alignment && "Invalid alignment value");
4186
4187 // Check if the alignment value is not a power of 2. If so, skip emitting
4188 // alignment.
4189 if (!intAttr.getValue().isPowerOf2())
4190 continue;
4191
4192 auto curInsert = builder.saveIP();
4193 builder.SetInsertPoint(sourceBlock);
4194 llvmVal = builder.CreateLoad(ty, llvmVal);
4195 builder.restoreIP(curInsert);
4196 alignedVars[llvmVal] = alignment;
4197 }
4198
4200 simdOp.getRegion(), "omp.simd.region", builder, moduleTranslation);
4201
4202 if (failed(handleError(regionBlock, opInst)))
4203 return failure();
4204
4205 llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
4206 // Emit Initialization for linear variables
4207 if (simdOp.getLinearVars().size()) {
4208 linearClauseProcessor.initLinearVar(builder, moduleTranslation,
4209 loopInfo->getPreheader());
4210
4211 linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
4212 loopInfo->getIndVar());
4213 }
4214 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
4215
4216 ompBuilder->applySimd(loopInfo, alignedVars,
4217 simdOp.getIfExpr()
4218 ? moduleTranslation.lookupValue(simdOp.getIfExpr())
4219 : nullptr,
4220 order, simdlen, safelen);
4221
4222 linearClauseProcessor.emitStoresForLinearVar(builder);
4223
4224 // Check if this SIMD loop contains ordered regions
4225 bool hasOrderedRegions = false;
4226 simdOp.getRegion().walk([&](omp::OrderedRegionOp orderedOp) {
4227 hasOrderedRegions = true;
4228 return WalkResult::interrupt();
4229 });
4230
4231 for (size_t index = 0; index < simdOp.getLinearVars().size(); index++) {
4232 linearClauseProcessor.rewriteInPlace(builder, "omp.loop_nest.region",
4233 index);
4234 if (hasOrderedRegions) {
4235 // Also rewrite uses in ordered regions so they read the current value
4236 linearClauseProcessor.rewriteInPlace(builder, "omp.ordered.region",
4237 index);
4238 // Also rewrite uses in finalize blocks (code after ordered regions)
4239 linearClauseProcessor.rewriteInPlace(builder, "omp_region.finalize",
4240 index);
4241 }
4242 }
4243
4244 // We now need to reduce the per-simd-lane reduction variable into the
4245 // original variable. This works a bit differently to other reductions (e.g.
4246 // wsloop) because we don't need to call into the OpenMP runtime to handle
4247 // threads: everything happened in this one thread.
4248 for (auto [i, tuple] : llvm::enumerate(
4249 llvm::zip(reductionDecls, isByRef, simdOp.getReductionVars(),
4250 privateReductionVariables))) {
4251 auto [decl, byRef, reductionVar, privateReductionVar] = tuple;
4252
4253 OwningReductionGen gen = makeReductionGen(decl, builder, moduleTranslation);
4254 llvm::Value *originalVariable = moduleTranslation.lookupValue(reductionVar);
4255 llvm::Type *reductionType = moduleTranslation.convertType(decl.getType());
4256
4257 // We have one less load for by-ref case because that load is now inside of
4258 // the reduction region.
4259 llvm::Value *redValue = originalVariable;
4260 if (!byRef)
4261 redValue =
4262 builder.CreateLoad(reductionType, redValue, "red.value." + Twine(i));
4263 llvm::Value *privateRedValue = builder.CreateLoad(
4264 reductionType, privateReductionVar, "red.private.value." + Twine(i));
4265 llvm::Value *reduced;
4266
4267 auto res = gen(builder.saveIP(), redValue, privateRedValue, reduced);
4268 if (failed(handleError(res, opInst)))
4269 return failure();
4270 builder.restoreIP(res.get());
4271
4272 // For by-ref case, the store is inside of the reduction region.
4273 if (!byRef)
4274 builder.CreateStore(reduced, originalVariable);
4275 }
4276
4277 // After the construct, deallocate private reduction variables.
4278 SmallVector<Region *> reductionRegions;
4279 llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
4280 [](omp::DeclareReductionOp reductionDecl) {
4281 return &reductionDecl.getCleanupRegion();
4282 });
4283 if (failed(inlineOmpRegionCleanup(reductionRegions, privateReductionVariables,
4284 moduleTranslation, builder,
4285 "omp.reduction.cleanup")))
4286 return failure();
4287
4288 return cleanupPrivateVars(simdOp, builder, moduleTranslation, simdOp.getLoc(),
4289 privateVarsInfo);
4290}
4291
4292/// Converts an OpenMP loop nest into LLVM IR using OpenMPIRBuilder.
4293static LogicalResult
4294convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder,
4295 LLVM::ModuleTranslation &moduleTranslation) {
4296 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4297 auto loopOp = cast<omp::LoopNestOp>(opInst);
4298
4299 if (failed(checkImplementationStatus(opInst)))
4300 return failure();
4301
4302 // Set up the source location value for OpenMP runtime.
4303 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4304
4305 // Generator of the canonical loop body.
4308 auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip,
4309 llvm::Value *iv) -> llvm::Error {
4310 // Make sure further conversions know about the induction variable.
4311 moduleTranslation.mapValue(
4312 loopOp.getRegion().front().getArgument(loopInfos.size()), iv);
4313
4314 // Capture the body insertion point for use in nested loops. BodyIP of the
4315 // CanonicalLoopInfo always points to the beginning of the entry block of
4316 // the body.
4317 bodyInsertPoints.push_back(ip);
4318
4319 if (loopInfos.size() != loopOp.getNumLoops() - 1)
4320 return llvm::Error::success();
4321
4322 // Convert the body of the loop.
4323 builder.restoreIP(ip);
4325 loopOp.getRegion(), "omp.loop_nest.region", builder, moduleTranslation);
4326 if (!regionBlock)
4327 return regionBlock.takeError();
4328
4329 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
4330 return llvm::Error::success();
4331 };
4332
4333 // Delegate actual loop construction to the OpenMP IRBuilder.
4334 // TODO: this currently assumes omp.loop_nest is semantically similar to SCF
4335 // loop, i.e. it has a positive step, uses signed integer semantics.
4336 // Reconsider this code when the nested loop operation clearly supports more
4337 // cases.
4338 for (unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) {
4339 llvm::Value *lowerBound =
4340 moduleTranslation.lookupValue(loopOp.getLoopLowerBounds()[i]);
4341 llvm::Value *upperBound =
4342 moduleTranslation.lookupValue(loopOp.getLoopUpperBounds()[i]);
4343 llvm::Value *step = moduleTranslation.lookupValue(loopOp.getLoopSteps()[i]);
4344
4345 // Make sure loop trip count are emitted in the preheader of the outermost
4346 // loop at the latest so that they are all available for the new collapsed
4347 // loop will be created below.
4348 llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc;
4349 llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP;
4350 if (i != 0) {
4351 loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back(),
4352 ompLoc.DL);
4353 computeIP = loopInfos.front()->getPreheaderIP();
4354 }
4355
4357 ompBuilder->createCanonicalLoop(
4358 loc, bodyGen, lowerBound, upperBound, step,
4359 /*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP);
4360
4361 if (failed(handleError(loopResult, *loopOp)))
4362 return failure();
4363
4364 loopInfos.push_back(*loopResult);
4365 }
4366
4367 llvm::OpenMPIRBuilder::InsertPointTy afterIP =
4368 loopInfos.front()->getAfterIP();
4369
4370 // Do tiling.
4371 if (const auto &tiles = loopOp.getTileSizes()) {
4372 llvm::Type *ivType = loopInfos.front()->getIndVarType();
4374
4375 for (auto tile : tiles.value()) {
4376 llvm::Value *tileVal = llvm::ConstantInt::get(ivType, tile);
4377 tileSizes.push_back(tileVal);
4378 }
4379
4380 std::vector<llvm::CanonicalLoopInfo *> newLoops =
4381 ompBuilder->tileLoops(ompLoc.DL, loopInfos, tileSizes);
4382
4383 // Update afterIP to get the correct insertion point after
4384 // tiling.
4385 llvm::BasicBlock *afterBB = newLoops.front()->getAfter();
4386 llvm::BasicBlock *afterAfterBB = afterBB->getSingleSuccessor();
4387 afterIP = {afterAfterBB, afterAfterBB->begin()};
4388
4389 // Update the loop infos.
4390 loopInfos.clear();
4391 for (const auto &newLoop : newLoops)
4392 loopInfos.push_back(newLoop);
4393 } // Tiling done.
4394
4395 // Do collapse.
4396 const auto &numCollapse = loopOp.getCollapseNumLoops();
4398 loopInfos.begin(), loopInfos.begin() + (numCollapse));
4399
4400 auto newTopLoopInfo =
4401 ompBuilder->collapseLoops(ompLoc.DL, collapseLoopInfos, {});
4402
4403 assert(newTopLoopInfo && "New top loop information is missing");
4404 moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
4405 [&](OpenMPLoopInfoStackFrame &frame) {
4406 frame.loopInfo = newTopLoopInfo;
4407 return WalkResult::interrupt();
4408 });
4409
4410 // Continue building IR after the loop. Note that the LoopInfo returned by
4411 // `collapseLoops` points inside the outermost loop and is intended for
4412 // potential further loop transformations. Use the insertion point stored
4413 // before collapsing loops instead.
4414 builder.restoreIP(afterIP);
4415 return success();
4416}
4417
4418/// Convert an omp.canonical_loop to LLVM-IR
4419static LogicalResult
4420convertOmpCanonicalLoopOp(omp::CanonicalLoopOp op, llvm::IRBuilderBase &builder,
4421 LLVM::ModuleTranslation &moduleTranslation) {
4422 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4423
4424 llvm::OpenMPIRBuilder::LocationDescription loopLoc(builder);
4425 Value loopIV = op.getInductionVar();
4426 Value loopTC = op.getTripCount();
4427
4428 llvm::Value *llvmTC = moduleTranslation.lookupValue(loopTC);
4429
4431 ompBuilder->createCanonicalLoop(
4432 loopLoc,
4433 [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *llvmIV) {
4434 // Register the mapping of MLIR induction variable to LLVM-IR
4435 // induction variable
4436 moduleTranslation.mapValue(loopIV, llvmIV);
4437
4438 builder.restoreIP(ip);
4440 convertOmpOpRegions(op.getRegion(), "omp.loop.region", builder,
4441 moduleTranslation);
4442
4443 return bodyGenStatus.takeError();
4444 },
4445 llvmTC, "omp.loop");
4446 if (!llvmOrError)
4447 return op.emitError(llvm::toString(llvmOrError.takeError()));
4448
4449 llvm::CanonicalLoopInfo *llvmCLI = *llvmOrError;
4450 llvm::IRBuilderBase::InsertPoint afterIP = llvmCLI->getAfterIP();
4451 builder.restoreIP(afterIP);
4452
4453 // Register the mapping of MLIR loop to LLVM-IR OpenMPIRBuilder loop
4454 if (Value cli = op.getCli())
4455 moduleTranslation.mapOmpLoop(cli, llvmCLI);
4456
4457 return success();
4458}
4459
4460/// Apply a `#pragma omp unroll` / "!$omp unroll" transformation using the
4461/// OpenMPIRBuilder.
4462static LogicalResult
4463applyUnrollHeuristic(omp::UnrollHeuristicOp op, llvm::IRBuilderBase &builder,
4464 LLVM::ModuleTranslation &moduleTranslation) {
4465 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4466
4467 Value applyee = op.getApplyee();
4468 assert(applyee && "Loop to apply unrolling on required");
4469
4470 llvm::CanonicalLoopInfo *consBuilderCLI =
4471 moduleTranslation.lookupOMPLoop(applyee);
4472 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
4473 ompBuilder->unrollLoopHeuristic(loc.DL, consBuilderCLI);
4474
4475 moduleTranslation.invalidateOmpLoop(applyee);
4476 return success();
4477}
4478
4479/// Apply a `#pragma omp tile` / `!$omp tile` transformation using the
4480/// OpenMPIRBuilder.
4481static LogicalResult applyTile(omp::TileOp op, llvm::IRBuilderBase &builder,
4482 LLVM::ModuleTranslation &moduleTranslation) {
4483 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4484 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
4485
4487 SmallVector<llvm::Value *> translatedSizes;
4488
4489 for (Value size : op.getSizes()) {
4490 llvm::Value *translatedSize = moduleTranslation.lookupValue(size);
4491 assert(translatedSize &&
4492 "sizes clause arguments must already be translated");
4493 translatedSizes.push_back(translatedSize);
4494 }
4495
4496 for (Value applyee : op.getApplyees()) {
4497 llvm::CanonicalLoopInfo *consBuilderCLI =
4498 moduleTranslation.lookupOMPLoop(applyee);
4499 assert(applyee && "Canonical loop must already been translated");
4500 translatedLoops.push_back(consBuilderCLI);
4501 }
4502
4503 auto generatedLoops =
4504 ompBuilder->tileLoops(loc.DL, translatedLoops, translatedSizes);
4505 if (!op.getGeneratees().empty()) {
4506 for (auto [mlirLoop, genLoop] :
4507 zip_equal(op.getGeneratees(), generatedLoops))
4508 moduleTranslation.mapOmpLoop(mlirLoop, genLoop);
4509 }
4510
4511 // CLIs can only be consumed once
4512 for (Value applyee : op.getApplyees())
4513 moduleTranslation.invalidateOmpLoop(applyee);
4514
4515 return success();
4516}
4517
4518/// Apply a `#pragma omp fuse` / `!$omp fuse` transformation using the
4519/// OpenMPIRBuilder.
4520static LogicalResult applyFuse(omp::FuseOp op, llvm::IRBuilderBase &builder,
4521 LLVM::ModuleTranslation &moduleTranslation) {
4522 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4523 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
4524
4525 // Select what CLIs are going to be fused
4526 SmallVector<llvm::CanonicalLoopInfo *> beforeFuse, toFuse, afterFuse;
4527 for (size_t i = 0; i < op.getApplyees().size(); i++) {
4528 Value applyee = op.getApplyees()[i];
4529 llvm::CanonicalLoopInfo *consBuilderCLI =
4530 moduleTranslation.lookupOMPLoop(applyee);
4531 assert(applyee && "Canonical loop must already been translated");
4532 if (op.getFirst().has_value() && i < op.getFirst().value() - 1)
4533 beforeFuse.push_back(consBuilderCLI);
4534 else if (op.getCount().has_value() &&
4535 i >= op.getFirst().value() + op.getCount().value() - 1)
4536 afterFuse.push_back(consBuilderCLI);
4537 else
4538 toFuse.push_back(consBuilderCLI);
4539 }
4540 assert(
4541 (op.getGeneratees().empty() ||
4542 beforeFuse.size() + afterFuse.size() + 1 == op.getGeneratees().size()) &&
4543 "Wrong number of generatees");
4544
4545 // do the fuse
4546 auto generatedLoop = ompBuilder->fuseLoops(loc.DL, toFuse);
4547 if (!op.getGeneratees().empty()) {
4548 size_t i = 0;
4549 for (; i < beforeFuse.size(); i++)
4550 moduleTranslation.mapOmpLoop(op.getGeneratees()[i], beforeFuse[i]);
4551 moduleTranslation.mapOmpLoop(op.getGeneratees()[i++], generatedLoop);
4552 for (; i < afterFuse.size(); i++)
4553 moduleTranslation.mapOmpLoop(op.getGeneratees()[i], afterFuse[i]);
4554 }
4555
4556 // CLIs can only be consumed once
4557 for (Value applyee : op.getApplyees())
4558 moduleTranslation.invalidateOmpLoop(applyee);
4559
4560 return success();
4561}
4562
4563/// Convert an Atomic Ordering attribute to llvm::AtomicOrdering.
4564static llvm::AtomicOrdering
4565convertAtomicOrdering(std::optional<omp::ClauseMemoryOrderKind> ao) {
4566 if (!ao)
4567 return llvm::AtomicOrdering::Monotonic; // Default Memory Ordering
4568
4569 switch (*ao) {
4570 case omp::ClauseMemoryOrderKind::Seq_cst:
4571 return llvm::AtomicOrdering::SequentiallyConsistent;
4572 case omp::ClauseMemoryOrderKind::Acq_rel:
4573 return llvm::AtomicOrdering::AcquireRelease;
4574 case omp::ClauseMemoryOrderKind::Acquire:
4575 return llvm::AtomicOrdering::Acquire;
4576 case omp::ClauseMemoryOrderKind::Release:
4577 return llvm::AtomicOrdering::Release;
4578 case omp::ClauseMemoryOrderKind::Relaxed:
4579 return llvm::AtomicOrdering::Monotonic;
4580 }
4581 llvm_unreachable("Unknown ClauseMemoryOrderKind kind");
4582}
4583
4584/// Convert omp.atomic.read operation to LLVM IR.
4585static LogicalResult
4586convertOmpAtomicRead(Operation &opInst, llvm::IRBuilderBase &builder,
4587 LLVM::ModuleTranslation &moduleTranslation) {
4588 auto readOp = cast<omp::AtomicReadOp>(opInst);
4589 if (failed(checkImplementationStatus(opInst)))
4590 return failure();
4591
4592 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4593 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4594 findAllocInsertPoints(builder, moduleTranslation);
4595
4596 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4597
4598 llvm::AtomicOrdering AO = convertAtomicOrdering(readOp.getMemoryOrder());
4599 llvm::Value *x = moduleTranslation.lookupValue(readOp.getX());
4600 llvm::Value *v = moduleTranslation.lookupValue(readOp.getV());
4601
4602 llvm::Type *elementType =
4603 moduleTranslation.convertType(readOp.getElementType());
4604
4605 llvm::OpenMPIRBuilder::AtomicOpValue V = {v, elementType, false, false};
4606 llvm::OpenMPIRBuilder::AtomicOpValue X = {x, elementType, false, false};
4607 builder.restoreIP(ompBuilder->createAtomicRead(ompLoc, X, V, AO, allocaIP));
4608 return success();
4609}
4610
4611/// Converts an omp.atomic.write operation to LLVM IR.
4612static LogicalResult
4613convertOmpAtomicWrite(Operation &opInst, llvm::IRBuilderBase &builder,
4614 LLVM::ModuleTranslation &moduleTranslation) {
4615 auto writeOp = cast<omp::AtomicWriteOp>(opInst);
4616 if (failed(checkImplementationStatus(opInst)))
4617 return failure();
4618
4619 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4620 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4621 findAllocInsertPoints(builder, moduleTranslation);
4622
4623 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4624 llvm::AtomicOrdering ao = convertAtomicOrdering(writeOp.getMemoryOrder());
4625 llvm::Value *expr = moduleTranslation.lookupValue(writeOp.getExpr());
4626 llvm::Value *dest = moduleTranslation.lookupValue(writeOp.getX());
4627 llvm::Type *ty = moduleTranslation.convertType(writeOp.getExpr().getType());
4628 llvm::OpenMPIRBuilder::AtomicOpValue x = {dest, ty, /*isSigned=*/false,
4629 /*isVolatile=*/false};
4630 builder.restoreIP(
4631 ompBuilder->createAtomicWrite(ompLoc, x, expr, ao, allocaIP));
4632 return success();
4633}
4634
4635/// Converts an LLVM dialect binary operation to the corresponding enum value
4636/// for `atomicrmw` supported binary operation.
4637static llvm::AtomicRMWInst::BinOp convertBinOpToAtomic(Operation &op) {
4639 .Case([&](LLVM::AddOp) { return llvm::AtomicRMWInst::BinOp::Add; })
4640 .Case([&](LLVM::SubOp) { return llvm::AtomicRMWInst::BinOp::Sub; })
4641 .Case([&](LLVM::AndOp) { return llvm::AtomicRMWInst::BinOp::And; })
4642 .Case([&](LLVM::OrOp) { return llvm::AtomicRMWInst::BinOp::Or; })
4643 .Case([&](LLVM::XOrOp) { return llvm::AtomicRMWInst::BinOp::Xor; })
4644 .Case([&](LLVM::UMaxOp) { return llvm::AtomicRMWInst::BinOp::UMax; })
4645 .Case([&](LLVM::UMinOp) { return llvm::AtomicRMWInst::BinOp::UMin; })
4646 .Case([&](LLVM::FAddOp) { return llvm::AtomicRMWInst::BinOp::FAdd; })
4647 .Case([&](LLVM::FSubOp) { return llvm::AtomicRMWInst::BinOp::FSub; })
4648 .Default(llvm::AtomicRMWInst::BinOp::BAD_BINOP);
4649}
4650
4651static void extractAtomicControlFlags(omp::AtomicUpdateOp atomicUpdateOp,
4652 bool &isIgnoreDenormalMode,
4653 bool &isFineGrainedMemory,
4654 bool &isRemoteMemory) {
4655 isIgnoreDenormalMode = false;
4656 isFineGrainedMemory = false;
4657 isRemoteMemory = false;
4658 if (atomicUpdateOp &&
4659 atomicUpdateOp->hasAttr(atomicUpdateOp.getAtomicControlAttrName())) {
4660 mlir::omp::AtomicControlAttr atomicControlAttr =
4661 atomicUpdateOp.getAtomicControlAttr();
4662 isIgnoreDenormalMode = atomicControlAttr.getIgnoreDenormalMode();
4663 isFineGrainedMemory = atomicControlAttr.getFineGrainedMemory();
4664 isRemoteMemory = atomicControlAttr.getRemoteMemory();
4665 }
4666}
4667
4668/// Converts an OpenMP atomic update operation using OpenMPIRBuilder.
4669static LogicalResult
4670convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst,
4671 llvm::IRBuilderBase &builder,
4672 LLVM::ModuleTranslation &moduleTranslation) {
4673 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4674 if (failed(checkImplementationStatus(*opInst)))
4675 return failure();
4676
4677 // Convert values and types.
4678 auto &innerOpList = opInst.getRegion().front().getOperations();
4679 bool isXBinopExpr{false};
4680 llvm::AtomicRMWInst::BinOp binop;
4681 mlir::Value mlirExpr;
4682 llvm::Value *llvmExpr = nullptr;
4683 llvm::Value *llvmX = nullptr;
4684 llvm::Type *llvmXElementType = nullptr;
4685 if (innerOpList.size() == 2) {
4686 // The two operations here are the update and the terminator.
4687 // Since we can identify the update operation, there is a possibility
4688 // that we can generate the atomicrmw instruction.
4689 mlir::Operation &innerOp = *opInst.getRegion().front().begin();
4690 if (!llvm::is_contained(innerOp.getOperands(),
4691 opInst.getRegion().getArgument(0))) {
4692 return opInst.emitError("no atomic update operation with region argument"
4693 " as operand found inside atomic.update region");
4694 }
4695 binop = convertBinOpToAtomic(innerOp);
4696 isXBinopExpr = innerOp.getOperand(0) == opInst.getRegion().getArgument(0);
4697 mlirExpr = (isXBinopExpr ? innerOp.getOperand(1) : innerOp.getOperand(0));
4698 llvmExpr = moduleTranslation.lookupValue(mlirExpr);
4699 } else {
4700 // Since the update region includes more than one operation
4701 // we will resort to generating a cmpxchg loop.
4702 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
4703 }
4704 llvmX = moduleTranslation.lookupValue(opInst.getX());
4705 llvmXElementType = moduleTranslation.convertType(
4706 opInst.getRegion().getArgument(0).getType());
4707 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
4708 /*isSigned=*/false,
4709 /*isVolatile=*/false};
4710
4711 llvm::AtomicOrdering atomicOrdering =
4712 convertAtomicOrdering(opInst.getMemoryOrder());
4713
4714 // Generate update code.
4715 auto updateFn =
4716 [&opInst, &moduleTranslation](
4717 llvm::Value *atomicx,
4718 llvm::IRBuilder<> &builder) -> llvm::Expected<llvm::Value *> {
4719 Block &bb = *opInst.getRegion().begin();
4720 moduleTranslation.mapValue(*opInst.getRegion().args_begin(), atomicx);
4721 moduleTranslation.mapBlock(&bb, builder.GetInsertBlock());
4722 if (failed(moduleTranslation.convertBlock(bb, true, builder)))
4723 return llvm::make_error<PreviouslyReportedError>();
4724
4725 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.getTerminator());
4726 assert(yieldop && yieldop.getResults().size() == 1 &&
4727 "terminator must be omp.yield op and it must have exactly one "
4728 "argument");
4729 return moduleTranslation.lookupValue(yieldop.getResults()[0]);
4730 };
4731
4732 bool isIgnoreDenormalMode;
4733 bool isFineGrainedMemory;
4734 bool isRemoteMemory;
4735 extractAtomicControlFlags(opInst, isIgnoreDenormalMode, isFineGrainedMemory,
4736 isRemoteMemory);
4737 // Handle ambiguous alloca, if any.
4738 auto allocaIP = findAllocInsertPoints(builder, moduleTranslation);
4739 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4740 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4741 ompBuilder->createAtomicUpdate(ompLoc, allocaIP, llvmAtomicX, llvmExpr,
4742 atomicOrdering, binop, updateFn,
4743 isXBinopExpr, isIgnoreDenormalMode,
4744 isFineGrainedMemory, isRemoteMemory);
4745
4746 if (failed(handleError(afterIP, *opInst)))
4747 return failure();
4748
4749 builder.restoreIP(*afterIP);
4750 return success();
4751}
4752
4753static LogicalResult
4754convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp,
4755 llvm::IRBuilderBase &builder,
4756 LLVM::ModuleTranslation &moduleTranslation) {
4757 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4758 if (failed(checkImplementationStatus(*atomicCaptureOp)))
4759 return failure();
4760
4761 mlir::Value mlirExpr;
4762 bool isXBinopExpr = false, isPostfixUpdate = false;
4763 llvm::AtomicRMWInst::BinOp binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
4764
4765 omp::AtomicUpdateOp atomicUpdateOp = atomicCaptureOp.getAtomicUpdateOp();
4766 omp::AtomicWriteOp atomicWriteOp = atomicCaptureOp.getAtomicWriteOp();
4767
4768 assert((atomicUpdateOp || atomicWriteOp) &&
4769 "internal op must be an atomic.update or atomic.write op");
4770
4771 if (atomicWriteOp) {
4772 isPostfixUpdate = true;
4773 mlirExpr = atomicWriteOp.getExpr();
4774 } else {
4775 isPostfixUpdate = atomicCaptureOp.getSecondOp() ==
4776 atomicCaptureOp.getAtomicUpdateOp().getOperation();
4777 auto &innerOpList = atomicUpdateOp.getRegion().front().getOperations();
4778 // Find the binary update operation that uses the region argument
4779 // and get the expression to update
4780 if (innerOpList.size() == 2) {
4781 mlir::Operation &innerOp = *atomicUpdateOp.getRegion().front().begin();
4782 if (!llvm::is_contained(innerOp.getOperands(),
4783 atomicUpdateOp.getRegion().getArgument(0))) {
4784 return atomicUpdateOp.emitError(
4785 "no atomic update operation with region argument"
4786 " as operand found inside atomic.update region");
4787 }
4788 binop = convertBinOpToAtomic(innerOp);
4789 isXBinopExpr =
4790 innerOp.getOperand(0) == atomicUpdateOp.getRegion().getArgument(0);
4791 mlirExpr = (isXBinopExpr ? innerOp.getOperand(1) : innerOp.getOperand(0));
4792 } else {
4793 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
4794 }
4795 }
4796
4797 llvm::Value *llvmExpr = moduleTranslation.lookupValue(mlirExpr);
4798 llvm::Value *llvmX =
4799 moduleTranslation.lookupValue(atomicCaptureOp.getAtomicReadOp().getX());
4800 llvm::Value *llvmV =
4801 moduleTranslation.lookupValue(atomicCaptureOp.getAtomicReadOp().getV());
4802 llvm::Type *llvmXElementType = moduleTranslation.convertType(
4803 atomicCaptureOp.getAtomicReadOp().getElementType());
4804 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
4805 /*isSigned=*/false,
4806 /*isVolatile=*/false};
4807 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicV = {llvmV, llvmXElementType,
4808 /*isSigned=*/false,
4809 /*isVolatile=*/false};
4810
4811 llvm::AtomicOrdering atomicOrdering =
4812 convertAtomicOrdering(atomicCaptureOp.getMemoryOrder());
4813
4814 auto updateFn =
4815 [&](llvm::Value *atomicx,
4816 llvm::IRBuilder<> &builder) -> llvm::Expected<llvm::Value *> {
4817 if (atomicWriteOp)
4818 return moduleTranslation.lookupValue(atomicWriteOp.getExpr());
4819 Block &bb = *atomicUpdateOp.getRegion().begin();
4820 moduleTranslation.mapValue(*atomicUpdateOp.getRegion().args_begin(),
4821 atomicx);
4822 moduleTranslation.mapBlock(&bb, builder.GetInsertBlock());
4823 if (failed(moduleTranslation.convertBlock(bb, true, builder)))
4824 return llvm::make_error<PreviouslyReportedError>();
4825
4826 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.getTerminator());
4827 assert(yieldop && yieldop.getResults().size() == 1 &&
4828 "terminator must be omp.yield op and it must have exactly one "
4829 "argument");
4830 return moduleTranslation.lookupValue(yieldop.getResults()[0]);
4831 };
4832
4833 bool isIgnoreDenormalMode;
4834 bool isFineGrainedMemory;
4835 bool isRemoteMemory;
4836 extractAtomicControlFlags(atomicUpdateOp, isIgnoreDenormalMode,
4837 isFineGrainedMemory, isRemoteMemory);
4838 // Handle ambiguous alloca, if any.
4839 auto allocaIP = findAllocInsertPoints(builder, moduleTranslation);
4840 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4841 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4842 ompBuilder->createAtomicCapture(
4843 ompLoc, allocaIP, llvmAtomicX, llvmAtomicV, llvmExpr, atomicOrdering,
4844 binop, updateFn, atomicUpdateOp, isPostfixUpdate, isXBinopExpr,
4845 isIgnoreDenormalMode, isFineGrainedMemory, isRemoteMemory);
4846
4847 if (failed(handleError(afterIP, *atomicCaptureOp)))
4848 return failure();
4849
4850 builder.restoreIP(*afterIP);
4851 return success();
4852}
4853
4854static llvm::omp::Directive convertCancellationConstructType(
4855 omp::ClauseCancellationConstructType directive) {
4856 switch (directive) {
4857 case omp::ClauseCancellationConstructType::Loop:
4858 return llvm::omp::Directive::OMPD_for;
4859 case omp::ClauseCancellationConstructType::Parallel:
4860 return llvm::omp::Directive::OMPD_parallel;
4861 case omp::ClauseCancellationConstructType::Sections:
4862 return llvm::omp::Directive::OMPD_sections;
4863 case omp::ClauseCancellationConstructType::Taskgroup:
4864 return llvm::omp::Directive::OMPD_taskgroup;
4865 }
4866 llvm_unreachable("Unhandled cancellation construct type");
4867}
4868
4869static LogicalResult
4870convertOmpCancel(omp::CancelOp op, llvm::IRBuilderBase &builder,
4871 LLVM::ModuleTranslation &moduleTranslation) {
4872 if (failed(checkImplementationStatus(*op.getOperation())))
4873 return failure();
4874
4875 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4876 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4877
4878 llvm::Value *ifCond = nullptr;
4879 if (Value ifVar = op.getIfExpr())
4880 ifCond = moduleTranslation.lookupValue(ifVar);
4881
4882 llvm::omp::Directive cancelledDirective =
4883 convertCancellationConstructType(op.getCancelDirective());
4884
4885 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4886 ompBuilder->createCancel(ompLoc, ifCond, cancelledDirective);
4887
4888 if (failed(handleError(afterIP, *op.getOperation())))
4889 return failure();
4890
4891 builder.restoreIP(afterIP.get());
4892
4893 return success();
4894}
4895
4896static LogicalResult
4897convertOmpCancellationPoint(omp::CancellationPointOp op,
4898 llvm::IRBuilderBase &builder,
4899 LLVM::ModuleTranslation &moduleTranslation) {
4900 if (failed(checkImplementationStatus(*op.getOperation())))
4901 return failure();
4902
4903 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4904 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4905
4906 llvm::omp::Directive cancelledDirective =
4907 convertCancellationConstructType(op.getCancelDirective());
4908
4909 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4910 ompBuilder->createCancellationPoint(ompLoc, cancelledDirective);
4911
4912 if (failed(handleError(afterIP, *op.getOperation())))
4913 return failure();
4914
4915 builder.restoreIP(afterIP.get());
4916
4917 return success();
4918}
4919
4920/// Converts an OpenMP Threadprivate operation into LLVM IR using
4921/// OpenMPIRBuilder.
4922static LogicalResult
4923convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder,
4924 LLVM::ModuleTranslation &moduleTranslation) {
4925 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4926 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4927 auto threadprivateOp = cast<omp::ThreadprivateOp>(opInst);
4928
4929 if (failed(checkImplementationStatus(opInst)))
4930 return failure();
4931
4932 Value symAddr = threadprivateOp.getSymAddr();
4933 auto *symOp = symAddr.getDefiningOp();
4934
4935 if (auto asCast = dyn_cast<LLVM::AddrSpaceCastOp>(symOp))
4936 symOp = asCast.getOperand().getDefiningOp();
4937
4938 if (!isa<LLVM::AddressOfOp>(symOp))
4939 return opInst.emitError("Addressing symbol not found");
4940 LLVM::AddressOfOp addressOfOp = dyn_cast<LLVM::AddressOfOp>(symOp);
4941
4942 LLVM::GlobalOp global =
4943 addressOfOp.getGlobal(moduleTranslation.symbolTable());
4944 llvm::GlobalValue *globalValue = moduleTranslation.lookupGlobal(global);
4945 llvm::Type *type = globalValue->getValueType();
4946 llvm::TypeSize typeSize =
4947 builder.GetInsertBlock()->getModule()->getDataLayout().getTypeStoreSize(
4948 type);
4949 llvm::ConstantInt *size = builder.getInt64(typeSize.getFixedValue());
4950 llvm::Value *callInst = ompBuilder->createCachedThreadPrivate(
4951 ompLoc, globalValue, size, global.getSymName() + ".cache");
4952 moduleTranslation.mapValue(opInst.getResult(0), callInst);
4953
4954 return success();
4955}
4956
4957static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind
4958convertToDeviceClauseKind(mlir::omp::DeclareTargetDeviceType deviceClause) {
4959 switch (deviceClause) {
4960 case mlir::omp::DeclareTargetDeviceType::host:
4961 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseHost;
4962 break;
4963 case mlir::omp::DeclareTargetDeviceType::nohost:
4964 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseNoHost;
4965 break;
4966 case mlir::omp::DeclareTargetDeviceType::any:
4967 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseAny;
4968 break;
4969 }
4970 llvm_unreachable("unhandled device clause");
4971}
4972
4973static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind
4975 mlir::omp::DeclareTargetCaptureClause captureClause) {
4976 switch (captureClause) {
4977 case mlir::omp::DeclareTargetCaptureClause::to:
4978 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo;
4979 case mlir::omp::DeclareTargetCaptureClause::link:
4980 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink;
4981 case mlir::omp::DeclareTargetCaptureClause::enter:
4982 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter;
4983 case mlir::omp::DeclareTargetCaptureClause::none:
4984 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryNone;
4985 }
4986 llvm_unreachable("unhandled capture clause");
4987}
4988
4990 Operation *op = value.getDefiningOp();
4991 if (auto addrCast = dyn_cast_if_present<LLVM::AddrSpaceCastOp>(op))
4992 op = addrCast->getOperand(0).getDefiningOp();
4993 if (auto addressOfOp = dyn_cast_if_present<LLVM::AddressOfOp>(op)) {
4994 auto modOp = addressOfOp->getParentOfType<mlir::ModuleOp>();
4995 return modOp.lookupSymbol(addressOfOp.getGlobalName());
4996 }
4997 return nullptr;
4998}
4999
5001 while (Operation *op = value.getDefiningOp()) {
5002 if (auto addrCast = dyn_cast_if_present<LLVM::AddrSpaceCastOp>(op))
5003 value = addrCast.getOperand();
5004 // Traces through hlfir.declare, fir.declare to reach the base address and
5005 // use for type lookup.
5006 else if (op->getName().getIdentifier() &&
5007 (op->getName().getIdentifier().str() == "hlfir.declare" ||
5008 op->getName().getIdentifier().str() == "fir.declare")) {
5009 if (op->getNumOperands() > 0)
5010 value = op->getOperand(0);
5011 else
5012 break;
5013 } else {
5014 break;
5015 }
5016 }
5017 return value;
5018}
5019
5020static llvm::SmallString<64>
5021getDeclareTargetRefPtrSuffix(LLVM::GlobalOp globalOp,
5022 llvm::OpenMPIRBuilder &ompBuilder,
5023 llvm::vfs::FileSystem &vfs) {
5024 llvm::SmallString<64> suffix;
5025 llvm::raw_svector_ostream os(suffix);
5026 if (globalOp.getVisibility() == mlir::SymbolTable::Visibility::Private) {
5027 auto loc = globalOp->getLoc()->findInstanceOf<FileLineColLoc>();
5028 auto fileInfoCallBack = [&loc]() {
5029 return std::pair<std::string, uint64_t>(
5030 llvm::StringRef(loc.getFilename()), loc.getLine());
5031 };
5032
5033 os << llvm::format(
5034 "_%x",
5035 ompBuilder.getTargetEntryUniqueInfo(fileInfoCallBack, vfs).FileID);
5036 }
5037 os << "_decl_tgt_ref_ptr";
5038
5039 return suffix;
5040}
5041
5042static bool isDeclareTargetLink(Value value) {
5043 if (auto declareTargetGlobal =
5044 dyn_cast_if_present<omp::DeclareTargetInterface>(
5045 getGlobalOpFromValue(value)))
5046 if (declareTargetGlobal.getDeclareTargetCaptureClause() ==
5047 omp::DeclareTargetCaptureClause::link)
5048 return true;
5049 return false;
5050}
5051
5052static bool isDeclareTargetTo(Value value) {
5053 if (auto declareTargetGlobal =
5054 dyn_cast_if_present<omp::DeclareTargetInterface>(
5055 getGlobalOpFromValue(value)))
5056 if (declareTargetGlobal.getDeclareTargetCaptureClause() ==
5057 omp::DeclareTargetCaptureClause::to ||
5058 declareTargetGlobal.getDeclareTargetCaptureClause() ==
5059 omp::DeclareTargetCaptureClause::enter)
5060 return true;
5061 return false;
5062}
5063
5064// Returns the reference pointer generated by the lowering of the declare
5065// target operation in cases where the link clause is used or the to clause is
5066// used in USM mode.
5067static llvm::Value *
5069 LLVM::ModuleTranslation &moduleTranslation) {
5070 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
5071 if (auto gOp =
5072 dyn_cast_or_null<LLVM::GlobalOp>(getGlobalOpFromValue(value))) {
5073 if (auto declareTargetGlobal =
5074 dyn_cast<omp::DeclareTargetInterface>(gOp.getOperation())) {
5075 // In this case, we must utilise the reference pointer generated by
5076 // the declare target operation, similar to Clang
5077 if ((declareTargetGlobal.getDeclareTargetCaptureClause() ==
5078 omp::DeclareTargetCaptureClause::link) ||
5079 (declareTargetGlobal.getDeclareTargetCaptureClause() ==
5080 omp::DeclareTargetCaptureClause::to &&
5081 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
5083 gOp, *ompBuilder, moduleTranslation.getFileSystem());
5084
5085 if (gOp.getSymName().contains(suffix))
5086 return moduleTranslation.getLLVMModule()->getNamedValue(
5087 gOp.getSymName());
5088
5089 return moduleTranslation.getLLVMModule()->getNamedValue(
5090 (gOp.getSymName().str() + suffix.str()).str());
5091 }
5092 }
5093 }
5094 return nullptr;
5095}
5096
5097namespace {
5098// Append customMappers information to existing MapInfosTy
5099struct MapInfosTy : llvm::OpenMPIRBuilder::MapInfosTy {
5100 SmallVector<Operation *, 4> Mappers;
5101
5102 /// Append arrays in \a CurInfo.
5103 void append(MapInfosTy &curInfo) {
5104 Mappers.append(curInfo.Mappers.begin(), curInfo.Mappers.end());
5105 llvm::OpenMPIRBuilder::MapInfosTy::append(curInfo);
5106 }
5107};
5108// A small helper structure to contain data gathered
5109// for map lowering and coalese it into one area and
5110// avoiding extra computations such as searches in the
5111// llvm module for lowered mapped variables or checking
5112// if something is declare target (and retrieving the
5113// value) more than neccessary.
5114struct MapInfoData : MapInfosTy {
5115 llvm::SmallVector<bool, 4> IsDeclareTarget;
5116 llvm::SmallVector<bool, 4> IsAMember;
5117 // Identify if mapping was added by mapClause or use_device clauses.
5118 llvm::SmallVector<bool, 4> IsAMapping;
5119 llvm::SmallVector<mlir::Operation *, 4> MapClause;
5120 llvm::SmallVector<llvm::Value *, 4> OriginalValue;
5121 // Stripped off array/pointer to get the underlying
5122 // element type
5123 llvm::SmallVector<llvm::Type *, 4> BaseType;
5124
5125 /// Append arrays in \a CurInfo.
5126 void append(MapInfoData &CurInfo) {
5127 IsDeclareTarget.append(CurInfo.IsDeclareTarget.begin(),
5128 CurInfo.IsDeclareTarget.end());
5129 MapClause.append(CurInfo.MapClause.begin(), CurInfo.MapClause.end());
5130 OriginalValue.append(CurInfo.OriginalValue.begin(),
5131 CurInfo.OriginalValue.end());
5132 BaseType.append(CurInfo.BaseType.begin(), CurInfo.BaseType.end());
5133 MapInfosTy::append(CurInfo);
5134 }
5135};
5136
5137enum class TargetDirectiveEnumTy : uint32_t {
5138 None = 0,
5139 Target = 1,
5140 TargetData = 2,
5141 TargetEnterData = 3,
5142 TargetExitData = 4,
5143 TargetUpdate = 5
5144};
5145
5146static TargetDirectiveEnumTy getTargetDirectiveEnumTyFromOp(Operation *op) {
5147 return llvm::TypeSwitch<Operation *, TargetDirectiveEnumTy>(op)
5148 .Case([](omp::TargetDataOp) { return TargetDirectiveEnumTy::TargetData; })
5149 .Case([](omp::TargetEnterDataOp) {
5150 return TargetDirectiveEnumTy::TargetEnterData;
5151 })
5152 .Case([&](omp::TargetExitDataOp) {
5153 return TargetDirectiveEnumTy::TargetExitData;
5154 })
5155 .Case([&](omp::TargetUpdateOp) {
5156 return TargetDirectiveEnumTy::TargetUpdate;
5157 })
5158 .Case([&](omp::TargetOp) { return TargetDirectiveEnumTy::Target; })
5159 .Default([&](Operation *op) { return TargetDirectiveEnumTy::None; });
5160}
5161
5162} // namespace
5163
5164static uint64_t getArrayElementSizeInBits(LLVM::LLVMArrayType arrTy,
5165 DataLayout &dl) {
5166 if (auto nestedArrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(
5167 arrTy.getElementType()))
5168 return getArrayElementSizeInBits(nestedArrTy, dl);
5169 return dl.getTypeSizeInBits(arrTy.getElementType());
5170}
5171
5172// This function calculates the size to be offloaded for a specified type, given
5173// its associated map clause (which can contain bounds information which affects
5174// the total size), this size is calculated based on the underlying element type
5175// e.g. given a 1-D array of ints, we will calculate the size from the integer
5176// type * number of elements in the array. This size can be used in other
5177// calculations but is ultimately used as an argument to the OpenMP runtimes
5178// kernel argument structure which is generated through the combinedInfo data
5179// structures.
5180// This function is somewhat equivalent to Clang's getExprTypeSize inside of
5181// CGOpenMPRuntime.cpp.
5182static llvm::Value *getSizeInBytes(DataLayout &dl, const mlir::Type &type,
5183 Operation *clauseOp,
5184 llvm::Value *basePointer,
5185 llvm::Type *baseType,
5186 llvm::IRBuilderBase &builder,
5187 LLVM::ModuleTranslation &moduleTranslation) {
5188 if (auto memberClause =
5189 mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(clauseOp)) {
5190 // This calculates the size to transfer based on bounds and the underlying
5191 // element type, provided bounds have been specified (Fortran
5192 // pointers/allocatables/target and arrays that have sections specified fall
5193 // into this as well)
5194 if (!memberClause.getBounds().empty()) {
5195 llvm::Value *elementCount = builder.getInt64(1);
5196 for (auto bounds : memberClause.getBounds()) {
5197 if (auto boundOp = mlir::dyn_cast_if_present<mlir::omp::MapBoundsOp>(
5198 bounds.getDefiningOp())) {
5199 // The below calculation for the size to be mapped calculated from the
5200 // map.info's bounds is: (elemCount * [UB - LB] + 1), later we
5201 // multiply by the underlying element types byte size to get the full
5202 // size to be offloaded based on the bounds
5203 elementCount = builder.CreateMul(
5204 elementCount,
5205 builder.CreateAdd(
5206 builder.CreateSub(
5207 moduleTranslation.lookupValue(boundOp.getUpperBound()),
5208 moduleTranslation.lookupValue(boundOp.getLowerBound())),
5209 builder.getInt64(1)));
5210 }
5211 }
5212
5213 // utilising getTypeSizeInBits instead of getTypeSize as getTypeSize gives
5214 // the size in inconsistent byte or bit format.
5215 uint64_t underlyingTypeSzInBits = dl.getTypeSizeInBits(type);
5216 if (auto arrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(type))
5217 underlyingTypeSzInBits = getArrayElementSizeInBits(arrTy, dl);
5218
5219 // The size in bytes x number of elements, the sizeInBytes stored is
5220 // the underyling types size, e.g. if ptr<i32>, it'll be the i32's
5221 // size, so we do some on the fly runtime math to get the size in
5222 // bytes from the extent (ub - lb) * sizeInBytes. NOTE: This may need
5223 // some adjustment for members with more complex types.
5224 return builder.CreateMul(elementCount,
5225 builder.getInt64(underlyingTypeSzInBits / 8));
5226 }
5227 }
5228
5229 return builder.getInt64(dl.getTypeSizeInBits(type) / 8);
5230}
5231
5232// Convert the MLIR map flag set to the runtime map flag set for embedding
5233// in LLVM-IR. This is important as the two bit-flag lists do not correspond
5234// 1-to-1 as there's flags the runtime doesn't care about and vice versa.
5235// Certain flags are discarded here such as RefPtee and co.
5236static llvm::omp::OpenMPOffloadMappingFlags
5237convertClauseMapFlags(omp::ClauseMapFlags mlirFlags) {
5238 auto mapTypeToBool = [&mlirFlags](omp::ClauseMapFlags flag) {
5239 return (mlirFlags & flag) == flag;
5240 };
5241 const bool hasExplicitMap =
5242 (mlirFlags & ~omp::ClauseMapFlags::is_device_ptr) !=
5243 omp::ClauseMapFlags::none;
5244
5245 llvm::omp::OpenMPOffloadMappingFlags mapType =
5246 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
5247
5248 if (mapTypeToBool(omp::ClauseMapFlags::to))
5249 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
5250
5251 if (mapTypeToBool(omp::ClauseMapFlags::from))
5252 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
5253
5254 if (mapTypeToBool(omp::ClauseMapFlags::always))
5255 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
5256
5257 if (mapTypeToBool(omp::ClauseMapFlags::del))
5258 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
5259
5260 if (mapTypeToBool(omp::ClauseMapFlags::return_param))
5261 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
5262
5263 if (mapTypeToBool(omp::ClauseMapFlags::priv))
5264 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRIVATE;
5265
5266 if (mapTypeToBool(omp::ClauseMapFlags::literal))
5267 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
5268
5269 if (mapTypeToBool(omp::ClauseMapFlags::implicit))
5270 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
5271
5272 if (mapTypeToBool(omp::ClauseMapFlags::close))
5273 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
5274
5275 if (mapTypeToBool(omp::ClauseMapFlags::present))
5276 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT;
5277
5278 if (mapTypeToBool(omp::ClauseMapFlags::ompx_hold))
5279 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD;
5280
5281 if (mapTypeToBool(omp::ClauseMapFlags::attach))
5282 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH;
5283
5284 if (mapTypeToBool(omp::ClauseMapFlags::is_device_ptr)) {
5285 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
5286 if (!hasExplicitMap)
5287 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
5288 }
5289
5290 return mapType;
5291}
5292
5294 MapInfoData &mapData, SmallVectorImpl<Value> &mapVars,
5295 LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl,
5296 llvm::IRBuilderBase &builder, ArrayRef<Value> useDevPtrOperands = {},
5297 ArrayRef<Value> useDevAddrOperands = {},
5298 ArrayRef<Value> hasDevAddrOperands = {}) {
5299 auto checkIsAMember = [](const auto &mapVars, auto mapOp) {
5300 // Check if this is a member mapping and correctly assign that it is, if
5301 // it is a member of a larger object.
5302 // TODO: Need better handling of members, and distinguishing of members
5303 // that are implicitly allocated on device vs explicitly passed in as
5304 // arguments.
5305 // TODO: May require some further additions to support nested record
5306 // types, i.e. member maps that can have member maps.
5307 for (Value mapValue : mapVars) {
5308 auto map = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
5309 for (auto member : map.getMembers())
5310 if (member == mapOp)
5311 return true;
5312 }
5313 return false;
5314 };
5315
5316 // Process MapOperands
5317 for (Value mapValue : mapVars) {
5318 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
5319 Value offloadPtr =
5320 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
5321 mapData.OriginalValue.push_back(moduleTranslation.lookupValue(offloadPtr));
5322 mapData.Pointers.push_back(mapData.OriginalValue.back());
5323
5324 if (llvm::Value *refPtr =
5325 getRefPtrIfDeclareTarget(offloadPtr, moduleTranslation)) {
5326 mapData.IsDeclareTarget.push_back(true);
5327 mapData.BasePointers.push_back(refPtr);
5328 } else if (isDeclareTargetTo(offloadPtr)) {
5329 mapData.IsDeclareTarget.push_back(true);
5330 mapData.BasePointers.push_back(mapData.OriginalValue.back());
5331 } else { // regular mapped variable
5332 mapData.IsDeclareTarget.push_back(false);
5333 mapData.BasePointers.push_back(mapData.OriginalValue.back());
5334 }
5335
5336 mapData.BaseType.push_back(
5337 moduleTranslation.convertType(mapOp.getVarType()));
5338 mapData.Sizes.push_back(
5339 getSizeInBytes(dl, mapOp.getVarType(), mapOp, mapData.Pointers.back(),
5340 mapData.BaseType.back(), builder, moduleTranslation));
5341 mapData.MapClause.push_back(mapOp.getOperation());
5342 mapData.Types.push_back(convertClauseMapFlags(mapOp.getMapType()));
5343 mapData.Names.push_back(LLVM::createMappingInformation(
5344 mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
5345 mapData.DevicePointers.push_back(llvm::OpenMPIRBuilder::DeviceInfoTy::None);
5346 if (mapOp.getMapperId())
5347 mapData.Mappers.push_back(
5349 mapOp, mapOp.getMapperIdAttr()));
5350 else
5351 mapData.Mappers.push_back(nullptr);
5352 mapData.IsAMapping.push_back(true);
5353 mapData.IsAMember.push_back(checkIsAMember(mapVars, mapOp));
5354 }
5355
5356 auto findMapInfo = [&mapData](llvm::Value *val,
5357 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
5358 unsigned index = 0;
5359 bool found = false;
5360 for (llvm::Value *basePtr : mapData.OriginalValue) {
5361 if (basePtr == val && mapData.IsAMapping[index]) {
5362 found = true;
5363 mapData.Types[index] |=
5364 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
5365 mapData.DevicePointers[index] = devInfoTy;
5366 }
5367 index++;
5368 }
5369 return found;
5370 };
5371
5372 // Process useDevPtr(Addr)Operands
5373 auto addDevInfos = [&](const llvm::ArrayRef<Value> &useDevOperands,
5374 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
5375 for (Value mapValue : useDevOperands) {
5376 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
5377 Value offloadPtr =
5378 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
5379 llvm::Value *origValue = moduleTranslation.lookupValue(offloadPtr);
5380
5381 // Check if map info is already present for this entry.
5382 if (!findMapInfo(origValue, devInfoTy)) {
5383 mapData.OriginalValue.push_back(origValue);
5384 mapData.Pointers.push_back(mapData.OriginalValue.back());
5385 mapData.IsDeclareTarget.push_back(false);
5386 mapData.BasePointers.push_back(mapData.OriginalValue.back());
5387 mapData.BaseType.push_back(
5388 moduleTranslation.convertType(mapOp.getVarType()));
5389 mapData.Sizes.push_back(builder.getInt64(0));
5390 mapData.MapClause.push_back(mapOp.getOperation());
5391 mapData.Types.push_back(
5392 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM);
5393 mapData.Names.push_back(LLVM::createMappingInformation(
5394 mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
5395 mapData.DevicePointers.push_back(devInfoTy);
5396 mapData.Mappers.push_back(nullptr);
5397 mapData.IsAMapping.push_back(false);
5398 mapData.IsAMember.push_back(checkIsAMember(useDevOperands, mapOp));
5399 }
5400 }
5401 };
5402
5403 addDevInfos(useDevAddrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
5404 addDevInfos(useDevPtrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer);
5405
5406 for (Value mapValue : hasDevAddrOperands) {
5407 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
5408 Value offloadPtr =
5409 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
5410 llvm::Value *origValue = moduleTranslation.lookupValue(offloadPtr);
5411 auto mapType = convertClauseMapFlags(mapOp.getMapType());
5412 auto mapTypeAlways = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
5413 bool isDevicePtr =
5414 (mapOp.getMapType() & omp::ClauseMapFlags::is_device_ptr) !=
5415 omp::ClauseMapFlags::none;
5416
5417 mapData.OriginalValue.push_back(origValue);
5418 mapData.BasePointers.push_back(origValue);
5419 mapData.Pointers.push_back(origValue);
5420 mapData.IsDeclareTarget.push_back(false);
5421 mapData.BaseType.push_back(
5422 moduleTranslation.convertType(mapOp.getVarType()));
5423 mapData.Sizes.push_back(
5424 builder.getInt64(dl.getTypeSize(mapOp.getVarType())));
5425 mapData.MapClause.push_back(mapOp.getOperation());
5426 if (llvm::to_underlying(mapType & mapTypeAlways)) {
5427 // Descriptors are mapped with the ALWAYS flag, since they can get
5428 // rematerialized, so the address of the decriptor for a given object
5429 // may change from one place to another.
5430 mapData.Types.push_back(mapType);
5431 // Technically it's possible for a non-descriptor mapping to have
5432 // both has-device-addr and ALWAYS, so lookup the mapper in case it
5433 // exists.
5434 if (mapOp.getMapperId()) {
5435 mapData.Mappers.push_back(
5437 mapOp, mapOp.getMapperIdAttr()));
5438 } else {
5439 mapData.Mappers.push_back(nullptr);
5440 }
5441 } else {
5442 // For is_device_ptr we need the map type to propagate so the runtime
5443 // can materialize the device-side copy of the pointer container.
5444 mapData.Types.push_back(
5445 isDevicePtr ? mapType
5446 : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL);
5447 mapData.Mappers.push_back(nullptr);
5448 }
5449 mapData.Names.push_back(LLVM::createMappingInformation(
5450 mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
5451 mapData.DevicePointers.push_back(
5452 isDevicePtr ? llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer
5453 : llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
5454 mapData.IsAMapping.push_back(false);
5455 mapData.IsAMember.push_back(checkIsAMember(hasDevAddrOperands, mapOp));
5456 }
5457}
5458
5459static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp) {
5460 auto *res = llvm::find(mapData.MapClause, memberOp);
5461 assert(res != mapData.MapClause.end() &&
5462 "MapInfoOp for member not found in MapData, cannot return index");
5463 return std::distance(mapData.MapClause.begin(), res);
5464}
5465
5467 omp::MapInfoOp mapInfo) {
5468 ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
5469 llvm::SmallVector<size_t> occludedChildren;
5470 llvm::sort(
5471 indices.begin(), indices.end(), [&](const size_t a, const size_t b) {
5472 // Bail early if we are asked to look at the same index. If we do not
5473 // bail early, we can end up mistakenly adding indices to
5474 // occludedChildren. This can occur with some types of libc++ hardening.
5475 if (a == b)
5476 return false;
5477
5478 auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]);
5479 auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]);
5480
5481 for (auto it : llvm::zip(memberIndicesA, memberIndicesB)) {
5482 int64_t aIndex = mlir::cast<IntegerAttr>(std::get<0>(it)).getInt();
5483 int64_t bIndex = mlir::cast<IntegerAttr>(std::get<1>(it)).getInt();
5484
5485 if (aIndex == bIndex)
5486 continue;
5487
5488 if (aIndex < bIndex)
5489 return true;
5490
5491 if (aIndex > bIndex)
5492 return false;
5493 }
5494
5495 // Iterated up until the end of the smallest member and
5496 // they were found to be equal up to that point, so select
5497 // the member with the lowest index count, so the "parent"
5498 bool memberAParent = memberIndicesA.size() < memberIndicesB.size();
5499 if (memberAParent)
5500 occludedChildren.push_back(b);
5501 else
5502 occludedChildren.push_back(a);
5503 return memberAParent;
5504 });
5505
5506 // We remove children from the index list that are overshadowed by
5507 // a parent, this prevents us retrieving these as the first or last
5508 // element when the parent is the correct element in these cases.
5509 for (auto v : occludedChildren)
5510 indices.erase(std::remove(indices.begin(), indices.end(), v),
5511 indices.end());
5512}
5513
5514static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo,
5515 bool first) {
5516 ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
5517 // Only 1 member has been mapped, we can return it.
5518 if (indexAttr.size() == 1)
5519 return cast<omp::MapInfoOp>(mapInfo.getMembers()[0].getDefiningOp());
5520 llvm::SmallVector<size_t> indices(indexAttr.size());
5521 std::iota(indices.begin(), indices.end(), 0);
5522 sortMapIndices(indices, mapInfo);
5523 return llvm::cast<omp::MapInfoOp>(
5524 mapInfo.getMembers()[first ? indices.front() : indices.back()]
5525 .getDefiningOp());
5526}
5527
5528/// This function calculates the array/pointer offset for map data provided
5529/// with bounds operations, e.g. when provided something like the following:
5530///
5531/// Fortran
5532/// map(tofrom: array(2:5, 3:2))
5533///
5534/// We must calculate the initial pointer offset to pass across, this function
5535/// performs this using bounds.
5536///
5537/// TODO/WARNING: This only supports Fortran's column major indexing currently
5538/// as is noted in the note below and comments in the function, we must extend
5539/// this function when we add a C++ frontend.
5540/// NOTE: which while specified in row-major order it currently needs to be
5541/// flipped for Fortran's column order array allocation and access (as
5542/// opposed to C++'s row-major, hence the backwards processing where order is
5543/// important). This is likely important to keep in mind for the future when
5544/// we incorporate a C++ frontend, both frontends will need to agree on the
5545/// ordering of generated bounds operations (one may have to flip them) to
5546/// make the below lowering frontend agnostic. The offload size
5547/// calcualtion may also have to be adjusted for C++.
5548static std::vector<llvm::Value *>
5550 llvm::IRBuilderBase &builder, bool isArrayTy,
5551 OperandRange bounds) {
5552 std::vector<llvm::Value *> idx;
5553 // There's no bounds to calculate an offset from, we can safely
5554 // ignore and return no indices.
5555 if (bounds.empty())
5556 return idx;
5557
5558 // If we have an array type, then we have its type so can treat it as a
5559 // normal GEP instruction where the bounds operations are simply indexes
5560 // into the array. We currently do reverse order of the bounds, which
5561 // I believe leans more towards Fortran's column-major in memory.
5562 if (isArrayTy) {
5563 idx.push_back(builder.getInt64(0));
5564 for (int i = bounds.size() - 1; i >= 0; --i) {
5565 if (auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
5566 bounds[i].getDefiningOp())) {
5567 idx.push_back(moduleTranslation.lookupValue(boundOp.getLowerBound()));
5568 }
5569 }
5570 } else {
5571 // If we do not have an array type, but we have bounds, then we're dealing
5572 // with a pointer that's being treated like an array and we have the
5573 // underlying type e.g. an i32, or f64 etc, e.g. a fortran descriptor base
5574 // address (pointer pointing to the actual data) so we must caclulate the
5575 // offset using a single index which the following loop attempts to
5576 // compute using the standard column-major algorithm e.g for a 3D array:
5577 //
5578 // ((((c_idx * b_len) + b_idx) * a_len) + a_idx)
5579 //
5580 // It is of note that it's doing column-major rather than row-major at the
5581 // moment, but having a way for the frontend to indicate which major format
5582 // to use or standardizing/canonicalizing the order of the bounds to compute
5583 // the offset may be useful in the future when there's other frontends with
5584 // different formats.
5585 std::vector<llvm::Value *> dimensionIndexSizeOffset;
5586 for (int i = bounds.size() - 1; i >= 0; --i) {
5587 if (auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
5588 bounds[i].getDefiningOp())) {
5589 if (i == ((int)bounds.size() - 1))
5590 idx.emplace_back(
5591 moduleTranslation.lookupValue(boundOp.getLowerBound()));
5592 else
5593 idx.back() = builder.CreateAdd(
5594 builder.CreateMul(idx.back(), moduleTranslation.lookupValue(
5595 boundOp.getExtent())),
5596 moduleTranslation.lookupValue(boundOp.getLowerBound()));
5597 }
5598 }
5599 }
5600
5601 return idx;
5602}
5603
5605 llvm::transform(values, std::back_inserter(ints), [](Attribute value) {
5606 return cast<IntegerAttr>(value).getInt();
5607 });
5608}
5609
5610// Gathers members that are overlapping in the parent, excluding members that
5611// themselves overlap, keeping the top-most (closest to parents level) map.
5612static void
5614 omp::MapInfoOp parentOp) {
5615 // No members mapped, no overlaps.
5616 if (parentOp.getMembers().empty())
5617 return;
5618
5619 // Single member, we can insert and return early.
5620 if (parentOp.getMembers().size() == 1) {
5621 overlapMapDataIdxs.push_back(0);
5622 return;
5623 }
5624
5625 // 1) collect list of top-level overlapping members from MemberOp
5627 ArrayAttr indexAttr = parentOp.getMembersIndexAttr();
5628 for (auto [memIndex, indicesAttr] : llvm::enumerate(indexAttr))
5629 memberByIndex.push_back(
5630 std::make_pair(memIndex, cast<ArrayAttr>(indicesAttr)));
5631
5632 // Sort the smallest first (higher up the parent -> member chain), so that
5633 // when we remove members, we remove as much as we can in the initial
5634 // iterations, shortening the number of passes required.
5635 llvm::sort(memberByIndex.begin(), memberByIndex.end(),
5636 [&](auto a, auto b) { return a.second.size() < b.second.size(); });
5637
5638 // Remove elements from the vector if there is a parent element that
5639 // supersedes it. i.e. if member [0] is mapped, we can remove members [0,1],
5640 // [0,2].. etc.
5642 for (auto v : memberByIndex) {
5643 llvm::SmallVector<int64_t> vArr(v.second.size());
5644 getAsIntegers(v.second, vArr);
5645 skipList.push_back(
5646 *std::find_if(memberByIndex.begin(), memberByIndex.end(), [&](auto x) {
5647 if (v == x)
5648 return false;
5649 llvm::SmallVector<int64_t> xArr(x.second.size());
5650 getAsIntegers(x.second, xArr);
5651 return std::equal(vArr.begin(), vArr.end(), xArr.begin()) &&
5652 xArr.size() >= vArr.size();
5653 }));
5654 }
5655
5656 // Collect the indices, as we need the base pointer etc. from the MapData
5657 // structure which is primarily accessible via index at the moment.
5658 for (auto v : memberByIndex)
5659 if (find(skipList.begin(), skipList.end(), v) == skipList.end())
5660 overlapMapDataIdxs.push_back(v.first);
5661}
5662
5663// The intent is to verify if the mapped data being passed is a
5664// pointer -> pointee that requires special handling in certain cases,
5665// e.g. applying the OMP_MAP_PTR_AND_OBJ map type.
5666//
5667// There may be a better way to verify this, but unfortunately with
5668// opaque pointers we lose the ability to easily check if something is
5669// a pointer whilst maintaining access to the underlying type.
5670static bool checkIfPointerMap(omp::MapInfoOp mapOp) {
5671 // If we have a varPtrPtr field assigned then the underlying type is a pointer
5672 if (mapOp.getVarPtrPtr())
5673 return true;
5674
5675 // If the map data is declare target with a link clause, then it's represented
5676 // as a pointer when we lower it to LLVM-IR even if at the MLIR level it has
5677 // no relation to pointers.
5678 if (isDeclareTargetLink(mapOp.getVarPtr()))
5679 return true;
5680
5681 return false;
5682}
5683
5684// This creates two insertions into the MapInfosTy data structure for the
5685// "parent" of a set of members, (usually a container e.g.
5686// class/structure/derived type) when subsequent members have also been
5687// explicitly mapped on the same map clause. Certain types, such as Fortran
5688// descriptors are mapped like this as well, however, the members are
5689// implicit as far as a user is concerned, but we must explicitly map them
5690// internally.
5691//
5692// This function also returns the memberOfFlag for this particular parent,
5693// which is utilised in subsequent member mappings (by modifying there map type
5694// with it) to indicate that a member is part of this parent and should be
5695// treated by the runtime as such. Important to achieve the correct mapping.
5696//
5697// This function borrows a lot from Clang's emitCombinedEntry function
5698// inside of CGOpenMPRuntime.cpp
5699static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
5700 LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder,
5701 llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo,
5702 MapInfoData &mapData, uint64_t mapDataIndex,
5703 TargetDirectiveEnumTy targetDirective) {
5704 assert(!ompBuilder.Config.isTargetDevice() &&
5705 "function only supported for host device codegen");
5706
5707 auto parentClause =
5708 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
5709
5710 auto *parentMapper = mapData.Mappers[mapDataIndex];
5711
5712 // Map the first segment of the parent. If a user-defined mapper is attached,
5713 // include the parent's to/from-style bits (and common modifiers) in this
5714 // base entry so the mapper receives correct copy semantics via its 'type'
5715 // parameter. Also keep TARGET_PARAM when required for kernel arguments.
5716 llvm::omp::OpenMPOffloadMappingFlags baseFlag =
5717 (targetDirective == TargetDirectiveEnumTy::Target &&
5718 !mapData.IsDeclareTarget[mapDataIndex])
5719 ? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM
5720 : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
5721
5722 if (parentMapper) {
5723 using mapFlags = llvm::omp::OpenMPOffloadMappingFlags;
5724 // Preserve relevant map-type bits from the parent clause. These include
5725 // the copy direction (TO/FROM), as well as commonly used modifiers that
5726 // should be visible to the mapper for correct behaviour.
5727 mapFlags parentFlags = mapData.Types[mapDataIndex];
5728 mapFlags preserve = mapFlags::OMP_MAP_TO | mapFlags::OMP_MAP_FROM |
5729 mapFlags::OMP_MAP_ALWAYS | mapFlags::OMP_MAP_CLOSE |
5730 mapFlags::OMP_MAP_PRESENT | mapFlags::OMP_MAP_OMPX_HOLD;
5731 baseFlag |= (parentFlags & preserve);
5732 }
5733
5734 combinedInfo.Types.emplace_back(baseFlag);
5735 combinedInfo.DevicePointers.emplace_back(
5736 mapData.DevicePointers[mapDataIndex]);
5737 // Only attach the mapper to the base entry when we are mapping the whole
5738 // parent. Combined/segment entries must not carry a mapper; otherwise the
5739 // mapper can be invoked with a partial size, which is undefined behaviour.
5740 combinedInfo.Mappers.emplace_back(
5741 parentMapper && !parentClause.getPartialMap() ? parentMapper : nullptr);
5742 combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
5743 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
5744 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
5745
5746 // Calculate size of the parent object being mapped based on the
5747 // addresses at runtime, highAddr - lowAddr = size. This of course
5748 // doesn't factor in allocated data like pointers, hence the further
5749 // processing of members specified by users, or in the case of
5750 // Fortran pointers and allocatables, the mapping of the pointed to
5751 // data by the descriptor (which itself, is a structure containing
5752 // runtime information on the dynamically allocated data).
5753 llvm::Value *lowAddr, *highAddr;
5754 if (!parentClause.getPartialMap()) {
5755 lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
5756 builder.getPtrTy());
5757 highAddr = builder.CreatePointerCast(
5758 builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
5759 mapData.Pointers[mapDataIndex], 1),
5760 builder.getPtrTy());
5761 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
5762 } else {
5763 auto mapOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
5764 int firstMemberIdx = getMapDataMemberIdx(
5765 mapData, getFirstOrLastMappedMemberPtr(mapOp, true));
5766 lowAddr = builder.CreatePointerCast(mapData.Pointers[firstMemberIdx],
5767 builder.getPtrTy());
5768 int lastMemberIdx = getMapDataMemberIdx(
5769 mapData, getFirstOrLastMappedMemberPtr(mapOp, false));
5770 highAddr = builder.CreatePointerCast(
5771 builder.CreateGEP(mapData.BaseType[lastMemberIdx],
5772 mapData.Pointers[lastMemberIdx], builder.getInt64(1)),
5773 builder.getPtrTy());
5774 combinedInfo.Pointers.emplace_back(mapData.Pointers[firstMemberIdx]);
5775 }
5776
5777 llvm::Value *size = builder.CreateIntCast(
5778 builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
5779 builder.getInt64Ty(),
5780 /*isSigned=*/false);
5781 combinedInfo.Sizes.push_back(size);
5782
5783 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag =
5784 ompBuilder.getMemberOfFlag(combinedInfo.BasePointers.size() - 1);
5785
5786 // This creates the initial MEMBER_OF mapping that consists of
5787 // the parent/top level container (same as above effectively, except
5788 // with a fixed initial compile time size and separate maptype which
5789 // indicates the true mape type (tofrom etc.). This parent mapping is
5790 // only relevant if the structure in its totality is being mapped,
5791 // otherwise the above suffices.
5792 if (!parentClause.getPartialMap()) {
5793 // TODO: This will need to be expanded to include the whole host of logic
5794 // for the map flags that Clang currently supports (e.g. it should do some
5795 // further case specific flag modifications). For the moment, it handles
5796 // what we support as expected.
5797 llvm::omp::OpenMPOffloadMappingFlags mapFlag = mapData.Types[mapDataIndex];
5798 bool hasMapClose = (llvm::omp::OpenMPOffloadMappingFlags(mapFlag) &
5799 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE) ==
5800 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
5801 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
5802
5803 if (targetDirective == TargetDirectiveEnumTy::TargetUpdate || hasMapClose) {
5804 combinedInfo.Types.emplace_back(mapFlag);
5805 combinedInfo.DevicePointers.emplace_back(
5806 mapData.DevicePointers[mapDataIndex]);
5807 combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
5808 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
5809 combinedInfo.BasePointers.emplace_back(
5810 mapData.BasePointers[mapDataIndex]);
5811 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
5812 combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIndex]);
5813 combinedInfo.Mappers.emplace_back(nullptr);
5814 } else {
5815 llvm::SmallVector<size_t> overlapIdxs;
5816 // Find all of the members that "overlap", i.e. occlude other members that
5817 // were mapped alongside the parent, e.g. member [0], occludes [0,1] and
5818 // [0,2], but not [1,0].
5819 getOverlappedMembers(overlapIdxs, parentClause);
5820 // We need to make sure the overlapped members are sorted in order of
5821 // lowest address to highest address.
5822 sortMapIndices(overlapIdxs, parentClause);
5823
5824 lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
5825 builder.getPtrTy());
5826 highAddr = builder.CreatePointerCast(
5827 builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
5828 mapData.Pointers[mapDataIndex], 1),
5829 builder.getPtrTy());
5830
5831 // TODO: We may want to skip arrays/array sections in this as Clang does.
5832 // It appears to be an optimisation rather than a necessity though,
5833 // but this requires further investigation. However, we would have to make
5834 // sure to not exclude maps with bounds that ARE pointers, as these are
5835 // processed as separate components, i.e. pointer + data.
5836 for (auto v : overlapIdxs) {
5837 auto mapDataOverlapIdx = getMapDataMemberIdx(
5838 mapData,
5839 cast<omp::MapInfoOp>(parentClause.getMembers()[v].getDefiningOp()));
5840 combinedInfo.Types.emplace_back(mapFlag);
5841 combinedInfo.DevicePointers.emplace_back(
5842 mapData.DevicePointers[mapDataOverlapIdx]);
5843 combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
5844 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
5845 combinedInfo.BasePointers.emplace_back(
5846 mapData.BasePointers[mapDataIndex]);
5847 combinedInfo.Mappers.emplace_back(nullptr);
5848 combinedInfo.Pointers.emplace_back(lowAddr);
5849 combinedInfo.Sizes.emplace_back(builder.CreateIntCast(
5850 builder.CreatePtrDiff(builder.getInt8Ty(),
5851 mapData.OriginalValue[mapDataOverlapIdx],
5852 lowAddr),
5853 builder.getInt64Ty(), /*isSigned=*/true));
5854 lowAddr = builder.CreateConstGEP1_32(
5855 checkIfPointerMap(llvm::cast<omp::MapInfoOp>(
5856 mapData.MapClause[mapDataOverlapIdx]))
5857 ? builder.getPtrTy()
5858 : mapData.BaseType[mapDataOverlapIdx],
5859 mapData.BasePointers[mapDataOverlapIdx], 1);
5860 }
5861
5862 combinedInfo.Types.emplace_back(mapFlag);
5863 combinedInfo.DevicePointers.emplace_back(
5864 mapData.DevicePointers[mapDataIndex]);
5865 combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
5866 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
5867 combinedInfo.BasePointers.emplace_back(
5868 mapData.BasePointers[mapDataIndex]);
5869 combinedInfo.Mappers.emplace_back(nullptr);
5870 combinedInfo.Pointers.emplace_back(lowAddr);
5871 combinedInfo.Sizes.emplace_back(builder.CreateIntCast(
5872 builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
5873 builder.getInt64Ty(), true));
5874 }
5875 }
5876 return memberOfFlag;
5877}
5878
5879// This function is intended to add explicit mappings of members
5881 LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder,
5882 llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo,
5883 MapInfoData &mapData, uint64_t mapDataIndex,
5884 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag,
5885 TargetDirectiveEnumTy targetDirective) {
5886 assert(!ompBuilder.Config.isTargetDevice() &&
5887 "function only supported for host device codegen");
5888
5889 auto parentClause =
5890 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
5891
5892 for (auto mappedMembers : parentClause.getMembers()) {
5893 auto memberClause =
5894 llvm::cast<omp::MapInfoOp>(mappedMembers.getDefiningOp());
5895 int memberDataIdx = getMapDataMemberIdx(mapData, memberClause);
5896
5897 assert(memberDataIdx >= 0 && "could not find mapped member of structure");
5898
5899 // If we're currently mapping a pointer to a block of data, we must
5900 // initially map the pointer, and then attatch/bind the data with a
5901 // subsequent map to the pointer. This segment of code generates the
5902 // pointer mapping, which can in certain cases be optimised out as Clang
5903 // currently does in its lowering. However, for the moment we do not do so,
5904 // in part as we currently have substantially less information on the data
5905 // being mapped at this stage.
5906 if (checkIfPointerMap(memberClause)) {
5907 auto mapFlag = convertClauseMapFlags(memberClause.getMapType());
5908 mapFlag &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
5909 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
5910 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
5911 combinedInfo.Types.emplace_back(mapFlag);
5912 combinedInfo.DevicePointers.emplace_back(
5913 llvm::OpenMPIRBuilder::DeviceInfoTy::None);
5914 combinedInfo.Mappers.emplace_back(nullptr);
5915 combinedInfo.Names.emplace_back(
5916 LLVM::createMappingInformation(memberClause.getLoc(), ompBuilder));
5917 combinedInfo.BasePointers.emplace_back(
5918 mapData.BasePointers[mapDataIndex]);
5919 combinedInfo.Pointers.emplace_back(mapData.BasePointers[memberDataIdx]);
5920 combinedInfo.Sizes.emplace_back(builder.getInt64(
5921 moduleTranslation.getLLVMModule()->getDataLayout().getPointerSize()));
5922 }
5923
5924 // Same MemberOfFlag to indicate its link with parent and other members
5925 // of.
5926 auto mapFlag = convertClauseMapFlags(memberClause.getMapType());
5927 mapFlag &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
5928 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
5929 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
5930 bool isDeclTargetTo = isDeclareTargetTo(parentClause.getVarPtr()
5931 ? parentClause.getVarPtr()
5932 : parentClause.getVarPtrPtr());
5933 if (checkIfPointerMap(memberClause) &&
5934 (!isDeclTargetTo ||
5935 (targetDirective != TargetDirectiveEnumTy::TargetUpdate &&
5936 targetDirective != TargetDirectiveEnumTy::TargetData))) {
5937 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
5938 }
5939
5940 combinedInfo.Types.emplace_back(mapFlag);
5941 combinedInfo.DevicePointers.emplace_back(
5942 mapData.DevicePointers[memberDataIdx]);
5943 combinedInfo.Mappers.emplace_back(mapData.Mappers[memberDataIdx]);
5944 combinedInfo.Names.emplace_back(
5945 LLVM::createMappingInformation(memberClause.getLoc(), ompBuilder));
5946 uint64_t basePointerIndex =
5947 checkIfPointerMap(memberClause) ? memberDataIdx : mapDataIndex;
5948 combinedInfo.BasePointers.emplace_back(
5949 mapData.BasePointers[basePointerIndex]);
5950 combinedInfo.Pointers.emplace_back(mapData.Pointers[memberDataIdx]);
5951
5952 llvm::Value *size = mapData.Sizes[memberDataIdx];
5953 if (checkIfPointerMap(memberClause)) {
5954 size = builder.CreateSelect(
5955 builder.CreateIsNull(mapData.Pointers[memberDataIdx]),
5956 builder.getInt64(0), size);
5957 }
5958
5959 combinedInfo.Sizes.emplace_back(size);
5960 }
5961}
5962
5963static void processIndividualMap(MapInfoData &mapData, size_t mapDataIdx,
5964 MapInfosTy &combinedInfo,
5965 TargetDirectiveEnumTy targetDirective,
5966 int mapDataParentIdx = -1) {
5967 // Declare Target Mappings are excluded from being marked as
5968 // OMP_MAP_TARGET_PARAM as they are not passed as parameters, they're
5969 // marked with OMP_MAP_PTR_AND_OBJ instead.
5970 auto mapFlag = mapData.Types[mapDataIdx];
5971 auto mapInfoOp = llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIdx]);
5972
5973 bool isPtrTy = checkIfPointerMap(mapInfoOp);
5974 if (isPtrTy)
5975 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
5976
5977 if (targetDirective == TargetDirectiveEnumTy::Target &&
5978 !mapData.IsDeclareTarget[mapDataIdx])
5979 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
5980
5981 if (mapInfoOp.getMapCaptureType() == omp::VariableCaptureKind::ByCopy &&
5982 !isPtrTy)
5983 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
5984
5985 // if we're provided a mapDataParentIdx, then the data being mapped is
5986 // part of a larger object (in a parent <-> member mapping) and in this
5987 // case our BasePointer should be the parent.
5988 if (mapDataParentIdx >= 0)
5989 combinedInfo.BasePointers.emplace_back(
5990 mapData.BasePointers[mapDataParentIdx]);
5991 else
5992 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIdx]);
5993
5994 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIdx]);
5995 combinedInfo.DevicePointers.emplace_back(mapData.DevicePointers[mapDataIdx]);
5996 combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIdx]);
5997 combinedInfo.Names.emplace_back(mapData.Names[mapDataIdx]);
5998 combinedInfo.Types.emplace_back(mapFlag);
5999 combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIdx]);
6000}
6001
6003 llvm::IRBuilderBase &builder,
6004 llvm::OpenMPIRBuilder &ompBuilder,
6005 DataLayout &dl, MapInfosTy &combinedInfo,
6006 MapInfoData &mapData, uint64_t mapDataIndex,
6007 TargetDirectiveEnumTy targetDirective) {
6008 assert(!ompBuilder.Config.isTargetDevice() &&
6009 "function only supported for host device codegen");
6010
6011 auto parentClause =
6012 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
6013
6014 // If we have a partial map (no parent referenced in the map clauses of the
6015 // directive, only members) and only a single member, we do not need to bind
6016 // the map of the member to the parent, we can pass the member separately.
6017 if (parentClause.getMembers().size() == 1 && parentClause.getPartialMap()) {
6018 auto memberClause = llvm::cast<omp::MapInfoOp>(
6019 parentClause.getMembers()[0].getDefiningOp());
6020 int memberDataIdx = getMapDataMemberIdx(mapData, memberClause);
6021 // Note: Clang treats arrays with explicit bounds that fall into this
6022 // category as a parent with map case, however, it seems this isn't a
6023 // requirement, and processing them as an individual map is fine. So,
6024 // we will handle them as individual maps for the moment, as it's
6025 // difficult for us to check this as we always require bounds to be
6026 // specified currently and it's also marginally more optimal (single
6027 // map rather than two). The difference may come from the fact that
6028 // Clang maps array without bounds as pointers (which we do not
6029 // currently do), whereas we treat them as arrays in all cases
6030 // currently.
6031 processIndividualMap(mapData, memberDataIdx, combinedInfo, targetDirective,
6032 mapDataIndex);
6033 return;
6034 }
6035
6036 llvm::omp::OpenMPOffloadMappingFlags memberOfParentFlag =
6037 mapParentWithMembers(moduleTranslation, builder, ompBuilder, dl,
6038 combinedInfo, mapData, mapDataIndex,
6039 targetDirective);
6040 processMapMembersWithParent(moduleTranslation, builder, ompBuilder, dl,
6041 combinedInfo, mapData, mapDataIndex,
6042 memberOfParentFlag, targetDirective);
6043}
6044
6045// This is a variation on Clang's GenerateOpenMPCapturedVars, which
6046// generates different operation (e.g. load/store) combinations for
6047// arguments to the kernel, based on map capture kinds which are then
6048// utilised in the combinedInfo in place of the original Map value.
6049static void
6050createAlteredByCaptureMap(MapInfoData &mapData,
6051 LLVM::ModuleTranslation &moduleTranslation,
6052 llvm::IRBuilderBase &builder) {
6053 assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
6054 "function only supported for host device codegen");
6055 for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
6056 // if it's declare target, skip it, it's handled separately.
6057 if (!mapData.IsDeclareTarget[i]) {
6058 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
6059 omp::VariableCaptureKind captureKind = mapOp.getMapCaptureType();
6060 bool isPtrTy = checkIfPointerMap(mapOp);
6061
6062 // Currently handles array sectioning lowerbound case, but more
6063 // logic may be required in the future. Clang invokes EmitLValue,
6064 // which has specialised logic for special Clang types such as user
6065 // defines, so it is possible we will have to extend this for
6066 // structures or other complex types. As the general idea is that this
6067 // function mimics some of the logic from Clang that we require for
6068 // kernel argument passing from host -> device.
6069 switch (captureKind) {
6070 case omp::VariableCaptureKind::ByRef: {
6071 llvm::Value *newV = mapData.Pointers[i];
6072 std::vector<llvm::Value *> offsetIdx = calculateBoundsOffset(
6073 moduleTranslation, builder, mapData.BaseType[i]->isArrayTy(),
6074 mapOp.getBounds());
6075 if (isPtrTy)
6076 newV = builder.CreateLoad(builder.getPtrTy(), newV);
6077
6078 if (!offsetIdx.empty())
6079 newV = builder.CreateInBoundsGEP(mapData.BaseType[i], newV, offsetIdx,
6080 "array_offset");
6081 mapData.Pointers[i] = newV;
6082 } break;
6083 case omp::VariableCaptureKind::ByCopy: {
6084 llvm::Type *type = mapData.BaseType[i];
6085 llvm::Value *newV;
6086 if (mapData.Pointers[i]->getType()->isPointerTy())
6087 newV = builder.CreateLoad(type, mapData.Pointers[i]);
6088 else
6089 newV = mapData.Pointers[i];
6090
6091 if (!isPtrTy) {
6092 auto curInsert = builder.saveIP();
6093 llvm::DebugLoc DbgLoc = builder.getCurrentDebugLocation();
6094 builder.restoreIP(findAllocInsertPoints(builder, moduleTranslation));
6095 auto *memTempAlloc =
6096 builder.CreateAlloca(builder.getPtrTy(), nullptr, ".casted");
6097 builder.SetCurrentDebugLocation(DbgLoc);
6098 builder.restoreIP(curInsert);
6099
6100 builder.CreateStore(newV, memTempAlloc);
6101 newV = builder.CreateLoad(builder.getPtrTy(), memTempAlloc);
6102 }
6103
6104 mapData.Pointers[i] = newV;
6105 mapData.BasePointers[i] = newV;
6106 } break;
6107 case omp::VariableCaptureKind::This:
6108 case omp::VariableCaptureKind::VLAType:
6109 mapData.MapClause[i]->emitOpError("Unhandled capture kind");
6110 break;
6111 }
6112 }
6113 }
6114}
6115
6116// Generate all map related information and fill the combinedInfo.
6117static void genMapInfos(llvm::IRBuilderBase &builder,
6118 LLVM::ModuleTranslation &moduleTranslation,
6119 DataLayout &dl, MapInfosTy &combinedInfo,
6120 MapInfoData &mapData,
6121 TargetDirectiveEnumTy targetDirective) {
6122 assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
6123 "function only supported for host device codegen");
6124 // We wish to modify some of the methods in which arguments are
6125 // passed based on their capture type by the target region, this can
6126 // involve generating new loads and stores, which changes the
6127 // MLIR value to LLVM value mapping, however, we only wish to do this
6128 // locally for the current function/target and also avoid altering
6129 // ModuleTranslation, so we remap the base pointer or pointer stored
6130 // in the map infos corresponding MapInfoData, which is later accessed
6131 // by genMapInfos and createTarget to help generate the kernel and
6132 // kernel arg structure. It primarily becomes relevant in cases like
6133 // bycopy, or byref range'd arrays. In the default case, we simply
6134 // pass thee pointer byref as both basePointer and pointer.
6135 createAlteredByCaptureMap(mapData, moduleTranslation, builder);
6136
6137 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
6138
6139 // We operate under the assumption that all vectors that are
6140 // required in MapInfoData are of equal lengths (either filled with
6141 // default constructed data or appropiate information) so we can
6142 // utilise the size from any component of MapInfoData, if we can't
6143 // something is missing from the initial MapInfoData construction.
6144 for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
6145 // NOTE/TODO: We currently do not support arbitrary depth record
6146 // type mapping.
6147 if (mapData.IsAMember[i])
6148 continue;
6149
6150 auto mapInfoOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[i]);
6151 if (!mapInfoOp.getMembers().empty()) {
6152 processMapWithMembersOf(moduleTranslation, builder, *ompBuilder, dl,
6153 combinedInfo, mapData, i, targetDirective);
6154 continue;
6155 }
6156
6157 processIndividualMap(mapData, i, combinedInfo, targetDirective);
6158 }
6159}
6160
6161static llvm::Expected<llvm::Function *>
6162emitUserDefinedMapper(Operation *declMapperOp, llvm::IRBuilderBase &builder,
6163 LLVM::ModuleTranslation &moduleTranslation,
6164 llvm::StringRef mapperFuncName,
6165 TargetDirectiveEnumTy targetDirective);
6166
6167static llvm::Expected<llvm::Function *>
6168getOrCreateUserDefinedMapperFunc(Operation *op, llvm::IRBuilderBase &builder,
6169 LLVM::ModuleTranslation &moduleTranslation,
6170 TargetDirectiveEnumTy targetDirective) {
6171 assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
6172 "function only supported for host device codegen");
6173 auto declMapperOp = cast<omp::DeclareMapperOp>(op);
6174 std::string mapperFuncName =
6175 moduleTranslation.getOpenMPBuilder()->createPlatformSpecificName(
6176 {"omp_mapper", declMapperOp.getSymName()});
6177
6178 if (auto *lookupFunc = moduleTranslation.lookupFunction(mapperFuncName))
6179 return lookupFunc;
6180
6181 // Recursive types can cause re-entrant mapper emission. The mapper function
6182 // is created by OpenMPIRBuilder before the callbacks run, so it may already
6183 // exist in the LLVM module even though it is not yet registered in the
6184 // ModuleTranslation mapping table. Reuse and register it to break the
6185 // recursion.
6186 if (llvm::Function *existingFunc =
6187 moduleTranslation.getLLVMModule()->getFunction(mapperFuncName)) {
6188 moduleTranslation.mapFunction(mapperFuncName, existingFunc);
6189 return existingFunc;
6190 }
6191
6192 return emitUserDefinedMapper(declMapperOp, builder, moduleTranslation,
6193 mapperFuncName, targetDirective);
6194}
6195
6196static llvm::Expected<llvm::Function *>
6197emitUserDefinedMapper(Operation *op, llvm::IRBuilderBase &builder,
6198 LLVM::ModuleTranslation &moduleTranslation,
6199 llvm::StringRef mapperFuncName,
6200 TargetDirectiveEnumTy targetDirective) {
6201 assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
6202 "function only supported for host device codegen");
6203 auto declMapperOp = cast<omp::DeclareMapperOp>(op);
6204 auto declMapperInfoOp = declMapperOp.getDeclareMapperInfo();
6205 DataLayout dl = DataLayout(declMapperOp->getParentOfType<ModuleOp>());
6206 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
6207 llvm::Type *varType = moduleTranslation.convertType(declMapperOp.getType());
6208 SmallVector<Value> mapVars = declMapperInfoOp.getMapVars();
6209
6210 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
6211
6212 // Fill up the arrays with all the mapped variables.
6213 MapInfosTy combinedInfo;
6214 auto genMapInfoCB =
6215 [&](InsertPointTy codeGenIP, llvm::Value *ptrPHI,
6216 llvm::Value *unused2) -> llvm::OpenMPIRBuilder::MapInfosOrErrorTy {
6217 builder.restoreIP(codeGenIP);
6218 moduleTranslation.mapValue(declMapperOp.getSymVal(), ptrPHI);
6219 moduleTranslation.mapBlock(&declMapperOp.getRegion().front(),
6220 builder.GetInsertBlock());
6221 if (failed(moduleTranslation.convertBlock(declMapperOp.getRegion().front(),
6222 /*ignoreArguments=*/true,
6223 builder)))
6224 return llvm::make_error<PreviouslyReportedError>();
6225 MapInfoData mapData;
6226 collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, dl,
6227 builder);
6228 genMapInfos(builder, moduleTranslation, dl, combinedInfo, mapData,
6229 targetDirective);
6230
6231 // Drop the mapping that is no longer necessary so that the same region
6232 // can be processed multiple times.
6233 moduleTranslation.forgetMapping(declMapperOp.getRegion());
6234 return combinedInfo;
6235 };
6236
6237 auto customMapperCB = [&](unsigned i) -> llvm::Expected<llvm::Function *> {
6238 if (!combinedInfo.Mappers[i])
6239 return nullptr;
6240 return getOrCreateUserDefinedMapperFunc(combinedInfo.Mappers[i], builder,
6241 moduleTranslation, targetDirective);
6242 };
6243
6244 llvm::Expected<llvm::Function *> newFn = ompBuilder->emitUserDefinedMapper(
6245 genMapInfoCB, varType, mapperFuncName, customMapperCB);
6246 if (!newFn)
6247 return newFn.takeError();
6248 if ([[maybe_unused]] llvm::Function *mappedFunc =
6249 moduleTranslation.lookupFunction(mapperFuncName)) {
6250 assert(mappedFunc == *newFn &&
6251 "mapper function mapping disagrees with emitted function");
6252 } else {
6253 moduleTranslation.mapFunction(mapperFuncName, *newFn);
6254 }
6255 return *newFn;
6256}
6257
6258static LogicalResult
6259convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
6260 LLVM::ModuleTranslation &moduleTranslation) {
6261 llvm::Value *ifCond = nullptr;
6262 llvm::Value *deviceID = builder.getInt64(llvm::omp::OMP_DEVICEID_UNDEF);
6263 SmallVector<Value> mapVars;
6264 SmallVector<Value> useDevicePtrVars;
6265 SmallVector<Value> useDeviceAddrVars;
6266 llvm::omp::RuntimeFunction RTLFn;
6267 DataLayout DL = DataLayout(op->getParentOfType<ModuleOp>());
6268 TargetDirectiveEnumTy targetDirective = getTargetDirectiveEnumTyFromOp(op);
6269
6270 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
6271 llvm::OpenMPIRBuilder::TargetDataInfo info(
6272 /*RequiresDevicePointerInfo=*/true,
6273 /*SeparateBeginEndCalls=*/true);
6274 assert(!ompBuilder->Config.isTargetDevice() &&
6275 "target data/enter/exit/update are host ops");
6276 bool isOffloadEntry = !ompBuilder->Config.TargetTriples.empty();
6277
6278 auto getDeviceID = [&](mlir::Value dev) -> llvm::Value * {
6279 llvm::Value *v = moduleTranslation.lookupValue(dev);
6280 return builder.CreateIntCast(v, builder.getInt64Ty(), /*isSigned=*/true);
6281 };
6282
6283 LogicalResult result =
6285 .Case([&](omp::TargetDataOp dataOp) {
6286 if (failed(checkImplementationStatus(*dataOp)))
6287 return failure();
6288
6289 if (auto ifVar = dataOp.getIfExpr())
6290 ifCond = moduleTranslation.lookupValue(ifVar);
6291
6292 if (mlir::Value devId = dataOp.getDevice())
6293 deviceID = getDeviceID(devId);
6294
6295 mapVars = dataOp.getMapVars();
6296 useDevicePtrVars = dataOp.getUseDevicePtrVars();
6297 useDeviceAddrVars = dataOp.getUseDeviceAddrVars();
6298 return success();
6299 })
6300 .Case([&](omp::TargetEnterDataOp enterDataOp) -> LogicalResult {
6301 if (failed(checkImplementationStatus(*enterDataOp)))
6302 return failure();
6303
6304 if (auto ifVar = enterDataOp.getIfExpr())
6305 ifCond = moduleTranslation.lookupValue(ifVar);
6306
6307 if (mlir::Value devId = enterDataOp.getDevice())
6308 deviceID = getDeviceID(devId);
6309
6310 RTLFn =
6311 enterDataOp.getNowait()
6312 ? llvm::omp::OMPRTL___tgt_target_data_begin_nowait_mapper
6313 : llvm::omp::OMPRTL___tgt_target_data_begin_mapper;
6314 mapVars = enterDataOp.getMapVars();
6315 info.HasNoWait = enterDataOp.getNowait();
6316 return success();
6317 })
6318 .Case([&](omp::TargetExitDataOp exitDataOp) -> LogicalResult {
6319 if (failed(checkImplementationStatus(*exitDataOp)))
6320 return failure();
6321
6322 if (auto ifVar = exitDataOp.getIfExpr())
6323 ifCond = moduleTranslation.lookupValue(ifVar);
6324
6325 if (mlir::Value devId = exitDataOp.getDevice())
6326 deviceID = getDeviceID(devId);
6327
6328 RTLFn = exitDataOp.getNowait()
6329 ? llvm::omp::OMPRTL___tgt_target_data_end_nowait_mapper
6330 : llvm::omp::OMPRTL___tgt_target_data_end_mapper;
6331 mapVars = exitDataOp.getMapVars();
6332 info.HasNoWait = exitDataOp.getNowait();
6333 return success();
6334 })
6335 .Case([&](omp::TargetUpdateOp updateDataOp) -> LogicalResult {
6336 if (failed(checkImplementationStatus(*updateDataOp)))
6337 return failure();
6338
6339 if (auto ifVar = updateDataOp.getIfExpr())
6340 ifCond = moduleTranslation.lookupValue(ifVar);
6341
6342 if (mlir::Value devId = updateDataOp.getDevice())
6343 deviceID = getDeviceID(devId);
6344
6345 RTLFn =
6346 updateDataOp.getNowait()
6347 ? llvm::omp::OMPRTL___tgt_target_data_update_nowait_mapper
6348 : llvm::omp::OMPRTL___tgt_target_data_update_mapper;
6349 mapVars = updateDataOp.getMapVars();
6350 info.HasNoWait = updateDataOp.getNowait();
6351 return success();
6352 })
6353 .DefaultUnreachable("unexpected operation");
6354
6355 if (failed(result))
6356 return failure();
6357 // Pretend we have IF(false) if we're not doing offload.
6358 if (!isOffloadEntry)
6359 ifCond = builder.getFalse();
6360
6361 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
6362 MapInfoData mapData;
6363 collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, DL,
6364 builder, useDevicePtrVars, useDeviceAddrVars);
6365
6366 // Fill up the arrays with all the mapped variables.
6367 MapInfosTy combinedInfo;
6368 auto genMapInfoCB = [&](InsertPointTy codeGenIP) -> MapInfosTy & {
6369 builder.restoreIP(codeGenIP);
6370 genMapInfos(builder, moduleTranslation, DL, combinedInfo, mapData,
6371 targetDirective);
6372 return combinedInfo;
6373 };
6374
6375 // Define a lambda to apply mappings between use_device_addr and
6376 // use_device_ptr base pointers, and their associated block arguments.
6377 auto mapUseDevice =
6378 [&moduleTranslation](
6379 llvm::OpenMPIRBuilder::DeviceInfoTy type,
6381 llvm::SmallVectorImpl<Value> &useDeviceVars, MapInfoData &mapInfoData,
6382 llvm::function_ref<llvm::Value *(llvm::Value *)> mapper = nullptr) {
6383 for (auto [arg, useDevVar] :
6384 llvm::zip_equal(blockArgs, useDeviceVars)) {
6385
6386 auto getMapBasePtr = [](omp::MapInfoOp mapInfoOp) {
6387 return mapInfoOp.getVarPtrPtr() ? mapInfoOp.getVarPtrPtr()
6388 : mapInfoOp.getVarPtr();
6389 };
6390
6391 auto useDevMap = cast<omp::MapInfoOp>(useDevVar.getDefiningOp());
6392 for (auto [mapClause, devicePointer, basePointer] : llvm::zip_equal(
6393 mapInfoData.MapClause, mapInfoData.DevicePointers,
6394 mapInfoData.BasePointers)) {
6395 auto mapOp = cast<omp::MapInfoOp>(mapClause);
6396 if (getMapBasePtr(mapOp) != getMapBasePtr(useDevMap) ||
6397 devicePointer != type)
6398 continue;
6399
6400 if (llvm::Value *devPtrInfoMap =
6401 mapper ? mapper(basePointer) : basePointer) {
6402 moduleTranslation.mapValue(arg, devPtrInfoMap);
6403 break;
6404 }
6405 }
6406 }
6407 };
6408
6409 using BodyGenTy = llvm::OpenMPIRBuilder::BodyGenTy;
6410 auto bodyGenCB = [&](InsertPointTy codeGenIP, BodyGenTy bodyGenType)
6411 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
6412 // We must always restoreIP regardless of doing anything the caller
6413 // does not restore it, leading to incorrect (no) branch generation.
6414 builder.restoreIP(codeGenIP);
6415 assert(isa<omp::TargetDataOp>(op) &&
6416 "BodyGen requested for non TargetDataOp");
6417 auto blockArgIface = cast<omp::BlockArgOpenMPOpInterface>(op);
6418 Region &region = cast<omp::TargetDataOp>(op).getRegion();
6419 switch (bodyGenType) {
6420 case BodyGenTy::Priv:
6421 // Check if any device ptr/addr info is available
6422 if (!info.DevicePtrInfoMap.empty()) {
6423 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
6424 blockArgIface.getUseDeviceAddrBlockArgs(),
6425 useDeviceAddrVars, mapData,
6426 [&](llvm::Value *basePointer) -> llvm::Value * {
6427 if (!info.DevicePtrInfoMap[basePointer].second)
6428 return nullptr;
6429 return builder.CreateLoad(
6430 builder.getPtrTy(),
6431 info.DevicePtrInfoMap[basePointer].second);
6432 });
6433 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
6434 blockArgIface.getUseDevicePtrBlockArgs(), useDevicePtrVars,
6435 mapData, [&](llvm::Value *basePointer) {
6436 return info.DevicePtrInfoMap[basePointer].second;
6437 });
6438
6439 if (failed(inlineConvertOmpRegions(region, "omp.data.region", builder,
6440 moduleTranslation)))
6441 return llvm::make_error<PreviouslyReportedError>();
6442 }
6443 break;
6444 case BodyGenTy::DupNoPriv:
6445 if (info.DevicePtrInfoMap.empty()) {
6446 // For host device we still need to do the mapping for codegen,
6447 // otherwise it may try to lookup a missing value.
6448 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
6449 blockArgIface.getUseDeviceAddrBlockArgs(),
6450 useDeviceAddrVars, mapData);
6451 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
6452 blockArgIface.getUseDevicePtrBlockArgs(), useDevicePtrVars,
6453 mapData);
6454 }
6455 break;
6456 case BodyGenTy::NoPriv:
6457 // If device info is available then region has already been generated
6458 if (info.DevicePtrInfoMap.empty()) {
6459 if (failed(inlineConvertOmpRegions(region, "omp.data.region", builder,
6460 moduleTranslation)))
6461 return llvm::make_error<PreviouslyReportedError>();
6462 }
6463 break;
6464 }
6465 return builder.saveIP();
6466 };
6467
6468 auto customMapperCB =
6469 [&](unsigned int i) -> llvm::Expected<llvm::Function *> {
6470 if (!combinedInfo.Mappers[i])
6471 return nullptr;
6472 info.HasMapper = true;
6473 return getOrCreateUserDefinedMapperFunc(combinedInfo.Mappers[i], builder,
6474 moduleTranslation, targetDirective);
6475 };
6476
6477 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
6479 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
6480 findAllocInsertPoints(builder, moduleTranslation, &deallocBlocks);
6481 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = [&]() {
6482 if (isa<omp::TargetDataOp>(op))
6483 return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(),
6484 deallocBlocks, deviceID, ifCond, info,
6485 genMapInfoCB, customMapperCB,
6486 /*MapperFunc=*/nullptr, bodyGenCB,
6487 /*DeviceAddrCB=*/nullptr);
6488 return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(),
6489 deallocBlocks, deviceID, ifCond, info,
6490 genMapInfoCB, customMapperCB, &RTLFn);
6491 }();
6492
6493 if (failed(handleError(afterIP, *op)))
6494 return failure();
6495
6496 builder.restoreIP(*afterIP);
6497 return success();
6498}
6499
6500static LogicalResult
6501convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder,
6502 LLVM::ModuleTranslation &moduleTranslation) {
6503 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
6504 auto distributeOp = cast<omp::DistributeOp>(opInst);
6505 if (failed(checkImplementationStatus(opInst)))
6506 return failure();
6507
6508 /// Process teams op reduction in distribute if the reduction is contained in
6509 /// this specific distribute op.
6510 omp::TeamsOp teamsOp = opInst.getParentOfType<omp::TeamsOp>();
6511 bool doDistributeReduction =
6512 teamsOp && getDistributeCapturingTeamsReduction(teamsOp) == distributeOp;
6513
6514 DenseMap<Value, llvm::Value *> reductionVariableMap;
6515 unsigned numReductionVars = teamsOp ? teamsOp.getNumReductionVars() : 0;
6517 SmallVector<llvm::Value *> privateReductionVariables(numReductionVars);
6518 llvm::ArrayRef<bool> isByRef;
6519
6520 if (doDistributeReduction) {
6521 isByRef = getIsByRef(teamsOp.getReductionByref());
6522 assert(isByRef.size() == teamsOp.getNumReductionVars());
6523
6524 collectReductionDecls(teamsOp, reductionDecls);
6525 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
6526 findAllocInsertPoints(builder, moduleTranslation);
6527
6528 MutableArrayRef<BlockArgument> reductionArgs =
6529 llvm::cast<omp::BlockArgOpenMPOpInterface>(*teamsOp)
6530 .getReductionBlockArgs();
6531
6533 teamsOp, reductionArgs, builder, moduleTranslation, allocaIP,
6534 reductionDecls, privateReductionVariables, reductionVariableMap,
6535 isByRef)))
6536 return failure();
6537 }
6538
6539 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
6540 auto bodyGenCB =
6541 [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
6542 llvm::ArrayRef<llvm::BasicBlock *> deallocBlocks) -> llvm::Error {
6543 // Save the alloca insertion point on ModuleTranslation stack for use in
6544 // nested regions.
6546 moduleTranslation, allocaIP, deallocBlocks);
6547
6548 // DistributeOp has only one region associated with it.
6549 builder.restoreIP(codeGenIP);
6550 PrivateVarsInfo privVarsInfo(distributeOp);
6551
6553 distributeOp, builder, moduleTranslation, privVarsInfo, allocaIP);
6554 if (handleError(afterAllocas, opInst).failed())
6555 return llvm::make_error<PreviouslyReportedError>();
6556
6557 if (handleError(initPrivateVars(builder, moduleTranslation, privVarsInfo),
6558 opInst)
6559 .failed())
6560 return llvm::make_error<PreviouslyReportedError>();
6561
6562 if (failed(copyFirstPrivateVars(
6563 distributeOp, builder, moduleTranslation, privVarsInfo.mlirVars,
6564 privVarsInfo.llvmVars, privVarsInfo.privatizers,
6565 distributeOp.getPrivateNeedsBarrier())))
6566 return llvm::make_error<PreviouslyReportedError>();
6567
6568 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
6569 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
6571 convertOmpOpRegions(distributeOp.getRegion(), "omp.distribute.region",
6572 builder, moduleTranslation);
6573 if (!regionBlock)
6574 return regionBlock.takeError();
6575 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
6576
6577 // Skip applying a workshare loop below when translating 'distribute
6578 // parallel do' (it's been already handled by this point while translating
6579 // the nested omp.wsloop).
6580 if (!isa_and_present<omp::WsloopOp>(distributeOp.getNestedWrapper())) {
6581 // TODO: Add support for clauses which are valid for DISTRIBUTE
6582 // constructs. Static schedule is the default.
6583 bool hasDistSchedule = distributeOp.getDistScheduleStatic();
6584 auto schedule = hasDistSchedule ? omp::ClauseScheduleKind::Distribute
6585 : omp::ClauseScheduleKind::Static;
6586 // dist_schedule clauses are ordered - otherise this should be false
6587 bool isOrdered = hasDistSchedule;
6588 std::optional<omp::ScheduleModifier> scheduleMod;
6589 bool isSimd = false;
6590 llvm::omp::WorksharingLoopType workshareLoopType =
6591 llvm::omp::WorksharingLoopType::DistributeStaticLoop;
6592 bool loopNeedsBarrier = false;
6593 llvm::Value *chunk = moduleTranslation.lookupValue(
6594 distributeOp.getDistScheduleChunkSize());
6595 llvm::CanonicalLoopInfo *loopInfo =
6596 findCurrentLoopInfo(moduleTranslation);
6597 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
6598 ompBuilder->applyWorkshareLoop(
6599 ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
6600 convertToScheduleKind(schedule), chunk, isSimd,
6601 scheduleMod == omp::ScheduleModifier::monotonic,
6602 scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
6603 workshareLoopType, false, hasDistSchedule, chunk);
6604
6605 if (!wsloopIP)
6606 return wsloopIP.takeError();
6607 }
6608 if (failed(cleanupPrivateVars(distributeOp, builder, moduleTranslation,
6609 distributeOp.getLoc(), privVarsInfo)))
6610 return llvm::make_error<PreviouslyReportedError>();
6611
6612 return llvm::Error::success();
6613 };
6614
6616 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
6617 findAllocInsertPoints(builder, moduleTranslation, &deallocBlocks);
6618 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
6619 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
6620 ompBuilder->createDistribute(ompLoc, allocaIP, deallocBlocks, bodyGenCB);
6621
6622 if (failed(handleError(afterIP, opInst)))
6623 return failure();
6624
6625 builder.restoreIP(*afterIP);
6626
6627 if (doDistributeReduction) {
6628 // Process the reductions if required.
6630 teamsOp, builder, moduleTranslation, allocaIP, reductionDecls,
6631 privateReductionVariables, isByRef,
6632 /*isNoWait*/ false, /*isTeamsReduction*/ true);
6633 }
6634 return success();
6635}
6636
6637/// Lowers the FlagsAttr which is applied to the module on the device
6638/// pass when offloading, this attribute contains OpenMP RTL globals that can
6639/// be passed as flags to the frontend, otherwise they are set to default
6640static LogicalResult
6641convertFlagsAttr(Operation *op, mlir::omp::FlagsAttr attribute,
6642 LLVM::ModuleTranslation &moduleTranslation) {
6643 if (!cast<mlir::ModuleOp>(op))
6644 return failure();
6645
6646 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
6647
6648 ompBuilder->M.addModuleFlag(llvm::Module::Max, "openmp-device",
6649 attribute.getOpenmpDeviceVersion());
6650
6651 if (attribute.getNoGpuLib())
6652 return success();
6653
6654 ompBuilder->createGlobalFlag(
6655 attribute.getDebugKind() /*LangOpts().OpenMPTargetDebug*/,
6656 "__omp_rtl_debug_kind");
6657 ompBuilder->createGlobalFlag(
6658 attribute
6659 .getAssumeTeamsOversubscription() /*LangOpts().OpenMPTeamSubscription*/
6660 ,
6661 "__omp_rtl_assume_teams_oversubscription");
6662 ompBuilder->createGlobalFlag(
6663 attribute
6664 .getAssumeThreadsOversubscription() /*LangOpts().OpenMPThreadSubscription*/
6665 ,
6666 "__omp_rtl_assume_threads_oversubscription");
6667 ompBuilder->createGlobalFlag(
6668 attribute.getAssumeNoThreadState() /*LangOpts().OpenMPNoThreadState*/,
6669 "__omp_rtl_assume_no_thread_state");
6670 ompBuilder->createGlobalFlag(
6671 attribute
6672 .getAssumeNoNestedParallelism() /*LangOpts().OpenMPNoNestedParallelism*/
6673 ,
6674 "__omp_rtl_assume_no_nested_parallelism");
6675 return success();
6676}
6677
6678static void getTargetEntryUniqueInfo(llvm::TargetRegionEntryInfo &targetInfo,
6679 omp::TargetOp targetOp,
6680 llvm::OpenMPIRBuilder &ompBuilder,
6681 llvm::vfs::FileSystem &vfs,
6682 llvm::StringRef parentName = "") {
6683 auto fileLoc = targetOp.getLoc()->findInstanceOf<FileLineColLoc>();
6684 assert(fileLoc && "No file found from location");
6685
6686 auto fileInfoCallBack = [&fileLoc]() {
6687 return std::pair<std::string, uint64_t>(
6688 llvm::StringRef(fileLoc.getFilename()), fileLoc.getLine());
6689 };
6690
6691 targetInfo =
6692 ompBuilder.getTargetEntryUniqueInfo(fileInfoCallBack, vfs, parentName);
6693}
6694
6695static void
6696handleDeclareTargetMapVar(MapInfoData &mapData,
6697 LLVM::ModuleTranslation &moduleTranslation,
6698 llvm::IRBuilderBase &builder, llvm::Function *func) {
6699 assert(moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
6700 "function only supported for target device codegen");
6701 llvm::IRBuilderBase::InsertPointGuard guard(builder);
6702 for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
6703 // In the case of declare target mapped variables, the basePointer is
6704 // the reference pointer generated by the convertDeclareTargetAttr
6705 // method. Whereas the kernelValue is the original variable, so for
6706 // the device we must replace all uses of this original global variable
6707 // (stored in kernelValue) with the reference pointer (stored in
6708 // basePointer for declare target mapped variables), as for device the
6709 // data is mapped into this reference pointer and should be loaded
6710 // from it, the original variable is discarded. On host both exist and
6711 // metadata is generated (elsewhere in the convertDeclareTargetAttr)
6712 // function to link the two variables in the runtime and then both the
6713 // reference pointer and the pointer are assigned in the kernel argument
6714 // structure for the host.
6715 if (mapData.IsDeclareTarget[i]) {
6716 // If the original map value is a constant, then we have to make sure all
6717 // of it's uses within the current kernel/function that we are going to
6718 // rewrite are converted to instructions, as we will be altering the old
6719 // use (OriginalValue) from a constant to an instruction, which will be
6720 // illegal and ICE the compiler if the user is a constant expression of
6721 // some kind e.g. a constant GEP.
6722 if (auto *constant = dyn_cast<llvm::Constant>(mapData.OriginalValue[i]))
6723 convertUsersOfConstantsToInstructions(constant, func, false);
6724
6725 // The users iterator will get invalidated if we modify an element,
6726 // so we populate this vector of uses to alter each user on an
6727 // individual basis to emit its own load (rather than one load for
6728 // all).
6730 for (llvm::User *user : mapData.OriginalValue[i]->users())
6731 userVec.push_back(user);
6732
6733 for (llvm::User *user : userVec) {
6734 if (auto *insn = dyn_cast<llvm::Instruction>(user)) {
6735 if (insn->getFunction() == func) {
6736 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
6737 llvm::Value *substitute = mapData.BasePointers[i];
6738 if (isDeclareTargetLink(mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr()
6739 : mapOp.getVarPtr())) {
6740 builder.SetCurrentDebugLocation(insn->getDebugLoc());
6741 substitute = builder.CreateLoad(
6742 mapData.BasePointers[i]->getType(), mapData.BasePointers[i]);
6743 cast<llvm::LoadInst>(substitute)->moveBefore(insn->getIterator());
6744 }
6745 user->replaceUsesOfWith(mapData.OriginalValue[i], substitute);
6746 }
6747 }
6748 }
6749 }
6750 }
6751}
6752
6753// The createDeviceArgumentAccessor function generates
6754// instructions for retrieving (acessing) kernel
6755// arguments inside of the device kernel for use by
6756// the kernel. This enables different semantics such as
6757// the creation of temporary copies of data allowing
6758// semantics like read-only/no host write back kernel
6759// arguments.
6760//
6761// This currently implements a very light version of Clang's
6762// EmitParmDecl's handling of direct argument handling as well
6763// as a portion of the argument access generation based on
6764// capture types found at the end of emitOutlinedFunctionPrologue
6765// in Clang. The indirect path handling of EmitParmDecl's may be
6766// required for future work, but a direct 1-to-1 copy doesn't seem
6767// possible as the logic is rather scattered throughout Clang's
6768// lowering and perhaps we wish to deviate slightly.
6769//
6770// \param mapData - A container containing vectors of information
6771// corresponding to the input argument, which should have a
6772// corresponding entry in the MapInfoData containers
6773// OrigialValue's.
6774// \param arg - This is the generated kernel function argument that
6775// corresponds to the passed in input argument. We generated different
6776// accesses of this Argument, based on capture type and other Input
6777// related information.
6778// \param input - This is the host side value that will be passed to
6779// the kernel i.e. the kernel input, we rewrite all uses of this within
6780// the kernel (as we generate the kernel body based on the target's region
6781// which maintians references to the original input) to the retVal argument
6782// apon exit of this function inside of the OMPIRBuilder. This interlinks
6783// the kernel argument to future uses of it in the function providing
6784// appropriate "glue" instructions inbetween.
6785// \param retVal - This is the value that all uses of input inside of the
6786// kernel will be re-written to, the goal of this function is to generate
6787// an appropriate location for the kernel argument to be accessed from,
6788// e.g. ByRef will result in a temporary allocation location and then
6789// a store of the kernel argument into this allocated memory which
6790// will then be loaded from, ByCopy will use the allocated memory
6791// directly.
6792static llvm::IRBuilderBase::InsertPoint createDeviceArgumentAccessor(
6793 omp::TargetOp targetOp, MapInfoData &mapData, llvm::Argument &arg,
6794 llvm::Value *input, llvm::Value *&retVal, llvm::IRBuilderBase &builder,
6795 llvm::OpenMPIRBuilder &ompBuilder,
6796 LLVM::ModuleTranslation &moduleTranslation,
6797 llvm::IRBuilderBase::InsertPoint allocaIP,
6798 llvm::IRBuilderBase::InsertPoint codeGenIP,
6800 assert(ompBuilder.Config.isTargetDevice() &&
6801 "function only supported for target device codegen");
6802 builder.restoreIP(allocaIP);
6803
6804 omp::VariableCaptureKind capture = omp::VariableCaptureKind::ByRef;
6805 LLVM::TypeToLLVMIRTranslator typeToLLVMIRTranslator(
6806 ompBuilder.M.getContext());
6807 unsigned alignmentValue = 0;
6808 BlockArgument mlirArg;
6810 cast<omp::BlockArgOpenMPOpInterface>(*targetOp).getBlockArgsPairs(
6811 blockArgsPairs);
6812 // Find the associated MapInfoData entry for the current input
6813 for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
6814 if (mapData.OriginalValue[i] == input) {
6815 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
6816 capture = mapOp.getMapCaptureType();
6817 // Get information of alignment of mapped object
6818 alignmentValue = typeToLLVMIRTranslator.getPreferredAlignment(
6819 mapOp.getVarType(), ompBuilder.M.getDataLayout());
6820
6821 // Find the corresponding entry block argument, which can be associated to
6822 // a map, use_device* or has_device* clause.
6823 for (auto &[val, arg] : blockArgsPairs) {
6824 if (mapOp.getResult() == val) {
6825 mlirArg = arg;
6826 break;
6827 }
6828 }
6829 assert(mlirArg && "expected to find entry block argument for map clause");
6830 break;
6831 }
6832 }
6833
6834 unsigned int allocaAS = ompBuilder.M.getDataLayout().getAllocaAddrSpace();
6835 unsigned int defaultAS =
6836 ompBuilder.M.getDataLayout().getProgramAddressSpace();
6837
6838 // Create the allocation for the argument.
6839 llvm::Value *v = nullptr;
6840 if (omp::opInSharedDeviceContext(*targetOp) &&
6842 // Use the beginning of the codeGenIP rather than the usual allocation point
6843 // for shared memory allocations because otherwise these would be done prior
6844 // to the target initialization call. Also, the exit block (where the
6845 // deallocation is placed) is only executed if the initialization call
6846 // succeeds.
6847 builder.SetInsertPoint(codeGenIP.getBlock()->getFirstInsertionPt());
6848 v = ompBuilder.createOMPAllocShared(builder, arg.getType());
6849
6850 // Create deallocations in all provided deallocation points and then restore
6851 // the insertion point to right after the new allocations.
6852 llvm::IRBuilderBase::InsertPointGuard guard(builder);
6853 for (auto deallocIP : deallocIPs) {
6854 builder.SetInsertPoint(deallocIP.getBlock(), deallocIP.getPoint());
6855 ompBuilder.createOMPFreeShared(builder, v, arg.getType());
6856 }
6857 } else {
6858 // Use the current point, which was previously set to allocaIP.
6859 v = builder.CreateAlloca(arg.getType(), allocaAS);
6860
6861 if (allocaAS != defaultAS && arg.getType()->isPointerTy())
6862 v = builder.CreateAddrSpaceCast(v, builder.getPtrTy(defaultAS));
6863 }
6864
6865 builder.CreateStore(&arg, v);
6866
6867 builder.restoreIP(codeGenIP);
6868
6869 switch (capture) {
6870 case omp::VariableCaptureKind::ByCopy: {
6871 retVal = v;
6872 break;
6873 }
6874 case omp::VariableCaptureKind::ByRef: {
6875 llvm::LoadInst *loadInst = builder.CreateAlignedLoad(
6876 v->getType(), v,
6877 ompBuilder.M.getDataLayout().getPrefTypeAlign(v->getType()));
6878 // CreateAlignedLoad function creates similar LLVM IR:
6879 // %res = load ptr, ptr %input, align 8
6880 // This LLVM IR does not contain information about alignment
6881 // of the loaded value. We need to add !align metadata to unblock
6882 // optimizer. The existence of the !align metadata on the instruction
6883 // tells the optimizer that the value loaded is known to be aligned to
6884 // a boundary specified by the integer value in the metadata node.
6885 // Example:
6886 // %res = load ptr, ptr %input, align 8, !align !align_md_node
6887 // ^ ^
6888 // | |
6889 // alignment of %input address |
6890 // |
6891 // alignment of %res object
6892 if (v->getType()->isPointerTy() && alignmentValue) {
6893 llvm::MDBuilder MDB(builder.getContext());
6894 loadInst->setMetadata(
6895 llvm::LLVMContext::MD_align,
6896 llvm::MDNode::get(builder.getContext(),
6897 MDB.createConstant(llvm::ConstantInt::get(
6898 llvm::Type::getInt64Ty(builder.getContext()),
6899 alignmentValue))));
6900 }
6901 retVal = loadInst;
6902
6903 break;
6904 }
6905 case omp::VariableCaptureKind::This:
6906 case omp::VariableCaptureKind::VLAType:
6907 // TODO: Consider returning error to use standard reporting for
6908 // unimplemented features.
6909 assert(false && "Currently unsupported capture kind");
6910 break;
6911 }
6912
6913 return builder.saveIP();
6914}
6915
6916/// Follow uses of `host_eval`-defined block arguments of the given `omp.target`
6917/// operation and populate output variables with their corresponding host value
6918/// (i.e. operand evaluated outside of the target region), based on their uses
6919/// inside of the target region.
6920///
6921/// Loop bounds and steps are only optionally populated, if output vectors are
6922/// provided.
6923static void
6924extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
6925 Value &numTeamsLower, Value &numTeamsUpper,
6926 Value &threadLimit,
6927 llvm::SmallVectorImpl<Value> *lowerBounds = nullptr,
6928 llvm::SmallVectorImpl<Value> *upperBounds = nullptr,
6929 llvm::SmallVectorImpl<Value> *steps = nullptr) {
6930 auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
6931 for (auto item : llvm::zip_equal(targetOp.getHostEvalVars(),
6932 blockArgIface.getHostEvalBlockArgs())) {
6933 Value hostEvalVar = std::get<0>(item), blockArg = std::get<1>(item);
6934
6935 for (Operation *user : blockArg.getUsers()) {
6937 .Case([&](omp::TeamsOp teamsOp) {
6938 if (teamsOp.getNumTeamsLower() == blockArg)
6939 numTeamsLower = hostEvalVar;
6940 else if (llvm::is_contained(teamsOp.getNumTeamsUpperVars(),
6941 blockArg))
6942 numTeamsUpper = hostEvalVar;
6943 else if (!teamsOp.getThreadLimitVars().empty() &&
6944 teamsOp.getThreadLimit(0) == blockArg)
6945 threadLimit = hostEvalVar;
6946 else
6947 llvm_unreachable("unsupported host_eval use");
6948 })
6949 .Case([&](omp::ParallelOp parallelOp) {
6950 if (!parallelOp.getNumThreadsVars().empty() &&
6951 parallelOp.getNumThreads(0) == blockArg)
6952 numThreads = hostEvalVar;
6953 else
6954 llvm_unreachable("unsupported host_eval use");
6955 })
6956 .Case([&](omp::LoopNestOp loopOp) {
6957 auto processBounds =
6958 [&](OperandRange opBounds,
6959 llvm::SmallVectorImpl<Value> *outBounds) -> bool {
6960 bool found = false;
6961 for (auto [i, lb] : llvm::enumerate(opBounds)) {
6962 if (lb == blockArg) {
6963 found = true;
6964 if (outBounds)
6965 (*outBounds)[i] = hostEvalVar;
6966 }
6967 }
6968 return found;
6969 };
6970 bool found =
6971 processBounds(loopOp.getLoopLowerBounds(), lowerBounds);
6972 found = processBounds(loopOp.getLoopUpperBounds(), upperBounds) ||
6973 found;
6974 found = processBounds(loopOp.getLoopSteps(), steps) || found;
6975 (void)found;
6976 assert(found && "unsupported host_eval use");
6977 })
6978 .DefaultUnreachable("unsupported host_eval use");
6979 }
6980 }
6981}
6982
6983/// If \p op is of the given type parameter, return it casted to that type.
6984/// Otherwise, if its immediate parent operation (or some other higher-level
6985/// parent, if \p immediateParent is false) is of that type, return that parent
6986/// casted to the given type.
6987///
6988/// If \p op is \c null or neither it or its parent(s) are of the specified
6989/// type, return a \c null operation.
6990template <typename OpTy>
6991static OpTy castOrGetParentOfType(Operation *op, bool immediateParent = false) {
6992 if (!op)
6993 return OpTy();
6994
6995 if (OpTy casted = dyn_cast<OpTy>(op))
6996 return casted;
6997
6998 if (immediateParent)
6999 return dyn_cast_if_present<OpTy>(op->getParentOp());
7000
7001 return op->getParentOfType<OpTy>();
7002}
7003
7004/// If the given \p value is defined by an \c llvm.mlir.constant operation and
7005/// it is of an integer type, return its value.
7006static std::optional<int64_t> extractConstInteger(Value value) {
7007 if (!value)
7008 return std::nullopt;
7009
7010 if (auto constOp = value.getDefiningOp<LLVM::ConstantOp>())
7011 if (auto constAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
7012 return constAttr.getInt();
7013
7014 return std::nullopt;
7015}
7016
7017static uint64_t getTypeByteSize(mlir::Type type, const DataLayout &dl) {
7018 uint64_t sizeInBits = dl.getTypeSizeInBits(type);
7019 uint64_t sizeInBytes = sizeInBits / 8;
7020 return sizeInBytes;
7021}
7022
7023template <typename OpTy>
7024static uint64_t getReductionDataSize(OpTy &op) {
7025 if (op.getNumReductionVars() > 0) {
7027 collectReductionDecls(op, reductions);
7028
7030 members.reserve(reductions.size());
7031 for (omp::DeclareReductionOp &red : reductions) {
7032 // For by-ref reductions, use the actual element type rather than the
7033 // pointer type so that the buffer size matches the access pattern in
7034 // the copy/reduce callbacks generated by OMPIRBuilder.
7035 if (red.getByrefElementType())
7036 members.push_back(*red.getByrefElementType());
7037 else
7038 members.push_back(red.getType());
7039 }
7040 Operation *opp = op.getOperation();
7041 auto structType = mlir::LLVM::LLVMStructType::getLiteral(
7042 opp->getContext(), members, /*isPacked=*/false);
7043 DataLayout dl = DataLayout(opp->getParentOfType<ModuleOp>());
7044 return getTypeByteSize(structType, dl);
7045 }
7046 return 0;
7047}
7048
7049/// Populate default `MinTeams`, `MaxTeams` and `MaxThreads` to their default
7050/// values as stated by the corresponding clauses, if constant.
7051///
7052/// These default values must be set before the creation of the outlined LLVM
7053/// function for the target region, so that they can be used to initialize the
7054/// corresponding global `ConfigurationEnvironmentTy` structure.
7055static void
7056initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
7057 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs,
7058 bool isTargetDevice, bool isGPU) {
7059 // TODO: Handle constant 'if' clauses.
7060
7061 Value numThreads, numTeamsLower, numTeamsUpper, threadLimit;
7062 if (!isTargetDevice) {
7063 extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper,
7064 threadLimit);
7065 } else {
7066 // In the target device, values for these clauses are not passed as
7067 // host_eval, but instead evaluated prior to entry to the region. This
7068 // ensures values are mapped and available inside of the target region.
7069 if (auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
7070 numTeamsLower = teamsOp.getNumTeamsLower();
7071 // Handle num_teams upper bounds (only first value for now)
7072 if (!teamsOp.getNumTeamsUpperVars().empty())
7073 numTeamsUpper = teamsOp.getNumTeams(0);
7074 if (!teamsOp.getThreadLimitVars().empty())
7075 threadLimit = teamsOp.getThreadLimit(0);
7076 }
7077
7078 if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) {
7079 if (!parallelOp.getNumThreadsVars().empty())
7080 numThreads = parallelOp.getNumThreads(0);
7081 }
7082 }
7083
7084 // Handle clauses impacting the number of teams.
7085
7086 int32_t minTeamsVal = 1, maxTeamsVal = -1;
7087 if (castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
7088 // TODO: Use `hostNumTeamsLower` to initialize `minTeamsVal`. For now,
7089 // match clang and set min and max to the same value.
7090 if (numTeamsUpper) {
7091 if (auto val = extractConstInteger(numTeamsUpper))
7092 minTeamsVal = maxTeamsVal = *val;
7093 } else {
7094 minTeamsVal = maxTeamsVal = 0;
7095 }
7096 } else if (castOrGetParentOfType<omp::ParallelOp>(capturedOp,
7097 /*immediateParent=*/true) ||
7099 /*immediateParent=*/true)) {
7100 minTeamsVal = maxTeamsVal = 1;
7101 } else {
7102 minTeamsVal = maxTeamsVal = -1;
7103 }
7104
7105 // Handle clauses impacting the number of threads.
7106
7107 auto setMaxValueFromClause = [](Value clauseValue, int32_t &result) {
7108 if (!clauseValue)
7109 return;
7110
7111 if (auto val = extractConstInteger(clauseValue))
7112 result = *val;
7113
7114 // Found an applicable clause, so it's not undefined. Mark as unknown
7115 // because it's not constant.
7116 if (result < 0)
7117 result = 0;
7118 };
7119
7120 // Extract 'thread_limit' clause from 'target' and 'teams' directives.
7121 int32_t targetThreadLimitVal = -1, teamsThreadLimitVal = -1;
7122 if (!targetOp.getThreadLimitVars().empty())
7123 setMaxValueFromClause(targetOp.getThreadLimit(0), targetThreadLimitVal);
7124 setMaxValueFromClause(threadLimit, teamsThreadLimitVal);
7125
7126 // Extract 'max_threads' clause from 'parallel' or set to 1 if it's SIMD.
7127 int32_t maxThreadsVal = -1;
7129 setMaxValueFromClause(numThreads, maxThreadsVal);
7130 else if (castOrGetParentOfType<omp::SimdOp>(capturedOp,
7131 /*immediateParent=*/true))
7132 maxThreadsVal = 1;
7133
7134 // For max values, < 0 means unset, == 0 means set but unknown. Select the
7135 // minimum value between 'max_threads' and 'thread_limit' clauses that were
7136 // set.
7137 int32_t combinedMaxThreadsVal = targetThreadLimitVal;
7138 if (combinedMaxThreadsVal < 0 ||
7139 (teamsThreadLimitVal >= 0 && teamsThreadLimitVal < combinedMaxThreadsVal))
7140 combinedMaxThreadsVal = teamsThreadLimitVal;
7141
7142 if (combinedMaxThreadsVal < 0 ||
7143 (maxThreadsVal >= 0 && maxThreadsVal < combinedMaxThreadsVal))
7144 combinedMaxThreadsVal = maxThreadsVal;
7145
7146 int32_t reductionDataSize = 0;
7147 if (isGPU && capturedOp) {
7148 if (auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp))
7149 reductionDataSize = getReductionDataSize(teamsOp);
7150 }
7151
7152 // Update kernel bounds structure for the `OpenMPIRBuilder` to use.
7153 omp::TargetExecMode execMode = targetOp.getKernelExecFlags(capturedOp);
7154 switch (execMode) {
7155 case omp::TargetExecMode::bare:
7156 attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_BARE;
7157 break;
7158 case omp::TargetExecMode::generic:
7159 attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_GENERIC;
7160 break;
7161 case omp::TargetExecMode::spmd:
7162 attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_SPMD;
7163 break;
7164 case omp::TargetExecMode::no_loop:
7165 attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_SPMD_NO_LOOP;
7166 break;
7167 }
7168 attrs.MinTeams = minTeamsVal;
7169 attrs.MaxTeams.front() = maxTeamsVal;
7170 attrs.MinThreads = 1;
7171 attrs.MaxThreads.front() = combinedMaxThreadsVal;
7172 attrs.ReductionDataSize = reductionDataSize;
7173 // TODO: Allow modified buffer length similar to
7174 // fopenmp-cuda-teams-reduction-recs-num flag in clang.
7175 if (attrs.ReductionDataSize != 0)
7176 attrs.ReductionBufferLength = 1024;
7177}
7178
7179/// Gather LLVM runtime values for all clauses evaluated in the host that are
7180/// passed to the kernel invocation.
7181///
7182/// This function must be called only when compiling for the host. Also, it will
7183/// only provide correct results if it's called after the body of \c targetOp
7184/// has been fully generated.
7185static void
7186initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
7187 LLVM::ModuleTranslation &moduleTranslation,
7188 omp::TargetOp targetOp, Operation *capturedOp,
7189 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs) {
7190 omp::LoopNestOp loopOp = castOrGetParentOfType<omp::LoopNestOp>(capturedOp);
7191 unsigned numLoops = loopOp ? loopOp.getNumLoops() : 0;
7192
7193 Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
7194 llvm::SmallVector<Value> lowerBounds(numLoops), upperBounds(numLoops),
7195 steps(numLoops);
7196 extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper,
7197 teamsThreadLimit, &lowerBounds, &upperBounds, &steps);
7198
7199 // TODO: Handle constant 'if' clauses.
7200 if (!targetOp.getThreadLimitVars().empty()) {
7201 Value targetThreadLimit = targetOp.getThreadLimit(0);
7202 attrs.TargetThreadLimit.front() =
7203 moduleTranslation.lookupValue(targetThreadLimit);
7204 }
7205
7206 // The __kmpc_push_num_teams_51 function expects int32 as the arguments. So,
7207 // truncate or sign extend lower and upper num_teams bounds as well as
7208 // thread_limit to match int32 ABI requirements for the OpenMP runtime.
7209 if (numTeamsLower)
7210 attrs.MinTeams = builder.CreateSExtOrTrunc(
7211 moduleTranslation.lookupValue(numTeamsLower), builder.getInt32Ty());
7212
7213 if (numTeamsUpper)
7214 attrs.MaxTeams.front() = builder.CreateSExtOrTrunc(
7215 moduleTranslation.lookupValue(numTeamsUpper), builder.getInt32Ty());
7216
7217 if (teamsThreadLimit)
7218 attrs.TeamsThreadLimit.front() = builder.CreateSExtOrTrunc(
7219 moduleTranslation.lookupValue(teamsThreadLimit), builder.getInt32Ty());
7220
7221 if (numThreads)
7222 attrs.MaxThreads = moduleTranslation.lookupValue(numThreads);
7223
7224 bool hostEvalTripCount;
7225 targetOp.getKernelExecFlags(capturedOp, &hostEvalTripCount);
7226 if (hostEvalTripCount) {
7227 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
7228 attrs.LoopTripCount = nullptr;
7229
7230 // To calculate the trip count, we multiply together the trip counts of
7231 // every collapsed canonical loop. We don't need to create the loop nests
7232 // here, since we're only interested in the trip count.
7233 for (auto [loopLower, loopUpper, loopStep] :
7234 llvm::zip_equal(lowerBounds, upperBounds, steps)) {
7235 llvm::Value *lowerBound = moduleTranslation.lookupValue(loopLower);
7236 llvm::Value *upperBound = moduleTranslation.lookupValue(loopUpper);
7237 llvm::Value *step = moduleTranslation.lookupValue(loopStep);
7238
7239 if (!lowerBound || !upperBound || !step) {
7240 attrs.LoopTripCount = nullptr;
7241 break;
7242 }
7243
7244 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
7245 llvm::Value *tripCount = ompBuilder->calculateCanonicalLoopTripCount(
7246 loc, lowerBound, upperBound, step, /*IsSigned=*/true,
7247 loopOp.getLoopInclusive());
7248
7249 if (!attrs.LoopTripCount) {
7250 attrs.LoopTripCount = tripCount;
7251 continue;
7252 }
7253
7254 // TODO: Enable UndefinedSanitizer to diagnose an overflow here.
7255 attrs.LoopTripCount = builder.CreateMul(attrs.LoopTripCount, tripCount,
7256 {}, /*HasNUW=*/true);
7257 }
7258 }
7259
7260 attrs.DeviceID = builder.getInt64(llvm::omp::OMP_DEVICEID_UNDEF);
7261 if (mlir::Value devId = targetOp.getDevice()) {
7262 attrs.DeviceID = moduleTranslation.lookupValue(devId);
7263 attrs.DeviceID =
7264 builder.CreateSExtOrTrunc(attrs.DeviceID, builder.getInt64Ty());
7265 }
7266}
7267
7268static llvm::omp::OMPDynGroupprivateFallbackType
7269getDynGroupprivateFallbackType(omp::FallbackModifierAttr fallbackAttr) {
7270 omp::FallbackModifier fb = fallbackAttr ? fallbackAttr.getValue()
7271 : omp::FallbackModifier::default_mem;
7272 switch (fb) {
7273 case omp::FallbackModifier::abort:
7274 return llvm::omp::OMPDynGroupprivateFallbackType::Abort;
7275 case omp::FallbackModifier::null:
7276 return llvm::omp::OMPDynGroupprivateFallbackType::Null;
7277 case omp::FallbackModifier::default_mem:
7278 return llvm::omp::OMPDynGroupprivateFallbackType::DefaultMem;
7279 }
7280
7281 llvm_unreachable("unexpected dyn_groupprivate fallback type");
7282}
7283
7284static LogicalResult
7285convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
7286 LLVM::ModuleTranslation &moduleTranslation) {
7287 auto targetOp = cast<omp::TargetOp>(opInst);
7288
7289 // The current debug location already has the DISubprogram for the outlined
7290 // function that will be created for the target op. We save it here so that
7291 // we can set it on the outlined function.
7292 llvm::DebugLoc outlinedFnLoc = builder.getCurrentDebugLocation();
7293 if (failed(checkImplementationStatus(opInst)))
7294 return failure();
7295
7296 // During the handling of target op, we will generate instructions in the
7297 // parent function like call to the oulined function or branch to a new
7298 // BasicBlock. We set the debug location here to parent function so that those
7299 // get the correct debug locations. For outlined functions, the normal MLIR op
7300 // conversion will automatically pick the correct location.
7301 llvm::BasicBlock *parentBB = builder.GetInsertBlock();
7302 assert(parentBB && "No insert block is set for the builder");
7303 llvm::Function *parentLLVMFn = parentBB->getParent();
7304 assert(parentLLVMFn && "Parent Function must be valid");
7305 if (llvm::DISubprogram *SP = parentLLVMFn->getSubprogram())
7306 builder.SetCurrentDebugLocation(llvm::DILocation::get(
7307 parentLLVMFn->getContext(), outlinedFnLoc.getLine(),
7308 outlinedFnLoc.getCol(), SP, outlinedFnLoc.getInlinedAt()));
7309
7310 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
7311 bool isTargetDevice = ompBuilder->Config.isTargetDevice();
7312 bool isGPU = ompBuilder->Config.isGPU();
7313
7314 auto parentFn = opInst.getParentOfType<LLVM::LLVMFuncOp>();
7315 auto argIface = cast<omp::BlockArgOpenMPOpInterface>(opInst);
7316 auto &targetRegion = targetOp.getRegion();
7317 // Holds the private vars that have been mapped along with the block
7318 // argument that corresponds to the MapInfoOp corresponding to the private
7319 // var in question. So, for instance:
7320 //
7321 // %10 = omp.map.info var_ptr(%6#0 : !fir.ref<!fir.box<!fir.heap<i32>>>, ..)
7322 // omp.target map_entries(%10 -> %arg0) private(@box.privatizer %6#0-> %arg1)
7323 //
7324 // Then, %10 has been created so that the descriptor can be used by the
7325 // privatizer @box.privatizer on the device side. Here we'd record {%6#0,
7326 // %arg0} in the mappedPrivateVars map.
7327 llvm::DenseMap<Value, Value> mappedPrivateVars;
7328 DataLayout dl = DataLayout(opInst.getParentOfType<ModuleOp>());
7329 SmallVector<Value> mapVars = targetOp.getMapVars();
7330 SmallVector<Value> hdaVars = targetOp.getHasDeviceAddrVars();
7331 ArrayRef<BlockArgument> mapBlockArgs = argIface.getMapBlockArgs();
7332 ArrayRef<BlockArgument> hdaBlockArgs = argIface.getHasDeviceAddrBlockArgs();
7333 llvm::Function *llvmOutlinedFn = nullptr;
7334 TargetDirectiveEnumTy targetDirective =
7335 getTargetDirectiveEnumTyFromOp(&opInst);
7336
7337 // TODO: It can also be false if a compile-time constant `false` IF clause is
7338 // specified.
7339 bool isOffloadEntry =
7340 isTargetDevice || !ompBuilder->Config.TargetTriples.empty();
7341
7342 // For some private variables, the MapsForPrivatizedVariablesPass
7343 // creates MapInfoOp instances. Go through the private variables and
7344 // the mapped variables so that during codegeneration we are able
7345 // to quickly look up the corresponding map variable, if any for each
7346 // private variable.
7347 if (!targetOp.getPrivateVars().empty() && !targetOp.getMapVars().empty()) {
7348 OperandRange privateVars = targetOp.getPrivateVars();
7349 std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms();
7350 std::optional<DenseI64ArrayAttr> privateMapIndices =
7351 targetOp.getPrivateMapsAttr();
7352
7353 for (auto [privVarIdx, privVarSymPair] :
7354 llvm::enumerate(llvm::zip_equal(privateVars, *privateSyms))) {
7355 auto privVar = std::get<0>(privVarSymPair);
7356 auto privSym = std::get<1>(privVarSymPair);
7357
7358 SymbolRefAttr privatizerName = llvm::cast<SymbolRefAttr>(privSym);
7359 omp::PrivateClauseOp privatizer =
7360 findPrivatizer(targetOp, privatizerName);
7361
7362 if (!privatizer.needsMap())
7363 continue;
7364
7365 mlir::Value mappedValue =
7366 targetOp.getMappedValueForPrivateVar(privVarIdx);
7367 assert(mappedValue && "Expected to find mapped value for a privatized "
7368 "variable that needs mapping");
7369
7370 // The MapInfoOp defining the map var isn't really needed later.
7371 // So, we don't store it in any datastructure. Instead, we just
7372 // do some sanity checks on it right now.
7373 auto mapInfoOp = mappedValue.getDefiningOp<omp::MapInfoOp>();
7374 [[maybe_unused]] Type varType = mapInfoOp.getVarType();
7375
7376 // Check #1: Check that the type of the private variable matches
7377 // the type of the variable being mapped.
7378 if (!isa<LLVM::LLVMPointerType>(privVar.getType()))
7379 assert(
7380 varType == privVar.getType() &&
7381 "Type of private var doesn't match the type of the mapped value");
7382
7383 // Ok, only 1 sanity check for now.
7384 // Record the block argument corresponding to this mapvar.
7385 mappedPrivateVars.insert(
7386 {privVar,
7387 targetRegion.getArgument(argIface.getMapBlockArgsStart() +
7388 (*privateMapIndices)[privVarIdx])});
7389 }
7390 }
7391
7392 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
7393 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
7394 ArrayRef<llvm::BasicBlock *> deallocBlocks)
7395 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
7396 llvm::IRBuilderBase::InsertPointGuard guard(builder);
7397 builder.SetCurrentDebugLocation(llvm::DebugLoc());
7398 // Forward target-cpu and target-features function attributes from the
7399 // original function to the new outlined function.
7400 llvm::Function *llvmParentFn =
7401 moduleTranslation.lookupFunction(parentFn.getName());
7402 llvmOutlinedFn = codeGenIP.getBlock()->getParent();
7403 assert(llvmParentFn && llvmOutlinedFn &&
7404 "Both parent and outlined functions must exist at this point");
7405
7406 if (outlinedFnLoc && llvmParentFn->getSubprogram())
7407 llvmOutlinedFn->setSubprogram(outlinedFnLoc->getScope()->getSubprogram());
7408
7409 if (auto attr = llvmParentFn->getFnAttribute("target-cpu");
7410 attr.isStringAttribute())
7411 llvmOutlinedFn->addFnAttr(attr);
7412
7413 if (auto attr = llvmParentFn->getFnAttribute("target-features");
7414 attr.isStringAttribute())
7415 llvmOutlinedFn->addFnAttr(attr);
7416
7417 for (auto [arg, mapOp] : llvm::zip_equal(mapBlockArgs, mapVars)) {
7418 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
7419 llvm::Value *mapOpValue =
7420 moduleTranslation.lookupValue(mapInfoOp.getVarPtr());
7421 moduleTranslation.mapValue(arg, mapOpValue);
7422 }
7423 for (auto [arg, mapOp] : llvm::zip_equal(hdaBlockArgs, hdaVars)) {
7424 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
7425 llvm::Value *mapOpValue =
7426 moduleTranslation.lookupValue(mapInfoOp.getVarPtr());
7427 moduleTranslation.mapValue(arg, mapOpValue);
7428 }
7429
7430 // Do privatization after moduleTranslation has already recorded
7431 // mapped values.
7432 PrivateVarsInfo privateVarsInfo(targetOp);
7433
7435 allocatePrivateVars(targetOp, builder, moduleTranslation,
7436 privateVarsInfo, allocaIP, &mappedPrivateVars);
7437
7438 if (failed(handleError(afterAllocas, *targetOp)))
7439 return llvm::make_error<PreviouslyReportedError>();
7440
7441 builder.restoreIP(codeGenIP);
7442 if (handleError(initPrivateVars(builder, moduleTranslation, privateVarsInfo,
7443 &mappedPrivateVars),
7444 *targetOp)
7445 .failed())
7446 return llvm::make_error<PreviouslyReportedError>();
7447
7448 if (failed(copyFirstPrivateVars(
7449 targetOp, builder, moduleTranslation, privateVarsInfo.mlirVars,
7450 privateVarsInfo.llvmVars, privateVarsInfo.privatizers,
7451 targetOp.getPrivateNeedsBarrier(), &mappedPrivateVars)))
7452 return llvm::make_error<PreviouslyReportedError>();
7453
7455 moduleTranslation, allocaIP, deallocBlocks);
7457 targetRegion, "omp.target", builder, moduleTranslation);
7458
7459 if (failed(handleError(exitBlock, *targetOp)))
7460 return llvm::make_error<PreviouslyReportedError>();
7461
7462 builder.SetInsertPoint(exitBlock.get()->getTerminator());
7463
7464 if (failed(cleanupPrivateVars(targetOp, builder, moduleTranslation,
7465 targetOp.getLoc(), privateVarsInfo)))
7466 return llvm::make_error<PreviouslyReportedError>();
7467
7468 return builder.saveIP();
7469 };
7470
7471 StringRef parentName = parentFn.getName();
7472
7473 llvm::TargetRegionEntryInfo entryInfo;
7474
7475 getTargetEntryUniqueInfo(entryInfo, targetOp,
7476 *moduleTranslation.getOpenMPBuilder(),
7477 moduleTranslation.getFileSystem(), parentName);
7478
7479 MapInfoData mapData;
7480 collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, dl,
7481 builder, /*useDevPtrOperands=*/{},
7482 /*useDevAddrOperands=*/{}, hdaVars);
7483
7484 MapInfosTy combinedInfos;
7485 auto genMapInfoCB =
7486 [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP) -> MapInfosTy & {
7487 builder.restoreIP(codeGenIP);
7488 genMapInfos(builder, moduleTranslation, dl, combinedInfos, mapData,
7489 targetDirective);
7490
7491 // Append a null entry for the implicit dyn_ptr argument so the argument
7492 // count sent to the runtime already includes it.
7493 auto *nullPtr = llvm::Constant::getNullValue(builder.getPtrTy());
7494 combinedInfos.BasePointers.push_back(nullPtr);
7495 combinedInfos.Pointers.push_back(nullPtr);
7496 combinedInfos.DevicePointers.push_back(
7497 llvm::OpenMPIRBuilder::DeviceInfoTy::None);
7498 combinedInfos.Sizes.push_back(builder.getInt64(0));
7499 combinedInfos.Types.push_back(
7500 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM |
7501 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL);
7502 if (!combinedInfos.Names.empty())
7503 combinedInfos.Names.push_back(nullPtr);
7504 combinedInfos.Mappers.push_back(nullptr);
7505
7506 return combinedInfos;
7507 };
7508
7509 auto argAccessorCB = [&](llvm::Argument &arg, llvm::Value *input,
7510 llvm::Value *&retVal, InsertPointTy allocaIP,
7511 InsertPointTy codeGenIP,
7513 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
7514 llvm::IRBuilderBase::InsertPointGuard guard(builder);
7515 builder.SetCurrentDebugLocation(llvm::DebugLoc());
7516 // We just return the unaltered argument for the host function
7517 // for now, some alterations may be required in the future to
7518 // keep host fallback functions working identically to the device
7519 // version (e.g. pass ByCopy values should be treated as such on
7520 // host and device, currently not always the case)
7521 if (!isTargetDevice) {
7522 retVal = cast<llvm::Value>(&arg);
7523 return codeGenIP;
7524 }
7525
7526 return createDeviceArgumentAccessor(targetOp, mapData, arg, input, retVal,
7527 builder, *ompBuilder, moduleTranslation,
7528 allocaIP, codeGenIP, deallocIPs);
7529 };
7530
7531 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs runtimeAttrs;
7532 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs;
7533 Operation *targetCapturedOp = targetOp.getInnermostCapturedOmpOp();
7534 initTargetDefaultAttrs(targetOp, targetCapturedOp, defaultAttrs,
7535 isTargetDevice, isGPU);
7536
7537 // Collect host-evaluated values needed to properly launch the kernel from the
7538 // host.
7539 if (!isTargetDevice)
7540 initTargetRuntimeAttrs(builder, moduleTranslation, targetOp,
7541 targetCapturedOp, runtimeAttrs);
7542
7543 // Pass host-evaluated values as parameters to the kernel / host fallback,
7544 // except if they are constants. In any case, map the MLIR block argument to
7545 // the corresponding LLVM values.
7547 SmallVector<Value> hostEvalVars = targetOp.getHostEvalVars();
7548 ArrayRef<BlockArgument> hostEvalBlockArgs = argIface.getHostEvalBlockArgs();
7549 for (auto [arg, var] : llvm::zip_equal(hostEvalBlockArgs, hostEvalVars)) {
7550 llvm::Value *value = moduleTranslation.lookupValue(var);
7551 moduleTranslation.mapValue(arg, value);
7552
7553 if (!llvm::isa<llvm::Constant>(value))
7554 kernelInput.push_back(value);
7555 }
7556
7557 for (size_t i = 0, e = mapData.OriginalValue.size(); i != e; ++i) {
7558 // declare target arguments are not passed to kernels as arguments
7559 // TODO: We currently do not handle cases where a member is explicitly
7560 // passed in as an argument, this will likley need to be handled in
7561 // the near future, rather than using IsAMember, it may be better to
7562 // test if the relevant BlockArg is used within the target region and
7563 // then use that as a basis for exclusion in the kernel inputs.
7564 if (!mapData.IsDeclareTarget[i] && !mapData.IsAMember[i])
7565 kernelInput.push_back(mapData.OriginalValue[i]);
7566 }
7567
7569 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
7570 findAllocInsertPoints(builder, moduleTranslation, &deallocBlocks);
7571
7572 llvm::OpenMPIRBuilder::DependenciesInfo dds;
7573 if (failed(buildDependData(
7574 targetOp.getDependVars(), targetOp.getDependKinds(),
7575 targetOp.getDependIterated(), targetOp.getDependIteratedKinds(),
7576 builder, moduleTranslation, dds)))
7577 return failure();
7578
7579 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
7580
7581 llvm::OpenMPIRBuilder::TargetDataInfo info(
7582 /*RequiresDevicePointerInfo=*/false,
7583 /*SeparateBeginEndCalls=*/true);
7584
7585 auto customMapperCB =
7586 [&](unsigned int i) -> llvm::Expected<llvm::Function *> {
7587 if (!combinedInfos.Mappers[i])
7588 return nullptr;
7589 info.HasMapper = true;
7590 return getOrCreateUserDefinedMapperFunc(combinedInfos.Mappers[i], builder,
7591 moduleTranslation, targetDirective);
7592 };
7593
7594 llvm::Value *ifCond = nullptr;
7595 if (Value targetIfCond = targetOp.getIfExpr())
7596 ifCond = moduleTranslation.lookupValue(targetIfCond);
7597
7598 Value dynGroupPrivateSize = targetOp.getDynGroupprivateSize();
7599 llvm::Value *dynSizeVal = nullptr;
7600 if (dynGroupPrivateSize) {
7601 dynSizeVal = moduleTranslation.lookupValue(dynGroupPrivateSize);
7602 dynSizeVal = builder.CreateIntCast(dynSizeVal, builder.getInt32Ty(),
7603 /*isSigned=*/false);
7604 }
7605
7606 llvm::omp::OMPDynGroupprivateFallbackType fallbackType =
7607 getDynGroupprivateFallbackType(targetOp.getDynGroupprivateFallbackAttr());
7608
7609 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
7610 moduleTranslation.getOpenMPBuilder()->createTarget(
7611 ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), deallocBlocks,
7612 info, entryInfo, defaultAttrs, runtimeAttrs, ifCond, kernelInput,
7613 genMapInfoCB, bodyCB, argAccessorCB, customMapperCB, dds,
7614 targetOp.getNowait(), dynSizeVal, fallbackType);
7615
7616 if (failed(handleError(afterIP, opInst)))
7617 return failure();
7618
7619 builder.restoreIP(*afterIP);
7620
7621 if (dds.DepArray)
7622 builder.CreateFree(dds.DepArray);
7623
7624 // Remap access operations to declare target reference pointers for the
7625 // device, essentially generating extra loadop's as necessary
7626 if (moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice())
7627 handleDeclareTargetMapVar(mapData, moduleTranslation, builder,
7628 llvmOutlinedFn);
7629
7630 return success();
7631}
7632
7633static LogicalResult
7634convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute,
7635 llvm::OpenMPIRBuilder *ompBuilder,
7636 LLVM::ModuleTranslation &moduleTranslation) {
7637 // Amend omp.declare_target by deleting the IR of the outlined functions
7638 // created for target regions. They cannot be filtered out from MLIR earlier
7639 // because the omp.target operation inside must be translated to LLVM, but
7640 // the wrapper functions themselves must not remain at the end of the
7641 // process. We know that functions where omp.declare_target does not match
7642 // omp.is_target_device at this stage can only be wrapper functions because
7643 // those that aren't are removed earlier as an MLIR transformation pass.
7644 if (FunctionOpInterface funcOp = dyn_cast<FunctionOpInterface>(op)) {
7645 if (auto offloadMod = dyn_cast<omp::OffloadModuleInterface>(
7646 op->getParentOfType<ModuleOp>().getOperation())) {
7647 if (!offloadMod.getIsTargetDevice())
7648 return success();
7649
7650 omp::DeclareTargetDeviceType declareType =
7651 attribute.getDeviceType().getValue();
7652
7653 if (declareType == omp::DeclareTargetDeviceType::host) {
7654 llvm::Function *llvmFunc =
7655 moduleTranslation.lookupFunction(funcOp.getName());
7656 llvmFunc->dropAllReferences();
7657 llvmFunc->eraseFromParent();
7658
7659 // Invalidate the builder's current insertion point, as it now points to
7660 // a deleted block.
7661 ompBuilder->Builder.ClearInsertionPoint();
7662 ompBuilder->Builder.SetCurrentDebugLocation(llvm::DebugLoc());
7663 }
7664 }
7665 return success();
7666 }
7667
7668 if (LLVM::GlobalOp gOp = dyn_cast<LLVM::GlobalOp>(op)) {
7669 llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
7670 if (auto *gVal = llvmModule->getNamedValue(gOp.getSymName())) {
7671 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
7672 bool isDeclaration = gOp.isDeclaration();
7673 bool isExternallyVisible =
7674 gOp.getVisibility() != mlir::SymbolTable::Visibility::Private;
7675 auto loc = op->getLoc()->findInstanceOf<FileLineColLoc>();
7676 llvm::StringRef mangledName = gOp.getSymName();
7677 auto captureClause =
7678 convertToCaptureClauseKind(attribute.getCaptureClause().getValue());
7679 auto deviceClause =
7680 convertToDeviceClauseKind(attribute.getDeviceType().getValue());
7681 // unused for MLIR at the moment, required in Clang for book
7682 // keeping
7683 std::vector<llvm::GlobalVariable *> generatedRefs;
7684
7685 std::vector<llvm::Triple> targetTriple;
7686 auto targetTripleAttr = dyn_cast_or_null<mlir::StringAttr>(
7687 op->getParentOfType<mlir::ModuleOp>()->getAttr(
7688 LLVM::LLVMDialect::getTargetTripleAttrName()));
7689 if (targetTripleAttr)
7690 targetTriple.emplace_back(targetTripleAttr.data());
7691
7692 auto fileInfoCallBack = [&loc]() {
7693 std::string filename = "";
7694 std::uint64_t lineNo = 0;
7695
7696 if (loc) {
7697 filename = loc.getFilename().str();
7698 lineNo = loc.getLine();
7699 }
7700
7701 return std::pair<std::string, std::uint64_t>(llvm::StringRef(filename),
7702 lineNo);
7703 };
7704
7705 llvm::vfs::FileSystem &vfs = moduleTranslation.getFileSystem();
7706
7707 ompBuilder->registerTargetGlobalVariable(
7708 captureClause, deviceClause, isDeclaration, isExternallyVisible,
7709 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack, vfs),
7710 mangledName, generatedRefs, /*OpenMPSimd*/ false, targetTriple,
7711 /*GlobalInitializer*/ nullptr, /*VariableLinkage*/ nullptr,
7712 gVal->getType(), gVal);
7713
7714 if (ompBuilder->Config.isTargetDevice() &&
7715 (attribute.getCaptureClause().getValue() !=
7716 mlir::omp::DeclareTargetCaptureClause::to ||
7717 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
7718 ompBuilder->getAddrOfDeclareTargetVar(
7719 captureClause, deviceClause, isDeclaration, isExternallyVisible,
7720 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack, vfs),
7721 mangledName, generatedRefs, /*OpenMPSimd*/ false, targetTriple,
7722 gVal->getType(), /*GlobalInitializer*/ nullptr,
7723 /*VariableLinkage*/ nullptr);
7724 }
7725 }
7726 }
7727
7728 return success();
7729}
7730
7731namespace {
7732
7733/// Implementation of the dialect interface that converts operations belonging
7734/// to the OpenMP dialect to LLVM IR.
7735class OpenMPDialectLLVMIRTranslationInterface
7736 : public LLVMTranslationDialectInterface {
7737public:
7738 using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
7739
7740 /// Translates the given operation to LLVM IR using the provided IR builder
7741 /// and saving the state in `moduleTranslation`.
7742 LogicalResult
7743 convertOperation(Operation *op, llvm::IRBuilderBase &builder,
7744 LLVM::ModuleTranslation &moduleTranslation) const final;
7745
7746 /// Given an OpenMP MLIR attribute, create the corresponding LLVM-IR,
7747 /// runtime calls, or operation amendments
7748 LogicalResult
7749 amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
7750 NamedAttribute attribute,
7751 LLVM::ModuleTranslation &moduleTranslation) const final;
7752
7753 /// Records the LLVM alloc pointer produced for an OMP ALLOCATE variable so
7754 /// that the paired omp.allocate_free op can generate the matching
7755 /// __kmpc_free call.
7756 void registerAllocatedPtr(Value var, llvm::Value *ptr) const {
7757 ompAllocatedPtrs[var] = ptr;
7758 }
7759
7760 /// Returns the LLVM alloc pointer previously registered for var, or
7761 /// nullptr if no allocation was recorded.
7762 llvm::Value *lookupAllocatedPtr(Value var) const {
7763 auto it = ompAllocatedPtrs.find(var);
7764 return it != ompAllocatedPtrs.end() ? it->second : nullptr;
7765 }
7766
7767private:
7768 /// Maps each MLIR variable value that appeared in an omp.allocate_dir op to
7769 /// the LLVM pointer returned by the corresponding __kmpc_alloc call. The
7770 /// paired omp.allocate_free op looks up these pointers to emit __kmpc_free.
7771 mutable DenseMap<Value, llvm::Value *> ompAllocatedPtrs;
7772};
7773
7774} // namespace
7775
7776LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
7777 Operation *op, ArrayRef<llvm::Instruction *> instructions,
7778 NamedAttribute attribute,
7779 LLVM::ModuleTranslation &moduleTranslation) const {
7780 return llvm::StringSwitch<llvm::function_ref<LogicalResult(Attribute)>>(
7781 attribute.getName())
7782 .Case("omp.is_target_device",
7783 [&](Attribute attr) {
7784 if (auto deviceAttr = dyn_cast<BoolAttr>(attr)) {
7785 llvm::OpenMPIRBuilderConfig &config =
7786 moduleTranslation.getOpenMPBuilder()->Config;
7787 config.setIsTargetDevice(deviceAttr.getValue());
7788 return success();
7789 }
7790 return failure();
7791 })
7792 .Case("omp.is_gpu",
7793 [&](Attribute attr) {
7794 if (auto gpuAttr = dyn_cast<BoolAttr>(attr)) {
7795 llvm::OpenMPIRBuilderConfig &config =
7796 moduleTranslation.getOpenMPBuilder()->Config;
7797 config.setIsGPU(gpuAttr.getValue());
7798 return success();
7799 }
7800 return failure();
7801 })
7802 .Case("omp.host_ir_filepath",
7803 [&](Attribute attr) {
7804 if (auto filepathAttr = dyn_cast<StringAttr>(attr)) {
7805 llvm::OpenMPIRBuilder *ompBuilder =
7806 moduleTranslation.getOpenMPBuilder();
7807 ompBuilder->loadOffloadInfoMetadata(
7808 moduleTranslation.getFileSystem(), filepathAttr.getValue());
7809 return success();
7810 }
7811 return failure();
7812 })
7813 .Case("omp.flags",
7814 [&](Attribute attr) {
7815 if (auto rtlAttr = dyn_cast<omp::FlagsAttr>(attr))
7816 return convertFlagsAttr(op, rtlAttr, moduleTranslation);
7817 return failure();
7818 })
7819 .Case("omp.version",
7820 [&](Attribute attr) {
7821 if (auto versionAttr = dyn_cast<omp::VersionAttr>(attr)) {
7822 llvm::OpenMPIRBuilder *ompBuilder =
7823 moduleTranslation.getOpenMPBuilder();
7824 ompBuilder->M.addModuleFlag(llvm::Module::Max, "openmp",
7825 versionAttr.getVersion());
7826 return success();
7827 }
7828 return failure();
7829 })
7830 .Case("omp.declare_target",
7831 [&](Attribute attr) {
7832 if (auto declareTargetAttr =
7833 dyn_cast<omp::DeclareTargetAttr>(attr)) {
7834 llvm::OpenMPIRBuilder *ompBuilder =
7835 moduleTranslation.getOpenMPBuilder();
7836 return convertDeclareTargetAttr(op, declareTargetAttr,
7837 ompBuilder, moduleTranslation);
7838 }
7839 return failure();
7840 })
7841 .Case("omp.requires",
7842 [&](Attribute attr) {
7843 if (auto requiresAttr = dyn_cast<omp::ClauseRequiresAttr>(attr)) {
7844 using Requires = omp::ClauseRequires;
7845 Requires flags = requiresAttr.getValue();
7846 llvm::OpenMPIRBuilderConfig &config =
7847 moduleTranslation.getOpenMPBuilder()->Config;
7848 config.setHasRequiresReverseOffload(
7849 bitEnumContainsAll(flags, Requires::reverse_offload));
7850 config.setHasRequiresUnifiedAddress(
7851 bitEnumContainsAll(flags, Requires::unified_address));
7852 config.setHasRequiresUnifiedSharedMemory(
7853 bitEnumContainsAll(flags, Requires::unified_shared_memory));
7854 config.setHasRequiresDynamicAllocators(
7855 bitEnumContainsAll(flags, Requires::dynamic_allocators));
7856 return success();
7857 }
7858 return failure();
7859 })
7860 .Case("omp.target_triples",
7861 [&](Attribute attr) {
7862 if (auto triplesAttr = dyn_cast<ArrayAttr>(attr)) {
7863 llvm::OpenMPIRBuilderConfig &config =
7864 moduleTranslation.getOpenMPBuilder()->Config;
7865 config.TargetTriples.clear();
7866 config.TargetTriples.reserve(triplesAttr.size());
7867 for (Attribute tripleAttr : triplesAttr) {
7868 if (auto tripleStrAttr = dyn_cast<StringAttr>(tripleAttr))
7869 config.TargetTriples.emplace_back(tripleStrAttr.getValue());
7870 else
7871 return failure();
7872 }
7873 return success();
7874 }
7875 return failure();
7876 })
7877 .Default([](Attribute) {
7878 // Fall through for omp attributes that do not require lowering.
7879 return success();
7880 })(attribute.getValue());
7881
7882 return failure();
7883}
7884
7885// Returns true if the operation is not inside a TargetOp, it is part of a
7886// function and that function is not declare target.
7887static bool isHostDeviceOp(Operation *op) {
7888 // Assumes no reverse offloading
7889 if (op->getParentOfType<omp::TargetOp>())
7890 return false;
7891
7892 if (auto parentFn = op->getParentOfType<LLVM::LLVMFuncOp>()) {
7893 if (auto declareTargetIface =
7894 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
7895 parentFn.getOperation()))
7896 if (declareTargetIface.isDeclareTarget() &&
7897 declareTargetIface.getDeclareTargetDeviceType() !=
7898 mlir::omp::DeclareTargetDeviceType::host)
7899 return false;
7900
7901 return true;
7902 }
7903
7904 return false;
7905}
7906
7907static llvm::Function *getOmpTargetAlloc(llvm::IRBuilderBase &builder,
7908 llvm::Module *llvmModule) {
7909 llvm::Type *i64Ty = builder.getInt64Ty();
7910 llvm::Type *i32Ty = builder.getInt32Ty();
7911 llvm::Type *returnType = builder.getPtrTy(0);
7912 llvm::FunctionType *fnType =
7913 llvm::FunctionType::get(returnType, {i64Ty, i32Ty}, false);
7914 llvm::Function *func = cast<llvm::Function>(
7915 llvmModule->getOrInsertFunction("omp_target_alloc", fnType).getCallee());
7916 return func;
7917}
7918
7919template <typename T>
7920static llvm::Value *
7921getAllocationSize(llvm::IRBuilderBase &builder,
7922 LLVM::ModuleTranslation &moduleTranslation, T op) {
7923 llvm::DataLayout dataLayout =
7924 moduleTranslation.getLLVMModule()->getDataLayout();
7925 llvm::Type *llvmHeapTy =
7926 moduleTranslation.convertType(op.getMemElemTypeAttr().getValue());
7927
7928 auto alignment = op.getMemAlignment();
7929 llvm::TypeSize typeSize = llvm::alignTo(
7930 dataLayout.getTypeStoreSize(llvmHeapTy),
7931 alignment ? *alignment : dataLayout.getABITypeAlign(llvmHeapTy).value());
7932
7933 llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue());
7934 return builder.CreateMul(
7935 allocSize,
7936 builder.CreateIntCast(moduleTranslation.lookupValue(op.getMemArraySize()),
7937 builder.getInt64Ty(),
7938 /*isSigned=*/false));
7939}
7940
7941template <>
7942llvm::Value *getAllocationSize(llvm::IRBuilderBase &builder,
7943 LLVM::ModuleTranslation &moduleTranslation,
7944 omp::TargetAllocMemOp op) {
7945 llvm::DataLayout dataLayout =
7946 moduleTranslation.getLLVMModule()->getDataLayout();
7947 llvm::Type *llvmHeapTy = moduleTranslation.convertType(op.getAllocatedType());
7948 llvm::TypeSize typeSize = dataLayout.getTypeAllocSize(llvmHeapTy);
7949 llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue());
7950 for (auto typeParam : op.getTypeparams()) {
7951 allocSize = builder.CreateMul(
7952 allocSize,
7953 builder.CreateIntCast(moduleTranslation.lookupValue(typeParam),
7954 builder.getInt64Ty(),
7955 /*isSigned=*/false));
7956 }
7957 return allocSize;
7958}
7959
7960static LogicalResult
7961convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder,
7962 LLVM::ModuleTranslation &moduleTranslation) {
7963 auto allocMemOp = cast<omp::TargetAllocMemOp>(opInst);
7964 if (!allocMemOp)
7965 return failure();
7966
7967 // Get "omp_target_alloc" function
7968 llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
7969 llvm::Function *ompTargetAllocFunc = getOmpTargetAlloc(builder, llvmModule);
7970 // Get the corresponding device value in llvm
7971 mlir::Value deviceNum = allocMemOp.getDevice();
7972 llvm::Value *llvmDeviceNum = moduleTranslation.lookupValue(deviceNum);
7973 // Get the allocation size.
7974 llvm::Value *allocSize =
7975 getAllocationSize(builder, moduleTranslation, allocMemOp);
7976 // Create call to "omp_target_alloc" with the args as translated llvm values.
7977 llvm::CallInst *call =
7978 builder.CreateCall(ompTargetAllocFunc, {allocSize, llvmDeviceNum});
7979 llvm::Value *resultI64 = builder.CreatePtrToInt(call, builder.getInt64Ty());
7980
7981 // Map the result
7982 moduleTranslation.mapValue(allocMemOp.getResult(), resultI64);
7983 return success();
7984}
7985
7986static LogicalResult
7987convertAllocSharedMemOp(omp::AllocSharedMemOp allocMemOp,
7988 llvm::IRBuilderBase &builder,
7989 LLVM::ModuleTranslation &moduleTranslation) {
7990 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
7991 llvm::Value *size = getAllocationSize(builder, moduleTranslation, allocMemOp);
7992 moduleTranslation.mapValue(allocMemOp.getResult(),
7993 ompBuilder->createOMPAllocShared(builder, size));
7994 return success();
7995}
7996
7997static LogicalResult
7998convertAllocateDirOp(Operation &opInst, llvm::IRBuilderBase &builder,
7999 LLVM::ModuleTranslation &moduleTranslation,
8000 const OpenMPDialectLLVMIRTranslationInterface &ompIface) {
8001 auto allocateDirOp = cast<omp::AllocateDirOp>(opInst);
8002 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
8003
8004 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
8005 llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
8006 llvm::DataLayout dataLayout = llvmModule->getDataLayout();
8007 SmallVector<Value> vars = allocateDirOp.getVarList();
8008 std::optional<int64_t> alignAttr = allocateDirOp.getAlign();
8009
8010 llvm::Value *allocator;
8011 if (auto allocatorVar = allocateDirOp.getAllocator()) {
8012 allocator = moduleTranslation.lookupValue(allocatorVar);
8013 if (allocator->getType()->isIntegerTy())
8014 allocator = builder.CreateIntToPtr(allocator, builder.getPtrTy());
8015 else if (allocator->getType()->isPointerTy())
8016 allocator = builder.CreatePointerBitCastOrAddrSpaceCast(
8017 allocator, builder.getPtrTy());
8018 } else {
8019 allocator = llvm::ConstantPointerNull::get(builder.getPtrTy());
8020 }
8021
8022 for (Value var : vars) {
8023 llvm::Type *llvmVarTy = moduleTranslation.convertType(var.getType());
8024
8025 // Opaque pointers lose element type. Trace to GlobalOp for type
8026 // Falls back to llvmVarTy when not from a global.
8027 llvm::Type *typeToInspect = llvmVarTy;
8028 if (llvmVarTy->isPointerTy()) {
8029 Value baseVar = getBaseValueForTypeLookup(var);
8030 if (Operation *globalOp = getGlobalOpFromValue(baseVar)) {
8031 if (auto gop = dyn_cast<LLVM::GlobalOp>(globalOp))
8032 typeToInspect = moduleTranslation.convertType(gop.getGlobalType());
8033 }
8034 }
8035
8036 llvm::Value *size;
8037 if (auto arrTy = llvm::dyn_cast<llvm::ArrayType>(typeToInspect)) {
8038 llvm::Value *elementCount = builder.getInt64(1);
8039 llvm::Type *currentType = arrTy;
8040 while (auto nestedArrTy = llvm::dyn_cast<llvm::ArrayType>(currentType)) {
8041 elementCount = builder.CreateMul(
8042 elementCount, builder.getInt64(nestedArrTy->getNumElements()));
8043 currentType = nestedArrTy->getElementType();
8044 }
8045 uint64_t elemSizeInBits = dataLayout.getTypeSizeInBits(currentType);
8046 size =
8047 builder.CreateMul(elementCount, builder.getInt64(elemSizeInBits / 8));
8048 } else {
8049 size = builder.getInt64(
8050 dataLayout.getTypeStoreSize(typeToInspect).getFixedValue());
8051 }
8052
8053 uint64_t alignValue =
8054 alignAttr ? alignAttr.value()
8055 : dataLayout.getABITypeAlign(typeToInspect).value();
8056 llvm::Value *alignConst = builder.getInt64(alignValue);
8057 // Align the size: ((size + align - 1) / align) * align
8058 size = builder.CreateAdd(size, builder.getInt64(alignValue - 1), "", true);
8059 size = builder.CreateUDiv(size, alignConst);
8060 size = builder.CreateMul(size, alignConst, "", true);
8061
8062 std::string allocName =
8063 ompBuilder->createPlatformSpecificName({".void.addr"});
8064 llvm::CallInst *allocCall;
8065 if (alignAttr.has_value()) {
8066 allocCall = ompBuilder->createOMPAlignedAlloc(
8067 ompLoc, builder.getInt64(alignAttr.value()), size, allocator,
8068 allocName);
8069 } else {
8070 allocCall =
8071 ompBuilder->createOMPAlloc(ompLoc, size, allocator, allocName);
8072 }
8073 // Record the alloc pointer keyed by the MLIR variable value.
8074 ompIface.registerAllocatedPtr(var, allocCall);
8075 }
8076
8077 return success();
8078}
8079
8080static LogicalResult
8081convertAllocateFreeOp(Operation &opInst, llvm::IRBuilderBase &builder,
8082 LLVM::ModuleTranslation &moduleTranslation,
8083 const OpenMPDialectLLVMIRTranslationInterface &ompIface) {
8084 auto freeOp = cast<omp::AllocateFreeOp>(opInst);
8085 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
8086 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
8087
8088 llvm::Value *allocator;
8089 if (auto allocatorVar = freeOp.getAllocator()) {
8090 allocator = moduleTranslation.lookupValue(allocatorVar);
8091 if (allocator->getType()->isIntegerTy())
8092 allocator = builder.CreateIntToPtr(allocator, builder.getPtrTy());
8093 else if (allocator->getType()->isPointerTy())
8094 allocator = builder.CreatePointerBitCastOrAddrSpaceCast(
8095 allocator, builder.getPtrTy());
8096 } else {
8097 allocator = llvm::ConstantPointerNull::get(builder.getPtrTy());
8098 }
8099
8100 // Emit __kmpc_free for each variable in reverse allocation order.
8101 SmallVector<Value> vars = freeOp.getVarList();
8102 for (Value var : llvm::reverse(vars)) {
8103 llvm::Value *allocPtr = ompIface.lookupAllocatedPtr(var);
8104 if (!allocPtr)
8105 return opInst.emitError("omp.allocate_free: no allocation recorded");
8106 ompBuilder->createOMPFree(ompLoc, allocPtr, allocator, "");
8107 }
8108
8109 return success();
8110}
8111
8112static llvm::Function *getOmpTargetFree(llvm::IRBuilderBase &builder,
8113 llvm::Module *llvmModule) {
8114 llvm::Type *ptrTy = builder.getPtrTy(0);
8115 llvm::Type *i32Ty = builder.getInt32Ty();
8116 llvm::Type *voidTy = builder.getVoidTy();
8117 llvm::FunctionType *fnType =
8118 llvm::FunctionType::get(voidTy, {ptrTy, i32Ty}, false);
8119 llvm::Function *func = dyn_cast<llvm::Function>(
8120 llvmModule->getOrInsertFunction("omp_target_free", fnType).getCallee());
8121 return func;
8122}
8123
8124static LogicalResult
8125convertTargetFreeMemOp(Operation &opInst, llvm::IRBuilderBase &builder,
8126 LLVM::ModuleTranslation &moduleTranslation) {
8127 auto freeMemOp = cast<omp::TargetFreeMemOp>(opInst);
8128 if (!freeMemOp)
8129 return failure();
8130
8131 // Get "omp_target_free" function
8132 llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
8133 llvm::Function *ompTragetFreeFunc = getOmpTargetFree(builder, llvmModule);
8134 // Get the corresponding device value in llvm
8135 mlir::Value deviceNum = freeMemOp.getDevice();
8136 llvm::Value *llvmDeviceNum = moduleTranslation.lookupValue(deviceNum);
8137 // Get the corresponding heapref value in llvm
8138 mlir::Value heapref = freeMemOp.getHeapref();
8139 llvm::Value *llvmHeapref = moduleTranslation.lookupValue(heapref);
8140 // Convert heapref int to ptr and call "omp_target_free"
8141 llvm::Value *intToPtr =
8142 builder.CreateIntToPtr(llvmHeapref, builder.getPtrTy(0));
8143 builder.CreateCall(ompTragetFreeFunc, {intToPtr, llvmDeviceNum});
8144 return success();
8145}
8146
8147static LogicalResult
8148convertFreeSharedMemOp(omp::FreeSharedMemOp freeMemOp,
8149 llvm::IRBuilderBase &builder,
8150 LLVM::ModuleTranslation &moduleTranslation) {
8151 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
8152 llvm::Value *size = getAllocationSize(builder, moduleTranslation, freeMemOp);
8153 ompBuilder->createOMPFreeShared(
8154 builder, moduleTranslation.lookupValue(freeMemOp.getHeapref()), size);
8155 return success();
8156}
8157
8158/// Converts an OpenMP groupprivate operation into LLVM IR.
8159static LogicalResult
8160convertOmpGroupprivate(Operation &opInst, llvm::IRBuilderBase &builder,
8161 LLVM::ModuleTranslation &moduleTranslation) {
8162 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
8163 auto groupprivateOp = cast<omp::GroupprivateOp>(opInst);
8164
8165 if (failed(checkImplementationStatus(opInst)))
8166 return failure();
8167
8168 bool isTargetDevice = ompBuilder->Config.isTargetDevice();
8169
8170 // Determine whether group-private storage should be allocated based on
8171 // device_type. When not specified, default to 'any' (allocate on both).
8172 bool shouldAllocate = true;
8173 switch (groupprivateOp.getDeviceType().value_or(
8174 mlir::omp::DeclareTargetDeviceType::any)) {
8175 case mlir::omp::DeclareTargetDeviceType::host:
8176 shouldAllocate = !isTargetDevice;
8177 break;
8178 case mlir::omp::DeclareTargetDeviceType::nohost:
8179 shouldAllocate = isTargetDevice;
8180 break;
8181 case mlir::omp::DeclareTargetDeviceType::any:
8182 shouldAllocate = true;
8183 break;
8184 }
8185
8186 // Look up the global variable directly by symbol name.
8188 &opInst, groupprivateOp.getSymNameAttr());
8189 if (!global)
8190 return opInst.emitError()
8191 << "expected symbol '" << groupprivateOp.getSymName()
8192 << "' to reference an LLVM global variable";
8193
8194 llvm::GlobalValue *globalValue = moduleTranslation.lookupGlobal(global);
8195 llvm::Type *varType = moduleTranslation.convertType(global.getType());
8196 std::string varName = globalValue->getName().str();
8197
8198 llvm::Value *resultPtr;
8199 if (shouldAllocate && isTargetDevice) {
8200 llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
8201 llvm::Triple targetTriple(llvmModule->getTargetTriple());
8202 unsigned sharedAddressSpace;
8203 if (targetTriple.isAMDGCN())
8204 sharedAddressSpace = llvm::AMDGPUAS::LOCAL_ADDRESS;
8205 else if (targetTriple.isNVPTX())
8206 sharedAddressSpace = llvm::NVPTXAS::ADDRESS_SPACE_SHARED;
8207 else
8208 return opInst.emitError() << "groupprivate is not supported for target: "
8209 << targetTriple.str();
8210 llvm::GlobalVariable *sharedVar = new llvm::GlobalVariable(
8211 *llvmModule, varType, /*isConstant=*/false,
8212 llvm::GlobalValue::InternalLinkage, llvm::PoisonValue::get(varType),
8213 varName, /*InsertBefore=*/nullptr, llvm::GlobalValue::NotThreadLocal,
8214 sharedAddressSpace,
8215 /*isExternallyInitialized=*/false);
8216 resultPtr = sharedVar;
8217 } else {
8218 if (shouldAllocate && !isTargetDevice)
8219 opInst.emitWarning("groupprivate directive is currently ignored on the "
8220 "host, using original global");
8221 resultPtr = globalValue;
8222 }
8223
8224 moduleTranslation.mapValue(opInst.getResult(0), resultPtr);
8225 return success();
8226}
8227
8228/// Given an OpenMP MLIR operation, create the corresponding LLVM IR (including
8229/// OpenMP runtime calls).
8230LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
8231 Operation *op, llvm::IRBuilderBase &builder,
8232 LLVM::ModuleTranslation &moduleTranslation) const {
8233 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
8234
8235 if (ompBuilder->Config.isTargetDevice() &&
8236 !isa<omp::TargetOp, omp::MapInfoOp, omp::TerminatorOp, omp::YieldOp>(
8237 op) &&
8238 isHostDeviceOp(op))
8239 return op->emitOpError() << "unsupported host op found in device";
8240
8241 // For each loop, introduce one stack frame to hold loop information. Ensure
8242 // this is only done for the outermost loop wrapper to prevent introducing
8243 // multiple stack frames for a single loop. Initially set to null, the loop
8244 // information structure is initialized during translation of the nested
8245 // omp.loop_nest operation, making it available to translation of all loop
8246 // wrappers after their body has been successfully translated.
8247 bool isOutermostLoopWrapper =
8248 isa_and_present<omp::LoopWrapperInterface>(op) &&
8249 !dyn_cast_if_present<omp::LoopWrapperInterface>(op->getParentOp());
8250
8251 // The TASKLOOP construct is implemented with an outer taskloop.context
8252 // operation which is not a loop wrapper, containing an inner taskloop
8253 // operation which is a loop wrapper. The stack frame should be pushed when
8254 // translating the outer taskloop.context and popped when translating the
8255 // inner taskloop which is a loop wrapper. We need access to the loop
8256 // information in the outer taskloop context so we need to create it and pop
8257 // it around the taskloop context not the inner loop wrapper.
8258 if (isa<omp::TaskloopContextOp>(op))
8259 isOutermostLoopWrapper = true;
8260 else if (isa<omp::TaskloopWrapperOp>(op))
8261 isOutermostLoopWrapper = false;
8262
8263 if (isOutermostLoopWrapper)
8264 moduleTranslation.stackPush<OpenMPLoopInfoStackFrame>();
8265
8266 auto result =
8267 llvm::TypeSwitch<Operation *, LogicalResult>(op)
8268 .Case([&](omp::BarrierOp op) -> LogicalResult {
8270 return failure();
8271
8272 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
8273 ompBuilder->createBarrier(builder.saveIP(),
8274 llvm::omp::OMPD_barrier);
8275 LogicalResult res = handleError(afterIP, *op);
8276 if (res.succeeded()) {
8277 // If the barrier generated a cancellation check, the insertion
8278 // point might now need to be changed to a new continuation block
8279 builder.restoreIP(*afterIP);
8280 }
8281 return res;
8282 })
8283 .Case([&](omp::TaskyieldOp op) {
8285 return failure();
8286
8287 ompBuilder->createTaskyield(builder.saveIP());
8288 return success();
8289 })
8290 .Case([&](omp::FlushOp op) {
8292 return failure();
8293
8294 // No support in Openmp runtime function (__kmpc_flush) to accept
8295 // the argument list.
8296 // OpenMP standard states the following:
8297 // "An implementation may implement a flush with a list by ignoring
8298 // the list, and treating it the same as a flush without a list."
8299 //
8300 // The argument list is discarded so that, flush with a list is
8301 // treated same as a flush without a list.
8302 ompBuilder->createFlush(builder.saveIP());
8303 return success();
8304 })
8305 .Case([&](omp::ParallelOp op) {
8306 return convertOmpParallel(op, builder, moduleTranslation);
8307 })
8308 .Case([&](omp::MaskedOp) {
8309 return convertOmpMasked(*op, builder, moduleTranslation);
8310 })
8311 .Case([&](omp::MasterOp) {
8312 return convertOmpMaster(*op, builder, moduleTranslation);
8313 })
8314 .Case([&](omp::CriticalOp) {
8315 return convertOmpCritical(*op, builder, moduleTranslation);
8316 })
8317 .Case([&](omp::OrderedRegionOp) {
8318 return convertOmpOrderedRegion(*op, builder, moduleTranslation);
8319 })
8320 .Case([&](omp::OrderedOp) {
8321 return convertOmpOrdered(*op, builder, moduleTranslation);
8322 })
8323 .Case([&](omp::WsloopOp) {
8324 return convertOmpWsloop(*op, builder, moduleTranslation);
8325 })
8326 .Case([&](omp::SimdOp) {
8327 return convertOmpSimd(*op, builder, moduleTranslation);
8328 })
8329 .Case([&](omp::AtomicReadOp) {
8330 return convertOmpAtomicRead(*op, builder, moduleTranslation);
8331 })
8332 .Case([&](omp::AtomicWriteOp) {
8333 return convertOmpAtomicWrite(*op, builder, moduleTranslation);
8334 })
8335 .Case([&](omp::AtomicUpdateOp op) {
8336 return convertOmpAtomicUpdate(op, builder, moduleTranslation);
8337 })
8338 .Case([&](omp::AtomicCaptureOp op) {
8339 return convertOmpAtomicCapture(op, builder, moduleTranslation);
8340 })
8341 .Case([&](omp::CancelOp op) {
8342 return convertOmpCancel(op, builder, moduleTranslation);
8343 })
8344 .Case([&](omp::CancellationPointOp op) {
8345 return convertOmpCancellationPoint(op, builder, moduleTranslation);
8346 })
8347 .Case([&](omp::SectionsOp) {
8348 return convertOmpSections(*op, builder, moduleTranslation);
8349 })
8350 .Case([&](omp::ScopeOp op) {
8351 return convertOmpScope(op, builder, moduleTranslation);
8352 })
8353 .Case([&](omp::SingleOp op) {
8354 return convertOmpSingle(op, builder, moduleTranslation);
8355 })
8356 .Case([&](omp::TeamsOp op) {
8357 return convertOmpTeams(op, builder, moduleTranslation);
8358 })
8359 .Case([&](omp::TaskOp op) {
8360 return convertOmpTaskOp(op, builder, moduleTranslation);
8361 })
8362 .Case([&](omp::TaskloopWrapperOp op) {
8363 return convertOmpTaskloopWrapperOp(op, builder, moduleTranslation);
8364 })
8365 .Case([&](omp::TaskloopContextOp op) {
8366 return convertOmpTaskloopContextOp(op, builder, moduleTranslation);
8367 })
8368 .Case([&](omp::TaskgroupOp op) {
8369 return convertOmpTaskgroupOp(op, builder, moduleTranslation);
8370 })
8371 .Case([&](omp::TaskwaitOp op) {
8372 return convertOmpTaskwaitOp(op, builder, moduleTranslation);
8373 })
8374 .Case<omp::YieldOp, omp::TerminatorOp, omp::DeclareMapperOp,
8375 omp::DeclareMapperInfoOp, omp::DeclareReductionOp,
8376 omp::CriticalDeclareOp>([](auto op) {
8377 // `yield` and `terminator` can be just omitted. The block structure
8378 // was created in the region that handles their parent operation.
8379 // `declare_reduction` will be used by reductions and is not
8380 // converted directly, skip it.
8381 // `declare_mapper` and `declare_mapper.info` are handled whenever
8382 // they are referred to through a `map` clause.
8383 // `critical.declare` is only used to declare names of critical
8384 // sections which will be used by `critical` ops and hence can be
8385 // ignored for lowering. The OpenMP IRBuilder will create unique
8386 // name for critical section names.
8387 return success();
8388 })
8389 .Case([&](omp::ThreadprivateOp) {
8390 return convertOmpThreadprivate(*op, builder, moduleTranslation);
8391 })
8392 .Case<omp::TargetDataOp, omp::TargetEnterDataOp,
8393 omp::TargetExitDataOp, omp::TargetUpdateOp>([&](auto op) {
8394 return convertOmpTargetData(op, builder, moduleTranslation);
8395 })
8396 .Case([&](omp::TargetOp) {
8397 return convertOmpTarget(*op, builder, moduleTranslation);
8398 })
8399 .Case([&](omp::DistributeOp) {
8400 return convertOmpDistribute(*op, builder, moduleTranslation);
8401 })
8402 .Case([&](omp::LoopNestOp) {
8403 return convertOmpLoopNest(*op, builder, moduleTranslation);
8404 })
8405 .Case<omp::MapInfoOp, omp::MapBoundsOp, omp::PrivateClauseOp,
8406 omp::AffinityEntryOp, omp::IteratorOp>([&](auto op) {
8407 // No-op, should be handled by relevant owning operations e.g.
8408 // TargetOp, TargetEnterDataOp, TargetExitDataOp, TargetDataOp
8409 // etc. and then discarded
8410 return success();
8411 })
8412 .Case([&](omp::NewCliOp op) {
8413 // Meta-operation: Doesn't do anything by itself, but used to
8414 // identify a loop.
8415 return success();
8416 })
8417 .Case([&](omp::CanonicalLoopOp op) {
8418 return convertOmpCanonicalLoopOp(op, builder, moduleTranslation);
8419 })
8420 .Case([&](omp::UnrollHeuristicOp op) {
8421 // FIXME: Handling omp.unroll_heuristic as an executable requires
8422 // that the generator (e.g. omp.canonical_loop) has been seen first.
8423 // For construct that require all codegen to occur inside a callback
8424 // (e.g. OpenMPIRBilder::createParallel), all codegen of that
8425 // contained region including their transformations must occur at
8426 // the omp.canonical_loop.
8427 return applyUnrollHeuristic(op, builder, moduleTranslation);
8428 })
8429 .Case([&](omp::TileOp op) {
8430 return applyTile(op, builder, moduleTranslation);
8431 })
8432 .Case([&](omp::FuseOp op) {
8433 return applyFuse(op, builder, moduleTranslation);
8434 })
8435 .Case([&](omp::TargetAllocMemOp) {
8436 return convertTargetAllocMemOp(*op, builder, moduleTranslation);
8437 })
8438 .Case([&](omp::TargetFreeMemOp) {
8439 return convertTargetFreeMemOp(*op, builder, moduleTranslation);
8440 })
8441 .Case([&](omp::AllocateDirOp) {
8442 return convertAllocateDirOp(*op, builder, moduleTranslation, *this);
8443 })
8444 .Case([&](omp::AllocateFreeOp) {
8445 return convertAllocateFreeOp(*op, builder, moduleTranslation,
8446 *this);
8447 })
8448 .Case([&](omp::AllocSharedMemOp op) {
8449 return convertAllocSharedMemOp(op, builder, moduleTranslation);
8450 })
8451 .Case([&](omp::FreeSharedMemOp op) {
8452 return convertFreeSharedMemOp(op, builder, moduleTranslation);
8453 })
8454 .Case([&](omp::GroupprivateOp) {
8455 return convertOmpGroupprivate(*op, builder, moduleTranslation);
8456 })
8457 .Default([&](Operation *inst) {
8458 return inst->emitError()
8459 << "not yet implemented: " << inst->getName();
8460 });
8461
8462 if (isOutermostLoopWrapper)
8463 moduleTranslation.stackPop();
8464
8465 return result;
8466}
8467
8469 registry.insert<omp::OpenMPDialect>();
8470 registry.addExtension(+[](MLIRContext *ctx, omp::OpenMPDialect *dialect) {
8471 dialect->addInterfaces<OpenMPDialectLLVMIRTranslationInterface>();
8472 });
8473}
8474
8476 DialectRegistry registry;
8478 context.appendDialectRegistry(registry);
8479}
for(Operation *op :ops)
return success()
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
if(!isCopyOut)
static bool mapTypeToBool(ClauseMapFlags value, ClauseMapFlags flag)
static mlir::LogicalResult buildDependData(OperandRange dependVars, std::optional< ArrayAttr > dependKinds, OperandRange dependIterated, std::optional< ArrayAttr > dependIteratedKinds, llvm::IRBuilderBase &builder, mlir::LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::DependenciesInfo &taskDeps)
static void handleDeclareTargetMapVar(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::Function *func)
static LogicalResult convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP atomic update operation using OpenMPIRBuilder.
static llvm::omp::OrderKind convertOrderKind(std::optional< omp::ClauseOrderKind > o)
Convert Order attribute to llvm::omp::OrderKind.
static llvm::OpenMPIRBuilder::InsertPointTy findAllocInsertPoints(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::SmallVectorImpl< llvm::BasicBlock * > *deallocBlocks=nullptr)
Find the insertion point for allocas given the current insertion point for normal operations in the b...
static LogicalResult convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
owningDataPtrPtrReductionGens[i]
static LogicalResult convertOmpTaskloopContextOp(omp::TaskloopContextOp contextOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static Operation * getGlobalOpFromValue(Value value)
static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind convertToCaptureClauseKind(mlir::omp::DeclareTargetCaptureClause captureClause)
static mlir::LogicalResult convertIteratorRegion(llvm::Value *linearIV, IteratorInfo &iterInfo, mlir::Block &iteratorRegionBlock, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo, bool first)
static OpTy castOrGetParentOfType(Operation *op, bool immediateParent=false)
If op is of the given type parameter, return it casted to that type. Otherwise, if its immediate pare...
static LogicalResult convertOmpOrderedRegion(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'ordered_region' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertFreeSharedMemOp(omp::FreeSharedMemOp freeMemOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertTargetFreeMemOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static void processMapMembersWithParent(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, llvm::omp::OpenMPOffloadMappingFlags memberOfFlag, TargetDirectiveEnumTy targetDirective)
static LogicalResult convertOmpAtomicWrite(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an omp.atomic.write operation to LLVM IR.
static OwningAtomicReductionGen makeAtomicReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Create an OpenMPIRBuilder-compatible atomic reduction generator for the given reduction declaration.
static OwningDataPtrPtrReductionGen makeRefDataPtrGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, bool isByRef)
Create an OpenMPIRBuilder-compatible data_ptr_ptr reduction generator for the given reduction declara...
static void popCancelFinalizationCB(const ArrayRef< llvm::UncondBrInst * > cancelTerminators, llvm::OpenMPIRBuilder &ompBuilder, const llvm::OpenMPIRBuilder::InsertPointTy &afterIP)
If we cancelled the construct, we should branch to the finalization block of that construct....
static llvm::Value * getRefPtrIfDeclareTarget(Value value, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult applyUnrollHeuristic(omp::UnrollHeuristicOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Apply a #pragma omp unroll / "!$omp unroll" transformation using the OpenMPIRBuilder.
static LogicalResult convertOmpMaster(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'master' operation into LLVM IR using OpenMPIRBuilder.
static void getAsIntegers(ArrayAttr values, llvm::SmallVector< int64_t > &ints)
static llvm::Value * findAssociatedValue(Value privateVar, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Return the llvm::Value * corresponding to the privateVar that is being privatized....
static ArrayRef< bool > getIsByRef(std::optional< ArrayRef< bool > > attr)
static llvm::Expected< llvm::Value * > lookupOrTranslatePureValue(Value value, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder)
Look up the given value in the mapping, and if it's not there, translate its defining operation at th...
static LogicalResult allocReductionVars(T op, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, SmallVectorImpl< DeferredStore > &deferredStores, llvm::ArrayRef< bool > isByRefs)
Allocate space for privatized reduction variables.
static LogicalResult convertOmpTaskwaitOp(omp::TaskwaitOp twOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP loop nest into LLVM IR using OpenMPIRBuilder.
static mlir::LogicalResult fillIteratorLoop(mlir::omp::IteratorOp itersOp, llvm::IRBuilderBase &builder, mlir::LLVM::ModuleTranslation &moduleTranslation, IteratorInfo &iterInfo, llvm::StringRef loopName, IteratorStoreEntryTy genStoreEntry)
static llvm::Expected< llvm::BasicBlock * > allocatePrivateVars(T op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, PrivateVarsInfo &privateVarsInfo, const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Allocate and initialize delayed private variables. Returns the basic block which comes after all of t...
static void createAlteredByCaptureMap(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder)
static LogicalResult convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP task construct into LLVM IR using OpenMPIRBuilder.
static void genMapInfos(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, TargetDirectiveEnumTy targetDirective)
static llvm::AtomicOrdering convertAtomicOrdering(std::optional< omp::ClauseMemoryOrderKind > ao)
Convert an Atomic Ordering attribute to llvm::AtomicOrdering.
static void setInsertPointForPossiblyEmptyBlock(llvm::IRBuilderBase &builder, llvm::BasicBlock *block=nullptr)
llvm::function_ref< void(llvm::Value *linearIV, mlir::omp::YieldOp yield)> IteratorStoreEntryTy
static LogicalResult convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'critical' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, TargetDirectiveEnumTy targetDirective)
static omp::DistributeOp getDistributeCapturingTeamsReduction(omp::TeamsOp teamsOp)
static LogicalResult convertOmpCanonicalLoopOp(omp::CanonicalLoopOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Convert an omp.canonical_loop to LLVM-IR.
static LogicalResult convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static std::optional< int64_t > extractConstInteger(Value value)
If the given value is defined by an llvm.mlir.constant operation and it is of an integer type,...
static llvm::Expected< llvm::Value * > initPrivateVar(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, omp::PrivateClauseOp &privDecl, llvm::Value *nonPrivateVar, BlockArgument &blockArg, llvm::Value *llvmPrivateVar, llvm::BasicBlock *privInitBlock, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Initialize a single (first)private variable. You probably want to use allocateAndInitPrivateVars inst...
static mlir::LogicalResult buildAffinityData(mlir::omp::TaskOp &taskOp, llvm::IRBuilderBase &builder, mlir::LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::AffinityData &ad)
static LogicalResult allocAndInitializeReductionVars(OP op, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, llvm::ArrayRef< bool > isByRef)
static LogicalResult convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Value * getAllocationSize(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, T op)
static llvm::Function * getOmpTargetAlloc(llvm::IRBuilderBase &builder, llvm::Module *llvmModule)
static llvm::omp::OMPDynGroupprivateFallbackType getDynGroupprivateFallbackType(omp::FallbackModifierAttr fallbackAttr)
static llvm::Expected< llvm::Function * > emitUserDefinedMapper(Operation *declMapperOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::StringRef mapperFuncName, TargetDirectiveEnumTy targetDirective)
static LogicalResult convertOmpOrdered(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'ordered' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult cleanupPrivateVars(T op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, Location loc, PrivateVarsInfo &privateVarsInfo)
static void processMapWithMembersOf(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, TargetDirectiveEnumTy targetDirective)
static LogicalResult convertOmpMasked(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'masked' operation into LLVM IR using OpenMPIRBuilder.
static llvm::AtomicRMWInst::BinOp convertBinOpToAtomic(Operation &op)
Converts an LLVM dialect binary operation to the corresponding enum value for atomicrmw supported bin...
static LogicalResult convertOmpCancel(omp::CancelOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp)
allocatedType moduleTranslation static convertType(allocatedType) LogicalResult inlineOmpRegionCleanup(llvm::SmallVectorImpl< Region * > &cleanupRegions, llvm::ArrayRef< llvm::Value * > privateVariables, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, StringRef regionName, bool shouldLoadCleanupRegionArg=true)
handling of DeclareReductionOp's cleanup region
static LogicalResult applyFuse(omp::FuseOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Apply a #pragma omp fuse / !$omp fuse transformation using the OpenMPIRBuilder.
static LogicalResult convertOmpScope(omp::ScopeOp &scopeOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP scope construct into LLVM IR.
static llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Error initPrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, PrivateVarsInfo &privateVarsInfo, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
static LogicalResult convertAllocSharedMemOp(omp::AllocSharedMemOp allocMemOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::CanonicalLoopInfo * findCurrentLoopInfo(LLVM::ModuleTranslation &moduleTranslation)
Find the loop information structure for the loop nest being translated.
static OwningReductionGen makeReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Create an OpenMPIRBuilder-compatible reduction generator for the given reduction declaration.
static std::vector< llvm::Value * > calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, bool isArrayTy, OperandRange bounds)
This function calculates the array/pointer offset for map data provided with bounds operations,...
static void storeAffinityEntry(llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, llvm::Value *affinityList, llvm::Value *index, llvm::Value *addr, llvm::Value *len)
static LogicalResult convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts the OpenMP parallel operation to LLVM IR.
static void pushCancelFinalizationCB(SmallVectorImpl< llvm::UncondBrInst * > &cancelTerminators, llvm::IRBuilderBase &llvmBuilder, llvm::OpenMPIRBuilder &ompBuilder, mlir::Operation *op, llvm::omp::Directive cancelDirective)
Shared implementation of a callback which adds a termiator for the new block created for the branch t...
static LogicalResult inlineConvertOmpRegions(Region &region, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::Value * > *continuationBlockArgs=nullptr)
Translates the blocks contained in the given region and appends them to at the current insertion poin...
static void getTargetEntryUniqueInfo(llvm::TargetRegionEntryInfo &targetInfo, omp::TargetOp targetOp, llvm::OpenMPIRBuilder &ompBuilder, llvm::vfs::FileSystem &vfs, llvm::StringRef parentName="")
static LogicalResult convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP Threadprivate operation into LLVM IR using OpenMPIRBuilder.
static omp::PrivateClauseOp findPrivatizer(Operation *from, SymbolRefAttr symbolName)
Looks up from the operation from and returns the PrivateClauseOp with name symbolName.
static LogicalResult convertOmpGroupprivate(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP groupprivate operation into LLVM IR.
static llvm::Expected< llvm::Function * > getOrCreateUserDefinedMapperFunc(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, TargetDirectiveEnumTy targetDirective)
static uint64_t getTypeByteSize(mlir::Type type, const DataLayout &dl)
static void processIndividualMap(MapInfoData &mapData, size_t mapDataIdx, MapInfosTy &combinedInfo, TargetDirectiveEnumTy targetDirective, int mapDataParentIdx=-1)
static llvm::SmallString< 64 > getDeclareTargetRefPtrSuffix(LLVM::GlobalOp globalOp, llvm::OpenMPIRBuilder &ompBuilder, llvm::vfs::FileSystem &vfs)
static void extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads, Value &numTeamsLower, Value &numTeamsUpper, Value &threadLimit, llvm::SmallVectorImpl< Value > *lowerBounds=nullptr, llvm::SmallVectorImpl< Value > *upperBounds=nullptr, llvm::SmallVectorImpl< Value > *steps=nullptr)
Follow uses of host_eval-defined block arguments of the given omp.target operation and populate outpu...
static llvm::Expected< llvm::BasicBlock * > convertOmpOpRegions(Region &region, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::PHINode * > *continuationBlockPHIs=nullptr)
Converts the given region that appears within an OpenMP dialect operation to LLVM IR,...
static LogicalResult copyFirstPrivateVars(mlir::Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::Value * > &moldVars, ArrayRef< llvm::Value * > llvmPrivateVars, SmallVectorImpl< omp::PrivateClauseOp > &privateDecls, bool insertBarrier, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
static LogicalResult convertAllocateDirOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, const OpenMPDialectLLVMIRTranslationInterface &ompIface)
static bool constructIsCancellable(Operation *op)
Returns true if the construct contains omp.cancel or omp.cancellation_point.
static llvm::omp::OpenMPOffloadMappingFlags convertClauseMapFlags(omp::ClauseMapFlags mlirFlags)
static void buildDependDataLocator(std::optional< ArrayAttr > dependKinds, OperandRange dependVars, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::OpenMPIRBuilder::DependData > &dds)
static void mapInitializationArgs(T loop, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, DenseMap< Value, llvm::Value * > &reductionVariableMap, unsigned i)
Map input arguments to reduction initialization region.
static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind)
Convert ProcBindKind from MLIR-generated enum to LLVM enum.
static void fillAffinityLocators(Operation::operand_range affinityVars, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::Value *affinityList)
static LogicalResult convertOmpTaskloopWrapperOp(omp::TaskloopWrapperOp loopWrapperOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
The correct entry point is convertOmpTaskloopContextOp. This gets called whilst lowering the body of ...
static void getOverlappedMembers(llvm::SmallVectorImpl< size_t > &overlapMapDataIdxs, omp::MapInfoOp parentOp)
static LogicalResult convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP single construct into LLVM IR using OpenMPIRBuilder.
static bool isDeclareTargetTo(Value value)
static uint64_t getArrayElementSizeInBits(LLVM::LLVMArrayType arrTy, DataLayout &dl)
static void collectReductionDecls(T op, SmallVectorImpl< omp::DeclareReductionOp > &reductions)
Populates reductions with reduction declarations used in the given op.
static LogicalResult handleError(llvm::Error error, Operation &op)
static LogicalResult convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind convertToDeviceClauseKind(mlir::omp::DeclareTargetDeviceType deviceClause)
static llvm::Error computeTaskloopBounds(omp::LoopNestOp loopOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::Value *&lbVal, llvm::Value *&ubVal, llvm::Value *&stepVal)
static LogicalResult checkImplementationStatus(Operation &op)
Check whether translation to LLVM IR for the given operation is currently supported.
static llvm::IRBuilderBase::InsertPoint createDeviceArgumentAccessor(omp::TargetOp targetOp, MapInfoData &mapData, llvm::Argument &arg, llvm::Value *input, llvm::Value *&retVal, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase::InsertPoint allocaIP, llvm::IRBuilderBase::InsertPoint codeGenIP, llvm::ArrayRef< llvm::IRBuilderBase::InsertPoint > deallocIPs)
static LogicalResult createReductionsAndCleanup(OP op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, ArrayRef< llvm::Value * > privateReductionVariables, ArrayRef< bool > isByRef, bool isNowait=false, bool isTeamsReduction=false)
static LogicalResult convertOmpCancellationPoint(omp::CancellationPointOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static bool opIsInSingleThread(mlir::Operation *op)
This can't always be determined statically, but when we can, it is good to avoid generating compiler-...
static uint64_t getReductionDataSize(OpTy &op)
static LogicalResult convertOmpAtomicRead(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Convert omp.atomic.read operation to LLVM IR.
static llvm::omp::Directive convertCancellationConstructType(omp::ClauseCancellationConstructType directive)
static void initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp, llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs, bool isTargetDevice, bool isGPU)
Populate default MinTeams, MaxTeams and MaxThreads to their default values as stated by the correspon...
static llvm::omp::RTLDependenceKindTy convertDependKind(mlir::omp::ClauseTaskDepend kind)
static void initTargetRuntimeAttrs(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, omp::TargetOp targetOp, Operation *capturedOp, llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs)
Gather LLVM runtime values for all clauses evaluated in the host that are passed to the kernel invoca...
static LogicalResult convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static Value getBaseValueForTypeLookup(Value value)
static bool isHostDeviceOp(Operation *op)
static LogicalResult convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute, llvm::OpenMPIRBuilder *ompBuilder, LLVM::ModuleTranslation &moduleTranslation)
static bool isDeclareTargetLink(Value value)
static LogicalResult convertFlagsAttr(Operation *op, mlir::omp::FlagsAttr attribute, LLVM::ModuleTranslation &moduleTranslation)
Lowers the FlagsAttr which is applied to the module on the device pass when offloading,...
static bool checkIfPointerMap(omp::MapInfoOp mapOp)
static LogicalResult applyTile(omp::TileOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Apply a #pragma omp tile / !$omp tile transformation using the OpenMPIRBuilder.
static LogicalResult convertAllocateFreeOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, const OpenMPDialectLLVMIRTranslationInterface &ompIface)
static void sortMapIndices(llvm::SmallVectorImpl< size_t > &indices, omp::MapInfoOp mapInfo)
static llvm::Function * getOmpTargetFree(llvm::IRBuilderBase &builder, llvm::Module *llvmModule)
static LogicalResult convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder.
static void collectMapDataFromMapOperands(MapInfoData &mapData, SmallVectorImpl< Value > &mapVars, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, llvm::IRBuilderBase &builder, ArrayRef< Value > useDevPtrOperands={}, ArrayRef< Value > useDevAddrOperands={}, ArrayRef< Value > hasDevAddrOperands={})
static void extractAtomicControlFlags(omp::AtomicUpdateOp atomicUpdateOp, bool &isIgnoreDenormalMode, bool &isFineGrainedMemory, bool &isRemoteMemory)
static Operation * genLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr, unsigned numCases, bool needsUniv, ArrayRef< TensorLevel > tidLvls)
Generates a for-loop or a while-loop, depending on whether it implements singleton iteration or co-it...
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition TypeID.h:331
#define div(a, b)
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:306
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
Operation & front()
Definition Block.h:163
Operation & back()
Definition Block.h:162
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
iterator begin()
Definition Block.h:153
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
An instance of this location represents a tuple of file, line number, and column number.
Definition Location.h:174
Implementation class for module translation.
llvm::BasicBlock * lookupBlock(Block *block) const
Finds an LLVM IR basic block that corresponds to the given MLIR block.
WalkResult stackWalk(llvm::function_ref< WalkResult(T &)> callback)
Calls callback for every ModuleTranslation stack frame of type T starting from the top of the stack.
void stackPush(Args &&...args)
Creates a stack frame of type T on ModuleTranslation stack.
LogicalResult convertBlock(Block &bb, bool ignoreArguments, llvm::IRBuilderBase &builder)
Translates the contents of the given block to LLVM IR using this translator.
SmallVector< llvm::Value * > lookupValues(ValueRange values)
Looks up remapped a list of remapped values.
void mapFunction(StringRef name, llvm::Function *func)
Stores the mapping between a function name and its LLVM IR representation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
void invalidateOmpLoop(omp::NewCliOp mlir)
Mark an OpenMP loop as having been consumed.
SymbolTableCollection & symbolTable()
llvm::Type * convertType(Type type)
Converts the type from MLIR LLVM dialect to LLVM.
llvm::OpenMPIRBuilder * getOpenMPBuilder()
Returns the OpenMP IR builder associated with the LLVM IR module being constructed.
llvm::vfs::FileSystem & getFileSystem()
Returns the virtual filesystem to use for file operations.
void mapOmpLoop(omp::NewCliOp mlir, llvm::CanonicalLoopInfo *llvm)
Map an MLIR OpenMP dialect CanonicalLoopInfo to its lowered LLVM-IR OpenMPIRBuilder CanonicalLoopInfo...
llvm::GlobalValue * lookupGlobal(Operation *op)
Finds an LLVM IR global value that corresponds to the given MLIR operation defining a global value.
SaveStateStack< T, ModuleTranslation > SaveStack
RAII object calling stackPush/stackPop on construction/destruction.
LogicalResult convertOperation(Operation &op, llvm::IRBuilderBase &builder)
Converts the given MLIR operation into LLVM IR using this translator.
llvm::Function * lookupFunction(StringRef name) const
Finds an LLVM IR function by its name.
void mapBlock(Block *mlir, llvm::BasicBlock *llvm)
Stores the mapping between an MLIR block and LLVM IR basic block.
llvm::Module * getLLVMModule()
Returns the LLVM module in which the IR is being constructed.
void stackPop()
Pops the last element from the ModuleTranslation stack.
void forgetMapping(Region &region)
Removes the mapping for blocks contained in the region and values defined in these blocks.
void mapValue(Value mlir, llvm::Value *llvm)
Stores the mapping between an MLIR value and its LLVM IR counterpart.
llvm::CanonicalLoopInfo * lookupOMPLoop(omp::NewCliOp mlir) const
Find the LLVM-IR loop that represents an MLIR loop.
llvm::LLVMContext & getLLVMContext() const
Returns the LLVM context in which the IR is being constructed.
Utility class to translate MLIR LLVM dialect types to LLVM IR.
Definition TypeToLLVM.h:39
unsigned getPreferredAlignment(Type type, const llvm::DataLayout &layout)
Returns the preferred alignment for the type given the data layout.
T findInstanceOf()
Return an instance of the given location type if one is nested under the current location.
Definition Location.h:45
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
void appendDialectRegistry(const DialectRegistry &registry)
Append the contents of the given dialect registry to the registry associated with this context.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:44
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:712
Value getOperand(unsigned idx)
Definition Operation.h:376
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:433
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:700
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:241
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:252
unsigned getNumOperands()
Definition Operation.h:372
OperandRange operand_range
Definition Operation.h:397
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:256
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:116
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:404
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:823
user_range getUsers()
Returns a range of all users.
Definition Operation.h:899
result_range getResults()
Definition Operation.h:441
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:234
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:430
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
BlockArgListType getArguments()
Definition Region.h:81
bool empty()
Definition Region.h:60
unsigned getNumArguments()
Definition Region.h:123
iterator begin()
Definition Region.h:55
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition Region.h:200
BlockListType & getBlocks()
Definition Region.h:45
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
Concrete CRTP base class for StateStack frames.
Definition StateStack.h:47
@ Private
The symbol is private and may only be referenced by SymbolRefAttrs local to the operations within the...
Definition SymbolTable.h:97
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult skip()
Definition WalkResult.h:48
static WalkResult advance()
Definition WalkResult.h:47
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition WalkResult.h:51
static WalkResult interrupt()
Definition WalkResult.h:46
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Definition CallGraph.h:229
void connectPHINodes(Region &region, const ModuleTranslation &state)
For all blocks in the region that were converted to LLVM IR using the given ModuleTranslation,...
llvm::Constant * createMappingInformation(Location loc, llvm::OpenMPIRBuilder &builder)
Create a constant string representing the mapping information extracted from the MLIR location inform...
bool opInSharedDeviceContext(Operation &op)
Check whether the given operation is located in a context where an allocation to be used by multiple ...
Definition Utils.cpp:66
bool allocaUsesRequireSharedMem(Value alloc)
Check whether the value representing an allocation, assumed to have been defined in a shared device c...
Definition Utils.cpp:51
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
auto getDims(VectorType vType)
Returns a range over the dims (size and scalability) of a VectorType.
Include the generated interface declarations.
SetVector< Block * > getBlocksSortedByDominance(Region &region)
Gets a list of blocks that is sorted according to dominance.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool isPure(Operation *op)
Returns true if the given operation is pure, i.e., is speculatable that does not touch memory.
void registerOpenMPDialectTranslation(DialectRegistry &registry)
Register the OpenMP dialect and the translation from it to the LLVM IR in the given registry;.
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:125
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition Utils.cpp:1330
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:120
A util to collect info needed to convert delayed privatizers from MLIR to LLVM.
SmallVector< mlir::Value > mlirVars
SmallVector< omp::PrivateClauseOp > privatizers
MutableArrayRef< BlockArgument > blockArgs
SmallVector< llvm::Value * > llvmVars