MLIR 23.0.0git
OpenMPToLLVMIRTranslation.cpp
Go to the documentation of this file.
1//===- OpenMPToLLVMIRTranslation.cpp - Translate OpenMP dialect to LLVM IR-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a translation between the MLIR OpenMP dialect and LLVM
10// IR.
11//
12//===----------------------------------------------------------------------===//
20#include "mlir/IR/Operation.h"
22#include "mlir/Support/LLVM.h"
25
26#include "llvm/ADT/ArrayRef.h"
27#include "llvm/ADT/SmallVector.h"
28#include "llvm/ADT/TypeSwitch.h"
29#include "llvm/Frontend/OpenMP/OMPConstants.h"
30#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
31#include "llvm/IR/Constants.h"
32#include "llvm/IR/DebugInfoMetadata.h"
33#include "llvm/IR/DerivedTypes.h"
34#include "llvm/IR/IRBuilder.h"
35#include "llvm/IR/MDBuilder.h"
36#include "llvm/IR/ReplaceConstant.h"
37#include "llvm/Support/AMDGPUAddrSpace.h"
38#include "llvm/Support/FileSystem.h"
39#include "llvm/Support/NVPTXAddrSpace.h"
40#include "llvm/Support/VirtualFileSystem.h"
41#include "llvm/TargetParser/Triple.h"
42#include "llvm/Transforms/Utils/ModuleUtils.h"
43
44#include <cstdint>
45#include <iterator>
46#include <numeric>
47#include <optional>
48#include <utility>
49
50using namespace mlir;
51
52namespace {
53static llvm::omp::ScheduleKind
54convertToScheduleKind(std::optional<omp::ClauseScheduleKind> schedKind) {
55 if (!schedKind.has_value())
56 return llvm::omp::OMP_SCHEDULE_Default;
57 switch (schedKind.value()) {
58 case omp::ClauseScheduleKind::Static:
59 return llvm::omp::OMP_SCHEDULE_Static;
60 case omp::ClauseScheduleKind::Dynamic:
61 return llvm::omp::OMP_SCHEDULE_Dynamic;
62 case omp::ClauseScheduleKind::Guided:
63 return llvm::omp::OMP_SCHEDULE_Guided;
64 case omp::ClauseScheduleKind::Auto:
65 return llvm::omp::OMP_SCHEDULE_Auto;
66 case omp::ClauseScheduleKind::Runtime:
67 return llvm::omp::OMP_SCHEDULE_Runtime;
68 case omp::ClauseScheduleKind::Distribute:
69 return llvm::omp::OMP_SCHEDULE_Distribute;
70 }
71 llvm_unreachable("unhandled schedule clause argument");
72}
73
74/// ModuleTranslation stack frame for OpenMP operations. This keeps track of the
75/// insertion points for allocas.
76class OpenMPAllocStackFrame
77 : public StateStackFrameBase<OpenMPAllocStackFrame> {
78public:
80
81 explicit OpenMPAllocStackFrame(
82 llvm::OpenMPIRBuilder::InsertPointTy allocaIP,
83 llvm::ArrayRef<llvm::BasicBlock *> deallocBlocks)
84 : allocInsertPoint(allocaIP), deallocBlocks(deallocBlocks) {}
85 llvm::OpenMPIRBuilder::InsertPointTy allocInsertPoint;
86 llvm::SmallVector<llvm::BasicBlock *> deallocBlocks;
87};
88
89/// Stack frame to hold a \see llvm::CanonicalLoopInfo representing the
90/// collapsed canonical loop information corresponding to an \c omp.loop_nest
91/// operation.
92class OpenMPLoopInfoStackFrame
93 : public StateStackFrameBase<OpenMPLoopInfoStackFrame> {
94public:
95 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPLoopInfoStackFrame)
96 llvm::CanonicalLoopInfo *loopInfo = nullptr;
97};
98
99/// Custom error class to signal translation errors that don't need reporting,
100/// since encountering them will have already triggered relevant error messages.
101///
102/// Its purpose is to serve as the glue between MLIR failures represented as
103/// \see LogicalResult instances and \see llvm::Error instances used to
104/// propagate errors through the \see llvm::OpenMPIRBuilder. Generally, when an
105/// error of the first type is raised, a message is emitted directly (the \see
106/// LogicalResult itself does not hold any information). If we need to forward
107/// this error condition as an \see llvm::Error while avoiding triggering some
108/// redundant error reporting later on, we need a custom \see llvm::ErrorInfo
109/// class to just signal this situation has happened.
110///
111/// For example, this class should be used to trigger errors from within
112/// callbacks passed to the \see OpenMPIRBuilder when they were triggered by the
113/// translation of their own regions. This unclutters the error log from
114/// redundant messages.
115class PreviouslyReportedError
116 : public llvm::ErrorInfo<PreviouslyReportedError> {
117public:
118 void log(raw_ostream &) const override {
119 // Do not log anything.
120 }
121
122 std::error_code convertToErrorCode() const override {
123 llvm_unreachable(
124 "PreviouslyReportedError doesn't support ECError conversion");
125 }
126
127 // Used by ErrorInfo::classID.
128 static char ID;
129};
130
131char PreviouslyReportedError::ID = 0;
132
133/*
134 * Custom class for processing linear clause for omp.wsloop
135 * and omp.simd. Linear clause translation requires setup,
136 * initialization, update, and finalization at varying
137 * basic blocks in the IR. This class helps maintain
138 * internal state to allow consistent translation in
139 * each of these stages.
140 */
141
142class LinearClauseProcessor {
143
144private:
145 SmallVector<llvm::Value *> linearPreconditionVars;
146 SmallVector<llvm::Value *> linearLoopBodyTemps;
147 SmallVector<llvm::Value *> linearOrigVal;
148 SmallVector<llvm::Value *> linearSteps;
149 SmallVector<llvm::Type *> linearVarTypes;
150 llvm::BasicBlock *linearFinalizationBB;
151 llvm::BasicBlock *linearExitBB;
152 llvm::BasicBlock *linearLastIterExitBB;
153
154public:
155 // Register type for the linear variables
156 void registerType(LLVM::ModuleTranslation &moduleTranslation,
157 mlir::Attribute &ty) {
158 linearVarTypes.push_back(moduleTranslation.convertType(
159 mlir::cast<mlir::TypeAttr>(ty).getValue()));
160 }
161
162 // Allocate space for linear variabes
163 void createLinearVar(llvm::IRBuilderBase &builder,
164 LLVM::ModuleTranslation &moduleTranslation,
165 llvm::Value *linearVar, int idx) {
166 linearPreconditionVars.push_back(
167 builder.CreateAlloca(linearVarTypes[idx], nullptr, ".linear_var"));
168 llvm::Value *linearLoopBodyTemp =
169 builder.CreateAlloca(linearVarTypes[idx], nullptr, ".linear_result");
170 linearOrigVal.push_back(linearVar);
171 linearLoopBodyTemps.push_back(linearLoopBodyTemp);
172 }
173
174 // Initialize linear step
175 inline void initLinearStep(LLVM::ModuleTranslation &moduleTranslation,
176 mlir::Value &linearStep) {
177 linearSteps.push_back(moduleTranslation.lookupValue(linearStep));
178 }
179
180 // Emit IR for initialization of linear variables
181 void initLinearVar(llvm::IRBuilderBase &builder,
182 LLVM::ModuleTranslation &moduleTranslation,
183 llvm::BasicBlock *loopPreHeader) {
184 builder.SetInsertPoint(loopPreHeader->getTerminator());
185 for (size_t index = 0; index < linearOrigVal.size(); index++) {
186 llvm::LoadInst *linearVarLoad =
187 builder.CreateLoad(linearVarTypes[index], linearOrigVal[index]);
188 builder.CreateStore(linearVarLoad, linearPreconditionVars[index]);
189 }
190 }
191
192 // Emit IR for updating Linear variables
193 void updateLinearVar(llvm::IRBuilderBase &builder, llvm::BasicBlock *loopBody,
194 llvm::Value *loopInductionVar) {
195 builder.SetInsertPoint(loopBody->getTerminator());
196 for (size_t index = 0; index < linearPreconditionVars.size(); index++) {
197 llvm::Type *linearVarType = linearVarTypes[index];
198 llvm::Value *iv = loopInductionVar;
199 llvm::Value *step = linearSteps[index];
200
201 if (!iv->getType()->isIntegerTy())
202 llvm_unreachable("OpenMP loop induction variable must be an integer "
203 "type");
204
205 if (linearVarType->isIntegerTy()) {
206 // Integer path: normalize all arithmetic to linearVarType
207 iv = builder.CreateSExtOrTrunc(iv, linearVarType);
208 step = builder.CreateSExtOrTrunc(step, linearVarType);
209
210 llvm::LoadInst *linearVarStart =
211 builder.CreateLoad(linearVarType, linearPreconditionVars[index]);
212 llvm::Value *mulInst = builder.CreateMul(iv, step);
213 llvm::Value *addInst = builder.CreateAdd(linearVarStart, mulInst);
214 builder.CreateStore(addInst, linearLoopBodyTemps[index]);
215 } else if (linearVarType->isFloatingPointTy()) {
216 // Float path: perform multiply in integer, then convert to float
217 step = builder.CreateSExtOrTrunc(step, iv->getType());
218 llvm::Value *mulInst = builder.CreateMul(iv, step);
219
220 llvm::LoadInst *linearVarStart =
221 builder.CreateLoad(linearVarType, linearPreconditionVars[index]);
222 llvm::Value *mulFp = builder.CreateSIToFP(mulInst, linearVarType);
223 llvm::Value *addInst = builder.CreateFAdd(linearVarStart, mulFp);
224 builder.CreateStore(addInst, linearLoopBodyTemps[index]);
225 } else {
226 llvm_unreachable(
227 "Linear variable must be of integer or floating-point type");
228 }
229 }
230 }
231
232 // Linear variable finalization is conditional on the last logical iteration.
233 // Create BB splits to manage the same.
234 void splitLinearFiniBB(llvm::IRBuilderBase &builder,
235 llvm::BasicBlock *loopExit) {
236 linearFinalizationBB = loopExit->splitBasicBlock(
237 loopExit->getTerminator(), "omp_loop.linear_finalization");
238 linearExitBB = linearFinalizationBB->splitBasicBlock(
239 linearFinalizationBB->getTerminator(), "omp_loop.linear_exit");
240 linearLastIterExitBB = linearFinalizationBB->splitBasicBlock(
241 linearFinalizationBB->getTerminator(), "omp_loop.linear_lastiter_exit");
242 }
243
244 // Finalize the linear vars
245 llvm::OpenMPIRBuilder::InsertPointOrErrorTy
246 finalizeLinearVar(llvm::IRBuilderBase &builder,
247 LLVM::ModuleTranslation &moduleTranslation,
248 llvm::Value *lastIter) {
249 // Emit condition to check whether last logical iteration is being executed
250 builder.SetInsertPoint(linearFinalizationBB->getTerminator());
251 llvm::Value *loopLastIterLoad = builder.CreateLoad(
252 llvm::Type::getInt32Ty(builder.getContext()), lastIter);
253 llvm::Value *isLast =
254 builder.CreateCmp(llvm::CmpInst::ICMP_NE, loopLastIterLoad,
255 llvm::ConstantInt::get(
256 llvm::Type::getInt32Ty(builder.getContext()), 0));
257 // Store the linear variable values to original variables.
258 builder.SetInsertPoint(linearLastIterExitBB->getTerminator());
259 for (size_t index = 0; index < linearOrigVal.size(); index++) {
260 llvm::LoadInst *linearVarTemp =
261 builder.CreateLoad(linearVarTypes[index], linearLoopBodyTemps[index]);
262 builder.CreateStore(linearVarTemp, linearOrigVal[index]);
263 }
264
265 // Create conditional branch such that the linear variable
266 // values are stored to original variables only at the
267 // last logical iteration
268 builder.SetInsertPoint(linearFinalizationBB->getTerminator());
269 builder.CreateCondBr(isLast, linearLastIterExitBB, linearExitBB);
270 linearFinalizationBB->getTerminator()->eraseFromParent();
271 // Emit barrier
272 builder.SetInsertPoint(linearExitBB->getTerminator());
273 return moduleTranslation.getOpenMPBuilder()->createBarrier(
274 builder.saveIP(), llvm::omp::OMPD_barrier);
275 }
276
277 // Emit stores for linear variables. Useful in case of SIMD
278 // construct.
279 void emitStoresForLinearVar(llvm::IRBuilderBase &builder) {
280 for (size_t index = 0; index < linearOrigVal.size(); index++) {
281 llvm::LoadInst *linearVarTemp =
282 builder.CreateLoad(linearVarTypes[index], linearLoopBodyTemps[index]);
283 builder.CreateStore(linearVarTemp, linearOrigVal[index]);
284 }
285 }
286
287 // Rewrite all uses of the original variable, in the basic blocks whose names
288 // start with `prefix`, with the linear variable in-place.
289 void rewriteInPlace(llvm::IRBuilderBase &builder, llvm::BasicBlock *startBB,
290 llvm::BasicBlock *endBB, llvm::StringRef prefix,
291 size_t varIndex) {
292 llvm::SmallVector<llvm::BasicBlock *, 32> worklist;
293 llvm::SmallPtrSet<llvm::BasicBlock *, 32> visited;
294 llvm::SmallPtrSet<llvm::BasicBlock *, 32> matchingBBs;
295
296 assert(startBB && endBB && "Invalid startBB/endBB");
297
298 // Traverse basic blocks from startBB to endBB and save those
299 // whose names start with the specified prefix.
300 worklist.push_back(startBB);
301 visited.insert(startBB);
302
303 while (!worklist.empty()) {
304 llvm::BasicBlock *bb = worklist.pop_back_val();
305
306 if (bb->hasName() && bb->getName().starts_with(prefix))
307 matchingBBs.insert(bb);
308
309 if (bb == endBB)
310 continue;
311
312 for (llvm::BasicBlock *succ : llvm::successors(bb)) {
313 if (visited.insert(succ).second)
314 worklist.push_back(succ);
315 }
316 }
317
318 // Rewrite all uses in the matching BBs.
319 llvm::SmallVector<llvm::User *> users(linearOrigVal[varIndex]->users());
320 for (auto *user : users) {
321 if (auto *userInst = dyn_cast<llvm::Instruction>(user)) {
322 if (matchingBBs.contains(userInst->getParent()))
323 user->replaceUsesOfWith(linearOrigVal[varIndex],
324 linearLoopBodyTemps[varIndex]);
325 }
326 }
327 }
328};
329
330} // namespace
331
332/// Looks up from the operation from and returns the PrivateClauseOp with
333/// name symbolName
334static omp::PrivateClauseOp findPrivatizer(Operation *from,
335 SymbolRefAttr symbolName) {
336 omp::PrivateClauseOp privatizer =
338 symbolName);
339 assert(privatizer && "privatizer not found in the symbol table");
340 return privatizer;
341}
342
343/// Check whether translation to LLVM IR for the given operation is currently
344/// supported. If not, descriptive diagnostics will be emitted to let users know
345/// this is a not-yet-implemented feature.
346///
347/// \returns success if no unimplemented features are needed to translate the
348/// given operation.
349static LogicalResult checkImplementationStatus(Operation &op) {
350 auto todo = [&op](StringRef clauseName) {
351 return op.emitError() << "not yet implemented: Unhandled clause "
352 << clauseName << " in " << op.getName()
353 << " operation";
354 };
355
356 auto checkAllocate = [&todo](auto op, LogicalResult &result) {
357 if (!op.getAllocateVars().empty() || !op.getAllocatorVars().empty())
358 result = todo("allocate");
359 };
360 auto checkBare = [&todo](auto op, LogicalResult &result) {
361 if (op.getBare())
362 result = todo("ompx_bare");
363 };
364 auto checkDepend = [&todo](auto op, LogicalResult &result) {
365 if (!op.getDependVars().empty() || op.getDependKinds())
366 result = todo("depend");
367 };
368 auto checkHint = [](auto op, LogicalResult &) {
369 if (op.getHint())
370 op.emitWarning("hint clause discarded");
371 };
372 auto checkInReduction = [&todo](auto op, LogicalResult &result) {
373 if (!op.getInReductionVars().empty() || op.getInReductionByref() ||
374 op.getInReductionSyms())
375 result = todo("in_reduction");
376 };
377 auto checkNowait = [&todo](auto op, LogicalResult &result) {
378 if (op.getNowait())
379 result = todo("nowait");
380 };
381 auto checkOrder = [&todo](auto op, LogicalResult &result) {
382 if (op.getOrder() || op.getOrderMod())
383 result = todo("order");
384 };
385 auto checkPrivate = [&todo](auto op, LogicalResult &result) {
386 if (!op.getPrivateVars().empty() || op.getPrivateSyms())
387 result = todo("privatization");
388 };
389 auto checkReduction = [&todo](auto op, LogicalResult &result) {
390 if (isa<omp::TeamsOp>(op) || isa<omp::TaskloopContextOp>(op))
391 if (!op.getReductionVars().empty() || op.getReductionByref() ||
392 op.getReductionSyms())
393 result = todo("reduction");
394 if (op.getReductionMod() &&
395 op.getReductionMod().value() != omp::ReductionModifier::defaultmod)
396 result = todo("reduction with modifier");
397 };
398 auto checkTaskReduction = [&todo](auto op, LogicalResult &result) {
399 if (!op.getTaskReductionVars().empty() || op.getTaskReductionByref() ||
400 op.getTaskReductionSyms())
401 result = todo("task_reduction");
402 };
403 auto checkNumTeams = [&todo](auto op, LogicalResult &result) {
404 if (op.hasNumTeamsMultiDim())
405 result = todo("num_teams with multi-dimensional values");
406 };
407 auto checkNumThreads = [&todo](auto op, LogicalResult &result) {
408 if (op.hasNumThreadsMultiDim())
409 result = todo("num_threads with multi-dimensional values");
410 };
411
412 auto checkThreadLimit = [&todo](auto op, LogicalResult &result) {
413 if (op.hasThreadLimitMultiDim())
414 result = todo("thread_limit with multi-dimensional values");
415 };
416
417 auto checkDynGroupprivate = [&todo](auto op, LogicalResult &result) {
418 if (op.getDynGroupprivateSize())
419 result = todo("dyn_groupprivate");
420 };
421
422 LogicalResult result = success();
424 .Case([&](omp::DistributeOp op) {
425 checkAllocate(op, result);
426 checkOrder(op, result);
427 })
428 .Case([&](omp::SectionsOp op) {
429 checkAllocate(op, result);
430 checkPrivate(op, result);
431 checkReduction(op, result);
432 })
433 .Case([&](omp::ScopeOp op) {
434 checkAllocate(op, result);
435 checkReduction(op, result);
436 })
437 .Case([&](omp::SingleOp op) {
438 checkAllocate(op, result);
439 checkPrivate(op, result);
440 })
441 .Case([&](omp::TeamsOp op) {
442 checkAllocate(op, result);
443 checkPrivate(op, result);
444 checkNumTeams(op, result);
445 checkThreadLimit(op, result);
446 checkDynGroupprivate(op, result);
447 })
448 .Case([&](omp::TaskOp op) {
449 checkAllocate(op, result);
450 checkInReduction(op, result);
451 })
452 .Case([&](omp::TaskgroupOp op) {
453 checkAllocate(op, result);
454 checkTaskReduction(op, result);
455 })
456 .Case([&](omp::TaskwaitOp op) {
457 checkDepend(op, result);
458 checkNowait(op, result);
459 })
460 .Case([&](omp::TaskloopContextOp op) {
461 checkAllocate(op, result);
462 checkInReduction(op, result);
463 checkReduction(op, result);
464 })
465 .Case([&](omp::WsloopOp op) {
466 checkAllocate(op, result);
467 checkOrder(op, result);
468 checkReduction(op, result);
469 })
470 .Case([&](omp::ParallelOp op) {
471 checkAllocate(op, result);
472 checkReduction(op, result);
473 checkNumThreads(op, result);
474 })
475 .Case([&](omp::SimdOp op) { checkReduction(op, result); })
476 .Case<omp::AtomicReadOp, omp::AtomicWriteOp, omp::AtomicUpdateOp,
477 omp::AtomicCaptureOp>([&](auto op) { checkHint(op, result); })
478 .Case([&](omp::AtomicCompareOp op) {
479 checkHint(op, result);
480 Region &region = op.getRegion();
481 if (region.empty())
482 return;
483 mlir::Type argType = region.front().getArgument(0).getType();
484 auto structTy = dyn_cast<LLVM::LLVMStructType>(argType);
485 if (!structTy)
486 return;
487 DataLayout dl = DataLayout(op->getParentOfType<ModuleOp>());
488 unsigned totalBits = dl.getTypeSizeInBits(structTy);
489 if (totalBits > 128)
490 result = todo("compare for complex types wider than 128 bits");
491 })
492 .Case<omp::TargetEnterDataOp, omp::TargetExitDataOp>(
493 [&](auto op) { checkDepend(op, result); })
494 .Case([&](omp::TargetUpdateOp op) { checkDepend(op, result); })
495 .Case([&](omp::TargetOp op) {
496 checkAllocate(op, result);
497 checkBare(op, result);
498 checkInReduction(op, result);
499 checkThreadLimit(op, result);
500 })
501 .Default([](Operation &) {
502 // Assume all clauses for an operation can be translated unless they are
503 // checked above.
504 });
505 return result;
506}
507
508static LogicalResult handleError(llvm::Error error, Operation &op) {
509 LogicalResult result = success();
510 if (error) {
511 llvm::handleAllErrors(
512 std::move(error),
513 [&](const PreviouslyReportedError &) { result = failure(); },
514 [&](const llvm::ErrorInfoBase &err) {
515 result = op.emitError(err.message());
516 });
517 }
518 return result;
519}
520
521template <typename T>
522static LogicalResult handleError(llvm::Expected<T> &result, Operation &op) {
523 if (!result)
524 return handleError(result.takeError(), op);
525
526 return success();
527}
528
529/// Find the insertion point for allocas given the current insertion point for
530/// normal operations in the builder.
531static llvm::OpenMPIRBuilder::InsertPointTy findAllocInsertPoints(
532 llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation,
533 llvm::SmallVectorImpl<llvm::BasicBlock *> *deallocBlocks = nullptr) {
534 // If there is an allocation insertion point on stack, i.e. we are in a nested
535 // operation and a specific point was provided by some surrounding operation,
536 // use it.
537 llvm::OpenMPIRBuilder::InsertPointTy allocInsertPoint;
538 llvm::ArrayRef<llvm::BasicBlock *> deallocInsertPoints;
539 WalkResult walkResult = moduleTranslation.stackWalk<OpenMPAllocStackFrame>(
540 [&](OpenMPAllocStackFrame &frame) {
541 allocInsertPoint = frame.allocInsertPoint;
542 deallocInsertPoints = frame.deallocBlocks;
543 return WalkResult::interrupt();
544 });
545 // In cases with multiple levels of outlining, the tree walk might find an
546 // insertion point that is inside the original function while the builder
547 // insertion point is inside the outlined function. We need to make sure that
548 // we do not use it in those cases.
549 if (walkResult.wasInterrupted() &&
550 allocInsertPoint.getBlock()->getParent() ==
551 builder.GetInsertBlock()->getParent()) {
552 if (deallocBlocks)
553 deallocBlocks->insert(deallocBlocks->end(), deallocInsertPoints.begin(),
554 deallocInsertPoints.end());
555 return allocInsertPoint;
556 }
557
558 // Otherwise, insert to the entry block of the surrounding function.
559 // If the current IRBuilder InsertPoint is the function's entry, it cannot
560 // also be used for alloca insertion which would result in insertion order
561 // confusion. Create a new BasicBlock for the Builder and use the entry block
562 // for the allocs.
563 // TODO: Create a dedicated alloca BasicBlock at function creation such that
564 // we do not need to move the current InsertPoint here.
565 if (builder.GetInsertBlock() ==
566 &builder.GetInsertBlock()->getParent()->getEntryBlock()) {
567 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end() &&
568 "Assuming end of basic block");
569 llvm::BasicBlock *entryBB = llvm::BasicBlock::Create(
570 builder.getContext(), "entry", builder.GetInsertBlock()->getParent(),
571 builder.GetInsertBlock()->getNextNode());
572 builder.CreateBr(entryBB);
573 builder.SetInsertPoint(entryBB);
574 }
575
576 // Collect exit blocks, which is where explicit deallocations should happen in
577 // this case.
578 if (deallocBlocks) {
579 for (llvm::BasicBlock &block : *builder.GetInsertBlock()->getParent()) {
580 // TODO: This currently results in no blocks being added to the list when
581 // all exit blocks of the enclosing function have not been lowered before
582 // this is reached.
583 llvm::Instruction *terminator = block.getTerminatorOrNull();
584 if (isa_and_present<llvm::ReturnInst>(terminator))
585 deallocBlocks->emplace_back(&block);
586 }
587 }
588
589 llvm::BasicBlock &funcEntryBlock =
590 builder.GetInsertBlock()->getParent()->getEntryBlock();
591 return llvm::OpenMPIRBuilder::InsertPointTy(
592 &funcEntryBlock, funcEntryBlock.getFirstInsertionPt());
593}
594
595/// Find the loop information structure for the loop nest being translated. It
596/// will return a `null` value unless called from the translation function for
597/// a loop wrapper operation after successfully translating its body.
598static llvm::CanonicalLoopInfo *
600 llvm::CanonicalLoopInfo *loopInfo = nullptr;
601 moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
602 [&](OpenMPLoopInfoStackFrame &frame) {
603 loopInfo = frame.loopInfo;
604 return WalkResult::interrupt();
605 });
606 return loopInfo;
607}
608
609/// Converts the given region that appears within an OpenMP dialect operation to
610/// LLVM IR, creating a branch from the `sourceBlock` to the entry block of the
611/// region, and a branch from any block with an successor-less OpenMP terminator
612/// to `continuationBlock`. Populates `continuationBlockPHIs` with the PHI nodes
613/// of the continuation block if provided.
615 Region &region, StringRef blockName, llvm::IRBuilderBase &builder,
616 LLVM::ModuleTranslation &moduleTranslation,
617 SmallVectorImpl<llvm::PHINode *> *continuationBlockPHIs = nullptr) {
618 bool isLoopWrapper = isa<omp::LoopWrapperInterface>(region.getParentOp());
619
620 llvm::BasicBlock *continuationBlock =
621 splitBB(builder, true, "omp.region.cont");
622 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
623
624 llvm::LLVMContext &llvmContext = builder.getContext();
625 for (Block &bb : region) {
626 llvm::BasicBlock *llvmBB = llvm::BasicBlock::Create(
627 llvmContext, blockName, builder.GetInsertBlock()->getParent(),
628 builder.GetInsertBlock()->getNextNode());
629 moduleTranslation.mapBlock(&bb, llvmBB);
630 }
631
632 llvm::Instruction *sourceTerminator = sourceBlock->getTerminator();
633
634 // Terminators (namely YieldOp) may be forwarding values to the region that
635 // need to be available in the continuation block. Collect the types of these
636 // operands in preparation of creating PHI nodes. This is skipped for loop
637 // wrapper operations, for which we know in advance they have no terminators.
638 SmallVector<llvm::Type *> continuationBlockPHITypes;
639 unsigned numYields = 0;
640
641 if (!isLoopWrapper) {
642 bool operandsProcessed = false;
643 for (Block &bb : region.getBlocks()) {
644 if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(bb.getTerminator())) {
645 if (!operandsProcessed) {
646 for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
647 continuationBlockPHITypes.push_back(
648 moduleTranslation.convertType(yield->getOperand(i).getType()));
649 }
650 operandsProcessed = true;
651 } else {
652 assert(continuationBlockPHITypes.size() == yield->getNumOperands() &&
653 "mismatching number of values yielded from the region");
654 for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
655 llvm::Type *operandType =
656 moduleTranslation.convertType(yield->getOperand(i).getType());
657 (void)operandType;
658 assert(continuationBlockPHITypes[i] == operandType &&
659 "values of mismatching types yielded from the region");
660 }
661 }
662 numYields++;
663 }
664 }
665 }
666
667 // Insert PHI nodes in the continuation block for any values forwarded by the
668 // terminators in this region.
669 if (!continuationBlockPHITypes.empty())
670 assert(
671 continuationBlockPHIs &&
672 "expected continuation block PHIs if converted regions yield values");
673 if (continuationBlockPHIs) {
674 llvm::IRBuilderBase::InsertPointGuard guard(builder);
675 continuationBlockPHIs->reserve(continuationBlockPHITypes.size());
676 builder.SetInsertPoint(continuationBlock, continuationBlock->begin());
677 for (llvm::Type *ty : continuationBlockPHITypes)
678 continuationBlockPHIs->push_back(builder.CreatePHI(ty, numYields));
679 }
680
681 // Convert blocks one by one in topological order to ensure
682 // defs are converted before uses.
684 for (Block *bb : blocks) {
685 llvm::BasicBlock *llvmBB = moduleTranslation.lookupBlock(bb);
686 // Retarget the branch of the entry block to the entry block of the
687 // converted region (regions are single-entry).
688 if (bb->isEntryBlock()) {
689 assert(sourceTerminator->getNumSuccessors() == 1 &&
690 "provided entry block has multiple successors");
691 assert(sourceTerminator->getSuccessor(0) == continuationBlock &&
692 "ContinuationBlock is not the successor of the entry block");
693 sourceTerminator->setSuccessor(0, llvmBB);
694 }
695
696 llvm::IRBuilderBase::InsertPointGuard guard(builder);
697 if (failed(
698 moduleTranslation.convertBlock(*bb, bb->isEntryBlock(), builder)))
699 return llvm::make_error<PreviouslyReportedError>();
700
701 // Create a direct branch here for loop wrappers to prevent their lack of a
702 // terminator from causing a crash below.
703 if (isLoopWrapper) {
704 builder.CreateBr(continuationBlock);
705 continue;
706 }
707
708 // Special handling for `omp.yield` and `omp.terminator` (we may have more
709 // than one): they return the control to the parent OpenMP dialect operation
710 // so replace them with the branch to the continuation block. We handle this
711 // here to avoid relying inter-function communication through the
712 // ModuleTranslation class to set up the correct insertion point. This is
713 // also consistent with MLIR's idiom of handling special region terminators
714 // in the same code that handles the region-owning operation.
715 Operation *terminator = bb->getTerminator();
716 if (isa<omp::TerminatorOp, omp::YieldOp>(terminator)) {
717 builder.CreateBr(continuationBlock);
718
719 for (unsigned i = 0, e = terminator->getNumOperands(); i < e; ++i)
720 (*continuationBlockPHIs)[i]->addIncoming(
721 moduleTranslation.lookupValue(terminator->getOperand(i)), llvmBB);
722 }
723 }
724 // After all blocks have been traversed and values mapped, connect the PHI
725 // nodes to the results of preceding blocks.
726 LLVM::detail::connectPHINodes(region, moduleTranslation);
727
728 // Remove the blocks and values defined in this region from the mapping since
729 // they are not visible outside of this region. This allows the same region to
730 // be converted several times, that is cloned, without clashes, and slightly
731 // speeds up the lookups.
732 moduleTranslation.forgetMapping(region);
733
734 return continuationBlock;
735}
736
737/// Convert ProcBindKind from MLIR-generated enum to LLVM enum.
738static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind) {
739 switch (kind) {
740 case omp::ClauseProcBindKind::Close:
741 return llvm::omp::ProcBindKind::OMP_PROC_BIND_close;
742 case omp::ClauseProcBindKind::Master:
743 return llvm::omp::ProcBindKind::OMP_PROC_BIND_master;
744 case omp::ClauseProcBindKind::Primary:
745 return llvm::omp::ProcBindKind::OMP_PROC_BIND_primary;
746 case omp::ClauseProcBindKind::Spread:
747 return llvm::omp::ProcBindKind::OMP_PROC_BIND_spread;
748 }
749 llvm_unreachable("Unknown ClauseProcBindKind kind");
750}
751
752/// Converts an OpenMP 'masked' operation into LLVM IR using OpenMPIRBuilder.
753static LogicalResult
754convertOmpMasked(Operation &opInst, llvm::IRBuilderBase &builder,
755 LLVM::ModuleTranslation &moduleTranslation) {
756 auto maskedOp = cast<omp::MaskedOp>(opInst);
757 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
758
759 if (failed(checkImplementationStatus(opInst)))
760 return failure();
761
762 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
764 // MaskedOp has only one region associated with it.
765 auto &region = maskedOp.getRegion();
766 builder.restoreIP(codeGenIP);
767 return convertOmpOpRegions(region, "omp.masked.region", builder,
768 moduleTranslation)
769 .takeError();
770 };
771
772 // TODO: Perform finalization actions for variables. This has to be
773 // called for variables which have destructors/finalizers.
774 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
775
776 llvm::Value *filterVal = nullptr;
777 if (auto filterVar = maskedOp.getFilteredThreadId()) {
778 filterVal = moduleTranslation.lookupValue(filterVar);
779 } else {
780 llvm::LLVMContext &llvmContext = builder.getContext();
781 filterVal =
782 llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext), /*V=*/0);
783 }
784 assert(filterVal != nullptr);
785 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
786 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
787 moduleTranslation.getOpenMPBuilder()->createMasked(ompLoc, bodyGenCB,
788 finiCB, filterVal);
789
790 if (failed(handleError(afterIP, opInst)))
791 return failure();
792
793 builder.restoreIP(*afterIP);
794 return success();
795}
796
797/// Converts an OpenMP 'master' operation into LLVM IR using OpenMPIRBuilder.
798static LogicalResult
799convertOmpMaster(Operation &opInst, llvm::IRBuilderBase &builder,
800 LLVM::ModuleTranslation &moduleTranslation) {
801 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
802 auto masterOp = cast<omp::MasterOp>(opInst);
803
804 if (failed(checkImplementationStatus(opInst)))
805 return failure();
806
807 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
809 // MasterOp has only one region associated with it.
810 auto &region = masterOp.getRegion();
811 builder.restoreIP(codeGenIP);
812 return convertOmpOpRegions(region, "omp.master.region", builder,
813 moduleTranslation)
814 .takeError();
815 };
816
817 // TODO: Perform finalization actions for variables. This has to be
818 // called for variables which have destructors/finalizers.
819 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
820
821 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
822 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
823 moduleTranslation.getOpenMPBuilder()->createMaster(ompLoc, bodyGenCB,
824 finiCB);
825
826 if (failed(handleError(afterIP, opInst)))
827 return failure();
828
829 builder.restoreIP(*afterIP);
830 return success();
831}
832
833/// Converts an OpenMP 'critical' operation into LLVM IR using OpenMPIRBuilder.
834static LogicalResult
835convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder,
836 LLVM::ModuleTranslation &moduleTranslation) {
837 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
838 auto criticalOp = cast<omp::CriticalOp>(opInst);
839
840 if (failed(checkImplementationStatus(opInst)))
841 return failure();
842
843 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
845 // CriticalOp has only one region associated with it.
846 auto &region = cast<omp::CriticalOp>(opInst).getRegion();
847 builder.restoreIP(codeGenIP);
848 return convertOmpOpRegions(region, "omp.critical.region", builder,
849 moduleTranslation)
850 .takeError();
851 };
852
853 // TODO: Perform finalization actions for variables. This has to be
854 // called for variables which have destructors/finalizers.
855 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
856
857 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
858 llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
859 llvm::Constant *hint = nullptr;
860
861 // If it has a name, it probably has a hint too.
862 if (criticalOp.getNameAttr()) {
863 // The verifiers in OpenMP Dialect guarentee that all the pointers are
864 // non-null
865 auto symbolRef = cast<SymbolRefAttr>(criticalOp.getNameAttr());
866 auto criticalDeclareOp =
868 symbolRef);
869 hint =
870 llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext),
871 static_cast<int>(criticalDeclareOp.getHint()));
872 }
873 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
874 moduleTranslation.getOpenMPBuilder()->createCritical(
875 ompLoc, bodyGenCB, finiCB, criticalOp.getName().value_or(""), hint);
876
877 if (failed(handleError(afterIP, opInst)))
878 return failure();
879
880 builder.restoreIP(*afterIP);
881 return success();
882}
883
884/// A util to collect info needed to convert delayed privatizers from MLIR to
885/// LLVM.
887 template <typename OP>
889 : blockArgs(
890 cast<omp::BlockArgOpenMPOpInterface>(*op).getPrivateBlockArgs()) {
891 mlirVars.reserve(blockArgs.size());
892 llvmVars.reserve(blockArgs.size());
893 collectPrivatizationDecls<OP>(op);
894
895 for (mlir::Value privateVar : op.getPrivateVars())
896 mlirVars.push_back(privateVar);
897 }
898
903
904private:
905 /// Populates `privatizations` with privatization declarations used for the
906 /// given op.
907 template <class OP>
908 void collectPrivatizationDecls(OP op) {
909 std::optional<ArrayAttr> attr = op.getPrivateSyms();
910 if (!attr)
911 return;
912
913 privatizers.reserve(privatizers.size() + attr->size());
914 for (auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
915 privatizers.push_back(findPrivatizer(op, symbolRef));
916 }
917 }
918};
919
920/// Populates `reductions` with reduction declarations used in the given op.
921template <typename T>
922static void
925 std::optional<ArrayAttr> attr = op.getReductionSyms();
926 if (!attr)
927 return;
928
929 reductions.reserve(reductions.size() + op.getNumReductionVars());
930 for (auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
931 reductions.push_back(
933 op, symbolRef));
934 }
935}
936
937/// Translates the blocks contained in the given region and appends them to at
938/// the current insertion point of `builder`. The operations of the entry block
939/// are appended to the current insertion block. If set, `continuationBlockArgs`
940/// is populated with translated values that correspond to the values
941/// omp.yield'ed from the region.
942static LogicalResult inlineConvertOmpRegions(
943 Region &region, StringRef blockName, llvm::IRBuilderBase &builder,
944 LLVM::ModuleTranslation &moduleTranslation,
945 SmallVectorImpl<llvm::Value *> *continuationBlockArgs = nullptr) {
946 if (region.empty())
947 return success();
948
949 // Special case for single-block regions that don't create additional blocks:
950 // insert operations without creating additional blocks.
951 if (region.hasOneBlock()) {
952 llvm::Instruction *potentialTerminator =
953 builder.GetInsertBlock()->empty() ? nullptr
954 : &builder.GetInsertBlock()->back();
955
956 if (potentialTerminator && potentialTerminator->isTerminator())
957 potentialTerminator->removeFromParent();
958 moduleTranslation.mapBlock(&region.front(), builder.GetInsertBlock());
959
960 if (failed(moduleTranslation.convertBlock(
961 region.front(), /*ignoreArguments=*/true, builder)))
962 return failure();
963
964 // The continuation arguments are simply the translated terminator operands.
965 if (continuationBlockArgs)
966 llvm::append_range(
967 *continuationBlockArgs,
968 moduleTranslation.lookupValues(region.front().back().getOperands()));
969
970 // Drop the mapping that is no longer necessary so that the same region can
971 // be processed multiple times.
972 moduleTranslation.forgetMapping(region);
973
974 if (potentialTerminator && potentialTerminator->isTerminator()) {
975 llvm::BasicBlock *block = builder.GetInsertBlock();
976 if (block->empty()) {
977 // this can happen for really simple reduction init regions e.g.
978 // %0 = llvm.mlir.constant(0 : i32) : i32
979 // omp.yield(%0 : i32)
980 // because the llvm.mlir.constant (MLIR op) isn't converted into any
981 // llvm op
982 potentialTerminator->insertInto(block, block->begin());
983 } else {
984 potentialTerminator->insertAfter(&block->back());
985 }
986 }
987
988 return success();
989 }
990
992 llvm::Expected<llvm::BasicBlock *> continuationBlock =
993 convertOmpOpRegions(region, blockName, builder, moduleTranslation, &phis);
994
995 if (failed(handleError(continuationBlock, *region.getParentOp())))
996 return failure();
997
998 if (continuationBlockArgs)
999 llvm::append_range(*continuationBlockArgs, phis);
1000 builder.SetInsertPoint(*continuationBlock,
1001 (*continuationBlock)->getFirstInsertionPt());
1002 return success();
1003}
1004
1005namespace {
1006/// Owning equivalents of OpenMPIRBuilder::(Atomic)ReductionGen that are used to
1007/// store lambdas with capture.
1008using OwningReductionGen =
1009 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
1010 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *,
1011 llvm::Value *&)>;
1012using OwningAtomicReductionGen =
1013 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
1014 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Type *, llvm::Value *,
1015 llvm::Value *)>;
1016using OwningDataPtrPtrReductionGen =
1017 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
1018 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *&)>;
1019} // namespace
1020
1021/// Create an OpenMPIRBuilder-compatible reduction generator for the given
1022/// reduction declaration. The generator uses `builder` but ignores its
1023/// insertion point.
1024static OwningReductionGen
1025makeReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder,
1026 LLVM::ModuleTranslation &moduleTranslation) {
1027 // The lambda is mutable because we need access to non-const methods of decl
1028 // (which aren't actually mutating it), and we must capture decl by-value to
1029 // avoid the dangling reference after the parent function returns.
1030 OwningReductionGen gen =
1031 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
1032 llvm::Value *lhs, llvm::Value *rhs,
1033 llvm::Value *&result) mutable
1034 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1035 moduleTranslation.mapValue(decl.getReductionLhsArg(), lhs);
1036 moduleTranslation.mapValue(decl.getReductionRhsArg(), rhs);
1037 builder.restoreIP(insertPoint);
1039 if (failed(inlineConvertOmpRegions(decl.getReductionRegion(),
1040 "omp.reduction.nonatomic.body", builder,
1041 moduleTranslation, &phis)))
1042 return llvm::createStringError(
1043 "failed to inline `combiner` region of `omp.declare_reduction`");
1044 result = llvm::getSingleElement(phis);
1045 return builder.saveIP();
1046 };
1047 return gen;
1048}
1049
1050/// Create an OpenMPIRBuilder-compatible atomic reduction generator for the
1051/// given reduction declaration. The generator uses `builder` but ignores its
1052/// insertion point. Returns null if there is no atomic region available in the
1053/// reduction declaration.
1054static OwningAtomicReductionGen
1055makeAtomicReductionGen(omp::DeclareReductionOp decl,
1056 llvm::IRBuilderBase &builder,
1057 LLVM::ModuleTranslation &moduleTranslation) {
1058 if (decl.getAtomicReductionRegion().empty())
1059 return OwningAtomicReductionGen();
1060
1061 // The lambda is mutable because we need access to non-const methods of decl
1062 // (which aren't actually mutating it), and we must capture decl by-value to
1063 // avoid the dangling reference after the parent function returns.
1064 OwningAtomicReductionGen atomicGen =
1065 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint, llvm::Type *,
1066 llvm::Value *lhs, llvm::Value *rhs) mutable
1067 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1068 moduleTranslation.mapValue(decl.getAtomicReductionLhsArg(), lhs);
1069 moduleTranslation.mapValue(decl.getAtomicReductionRhsArg(), rhs);
1070 builder.restoreIP(insertPoint);
1072 if (failed(inlineConvertOmpRegions(decl.getAtomicReductionRegion(),
1073 "omp.reduction.atomic.body", builder,
1074 moduleTranslation, &phis)))
1075 return llvm::createStringError(
1076 "failed to inline `atomic` region of `omp.declare_reduction`");
1077 assert(phis.empty());
1078 return builder.saveIP();
1079 };
1080 return atomicGen;
1081}
1082
1083/// Create an OpenMPIRBuilder-compatible `data_ptr_ptr` reduction generator for
1084/// the given reduction declaration. The generator uses `builder` but ignores
1085/// its insertion point. Returns null if there is no `data_ptr_ptr` region
1086/// available in the reduction declaration.
1087static OwningDataPtrPtrReductionGen
1088makeRefDataPtrGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder,
1089 LLVM::ModuleTranslation &moduleTranslation, bool isByRef) {
1090 if (!isByRef || decl.getDataPtrPtrRegion().empty())
1091 return OwningDataPtrPtrReductionGen();
1092
1093 OwningDataPtrPtrReductionGen refDataPtrGen =
1094 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
1095 llvm::Value *byRefVal, llvm::Value *&result) mutable
1096 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1097 moduleTranslation.mapValue(decl.getDataPtrPtrRegionArg(), byRefVal);
1098 builder.restoreIP(insertPoint);
1100 if (failed(inlineConvertOmpRegions(decl.getDataPtrPtrRegion(),
1101 "omp.data_ptr_ptr.body", builder,
1102 moduleTranslation, &phis)))
1103 return llvm::createStringError(
1104 "failed to inline `data_ptr_ptr` region of `omp.declare_reduction`");
1105 result = llvm::getSingleElement(phis);
1106 return builder.saveIP();
1107 };
1108
1109 return refDataPtrGen;
1110}
1111
1112/// Converts an OpenMP 'ordered' operation into LLVM IR using OpenMPIRBuilder.
1113static LogicalResult
1114convertOmpOrdered(Operation &opInst, llvm::IRBuilderBase &builder,
1115 LLVM::ModuleTranslation &moduleTranslation) {
1116 auto orderedOp = cast<omp::OrderedOp>(opInst);
1117
1118 if (failed(checkImplementationStatus(opInst)))
1119 return failure();
1120
1121 omp::ClauseDepend dependType = *orderedOp.getDoacrossDependType();
1122 bool isDependSource = dependType == omp::ClauseDepend::dependsource;
1123 unsigned numLoops = *orderedOp.getDoacrossNumLoops();
1124 SmallVector<llvm::Value *> vecValues =
1125 moduleTranslation.lookupValues(orderedOp.getDoacrossDependVars());
1126
1127 size_t indexVecValues = 0;
1128 while (indexVecValues < vecValues.size()) {
1129 SmallVector<llvm::Value *> storeValues;
1130 storeValues.reserve(numLoops);
1131 for (unsigned i = 0; i < numLoops; i++) {
1132 storeValues.push_back(vecValues[indexVecValues]);
1133 indexVecValues++;
1134 }
1135 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1136 findAllocInsertPoints(builder, moduleTranslation);
1137 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1138 builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createOrderedDepend(
1139 ompLoc, allocaIP, numLoops, storeValues, ".cnt.addr", isDependSource));
1140 }
1141 return success();
1142}
1143
1144/// Converts an OpenMP 'ordered_region' operation into LLVM IR using
1145/// OpenMPIRBuilder.
1146static LogicalResult
1147convertOmpOrderedRegion(Operation &opInst, llvm::IRBuilderBase &builder,
1148 LLVM::ModuleTranslation &moduleTranslation) {
1149 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1150 auto orderedRegionOp = cast<omp::OrderedRegionOp>(opInst);
1151
1152 if (failed(checkImplementationStatus(opInst)))
1153 return failure();
1154
1155 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
1156 llvm::ArrayRef<llvm::BasicBlock *> deallocBlocks) {
1157 // OrderedOp has only one region associated with it.
1158 auto &region = cast<omp::OrderedRegionOp>(opInst).getRegion();
1159 builder.restoreIP(codeGenIP);
1160 return convertOmpOpRegions(region, "omp.ordered.region", builder,
1161 moduleTranslation)
1162 .takeError();
1163 };
1164
1165 // TODO: Perform finalization actions for variables. This has to be
1166 // called for variables which have destructors/finalizers.
1167 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
1168
1169 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1170 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1171 moduleTranslation.getOpenMPBuilder()->createOrderedThreadsSimd(
1172 ompLoc, bodyGenCB, finiCB, !orderedRegionOp.getParLevelSimd());
1173
1174 if (failed(handleError(afterIP, opInst)))
1175 return failure();
1176
1177 builder.restoreIP(*afterIP);
1178 return success();
1179}
1180
1181namespace {
1182/// Contains the arguments for an LLVM store operation
1183struct DeferredStore {
1184 DeferredStore(llvm::Value *value, llvm::Value *address)
1185 : value(value), address(address) {}
1186
1187 llvm::Value *value;
1188 llvm::Value *address;
1189};
1190} // namespace
1191
1192/// Allocate space for privatized reduction variables.
1193/// `deferredStores` contains information to create store operations which needs
1194/// to be inserted after all allocas
1195template <typename T>
1196static LogicalResult
1198 llvm::IRBuilderBase &builder,
1199 LLVM::ModuleTranslation &moduleTranslation,
1200 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1202 SmallVectorImpl<llvm::Value *> &privateReductionVariables,
1203 DenseMap<Value, llvm::Value *> &reductionVariableMap,
1204 SmallVectorImpl<DeferredStore> &deferredStores,
1205 llvm::ArrayRef<bool> isByRefs) {
1206 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1207 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1208
1209 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
1210 bool useDeviceSharedMem = omp::opInSharedDeviceContext(*op);
1211
1212 // delay creating stores until after all allocas
1213 deferredStores.reserve(op.getNumReductionVars());
1214
1215 for (std::size_t i = 0; i < op.getNumReductionVars(); ++i) {
1216 Region &allocRegion = reductionDecls[i].getAllocRegion();
1217 if (isByRefs[i]) {
1218 if (allocRegion.empty())
1219 continue;
1220
1222 if (failed(inlineConvertOmpRegions(allocRegion, "omp.reduction.alloc",
1223 builder, moduleTranslation, &phis)))
1224 return op.emitError(
1225 "failed to inline `alloc` region of `omp.declare_reduction`");
1226
1227 assert(phis.size() == 1 && "expected one allocation to be yielded");
1228 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1229
1230 // Allocate reduction variable (which is a pointer to the real reduction
1231 // variable allocated in the inlined region)
1232 llvm::Type *ptrTy = builder.getPtrTy();
1233 llvm::Type *varTy =
1234 moduleTranslation.convertType(reductionDecls[i].getType());
1235 llvm::Value *var;
1236 if (useDeviceSharedMem) {
1237 var = ompBuilder->createOMPAllocShared(builder, varTy);
1238 } else {
1239 var = builder.CreateAlloca(varTy);
1240 var = builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
1241 }
1242
1243 llvm::Value *castPhi =
1244 builder.CreatePointerBitCastOrAddrSpaceCast(phis[0], ptrTy);
1245
1246 deferredStores.emplace_back(castPhi, var);
1247
1248 privateReductionVariables[i] = var;
1249 moduleTranslation.mapValue(reductionArgs[i], castPhi);
1250 reductionVariableMap.try_emplace(op.getReductionVars()[i], castPhi);
1251 } else {
1252 assert(allocRegion.empty() &&
1253 "allocaction is implicit for by-val reduction");
1254
1255 llvm::Type *ptrTy = builder.getPtrTy();
1256 llvm::Type *varTy =
1257 moduleTranslation.convertType(reductionDecls[i].getType());
1258 llvm::Value *var;
1259 if (useDeviceSharedMem) {
1260 var = ompBuilder->createOMPAllocShared(builder, varTy);
1261 } else {
1262 var = builder.CreateAlloca(varTy);
1263 var = builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
1264 }
1265
1266 moduleTranslation.mapValue(reductionArgs[i], var);
1267 privateReductionVariables[i] = var;
1268 reductionVariableMap.try_emplace(op.getReductionVars()[i], var);
1269 }
1270 }
1271
1272 return success();
1273}
1274
1275/// Map input arguments to reduction initialization region
1276template <typename T>
1277static void
1279 llvm::IRBuilderBase &builder,
1281 DenseMap<Value, llvm::Value *> &reductionVariableMap,
1282 unsigned i) {
1283 // map input argument to the initialization region
1284 mlir::omp::DeclareReductionOp &reduction = reductionDecls[i];
1285 Region &initializerRegion = reduction.getInitializerRegion();
1286 Block &entry = initializerRegion.front();
1287
1288 mlir::Value mlirSource = loop.getReductionVars()[i];
1289 llvm::Value *llvmSource = moduleTranslation.lookupValue(mlirSource);
1290 llvm::Value *origVal = llvmSource;
1291 // If a non-pointer value is expected, load the value from the source pointer.
1292 if (!isa<LLVM::LLVMPointerType>(
1293 reduction.getInitializerMoldArg().getType()) &&
1294 isa<LLVM::LLVMPointerType>(mlirSource.getType())) {
1295 origVal =
1296 builder.CreateLoad(moduleTranslation.convertType(
1297 reduction.getInitializerMoldArg().getType()),
1298 llvmSource, "omp_orig");
1299 }
1300 moduleTranslation.mapValue(reduction.getInitializerMoldArg(), origVal);
1301
1302 if (entry.getNumArguments() > 1) {
1303 llvm::Value *allocation =
1304 reductionVariableMap.lookup(loop.getReductionVars()[i]);
1305 moduleTranslation.mapValue(reduction.getInitializerAllocArg(), allocation);
1306 }
1307}
1308
1309static void
1310setInsertPointForPossiblyEmptyBlock(llvm::IRBuilderBase &builder,
1311 llvm::BasicBlock *block = nullptr) {
1312 if (block == nullptr)
1313 block = builder.GetInsertBlock();
1314
1315 if (!block->hasTerminator())
1316 builder.SetInsertPoint(block);
1317 else
1318 builder.SetInsertPoint(block->getTerminator());
1319}
1320
1321/// Inline reductions' `init` regions. This functions assumes that the
1322/// `builder`'s insertion point is where the user wants the `init` regions to be
1323/// inlined; i.e. it does not try to find a proper insertion location for the
1324/// `init` regions. It also leaves the `builder's insertions point in a state
1325/// where the user can continue the code-gen directly afterwards.
1326template <typename OP>
1327static LogicalResult
1328initReductionVars(OP op, ArrayRef<BlockArgument> reductionArgs,
1329 llvm::IRBuilderBase &builder,
1330 LLVM::ModuleTranslation &moduleTranslation,
1331 llvm::BasicBlock *latestAllocaBlock,
1333 SmallVectorImpl<llvm::Value *> &privateReductionVariables,
1334 DenseMap<Value, llvm::Value *> &reductionVariableMap,
1335 llvm::ArrayRef<bool> isByRef,
1336 SmallVectorImpl<DeferredStore> &deferredStores) {
1337 if (op.getNumReductionVars() == 0)
1338 return success();
1339
1340 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
1341 bool useDeviceSharedMem = omp::opInSharedDeviceContext(*op);
1342
1343 llvm::BasicBlock *initBlock = splitBB(builder, true, "omp.reduction.init");
1344 auto allocaIP = llvm::IRBuilderBase::InsertPoint(
1345 latestAllocaBlock, latestAllocaBlock->getTerminator()->getIterator());
1346 builder.restoreIP(allocaIP);
1347 SmallVector<llvm::Value *> byRefVars(op.getNumReductionVars());
1348
1349 for (unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1350 if (isByRef[i]) {
1351 if (!reductionDecls[i].getAllocRegion().empty())
1352 continue;
1353
1354 // TODO: remove after all users of by-ref are updated to use the alloc
1355 // region: Allocate reduction variable (which is a pointer to the real
1356 // reduciton variable allocated in the inlined region)
1357 llvm::Type *varTy =
1358 moduleTranslation.convertType(reductionDecls[i].getType());
1359 if (useDeviceSharedMem)
1360 byRefVars[i] = ompBuilder->createOMPAllocShared(builder, varTy);
1361 else
1362 byRefVars[i] = builder.CreateAlloca(varTy);
1363 }
1364 }
1365
1366 setInsertPointForPossiblyEmptyBlock(builder, initBlock);
1367
1368 // store result of the alloc region to the allocated pointer to the real
1369 // reduction variable
1370 for (auto [data, addr] : deferredStores)
1371 builder.CreateStore(data, addr);
1372
1373 // Before the loop, store the initial values of reductions into reduction
1374 // variables. Although this could be done after allocas, we don't want to mess
1375 // up with the alloca insertion point.
1376 for (unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1378
1379 // map block argument to initializer region
1380 mapInitializationArgs(op, moduleTranslation, builder, reductionDecls,
1381 reductionVariableMap, i);
1382
1383 // TODO In some cases (specially on the GPU), the init regions may
1384 // contains stack alloctaions. If the region is inlined in a loop, this is
1385 // problematic. Instead of just inlining the region, handle allocations by
1386 // hoisting fixed length allocations to the function entry and using
1387 // stacksave and restore for variable length ones.
1388 if (failed(inlineConvertOmpRegions(reductionDecls[i].getInitializerRegion(),
1389 "omp.reduction.neutral", builder,
1390 moduleTranslation, &phis)))
1391 return failure();
1392
1393 assert(phis.size() == 1 && "expected one value to be yielded from the "
1394 "reduction neutral element declaration region");
1395
1397
1398 if (isByRef[i]) {
1399 if (!reductionDecls[i].getAllocRegion().empty())
1400 // done in allocReductionVars
1401 continue;
1402
1403 // TODO: this path can be removed once all users of by-ref are updated to
1404 // use an alloc region
1405
1406 // Store the result of the inlined region to the allocated reduction var
1407 // ptr
1408 builder.CreateStore(phis[0], byRefVars[i]);
1409
1410 privateReductionVariables[i] = byRefVars[i];
1411 moduleTranslation.mapValue(reductionArgs[i], phis[0]);
1412 reductionVariableMap.try_emplace(op.getReductionVars()[i], phis[0]);
1413 } else {
1414 // for by-ref case the store is inside of the reduction region
1415 builder.CreateStore(phis[0], privateReductionVariables[i]);
1416 // the rest was handled in allocByValReductionVars
1417 }
1418
1419 // forget the mapping for the initializer region because we might need a
1420 // different mapping if this reduction declaration is re-used for a
1421 // different variable
1422 moduleTranslation.forgetMapping(reductionDecls[i].getInitializerRegion());
1423 }
1424
1425 return success();
1426}
1427
1428/// Collect reduction info
1429template <typename T>
1430static void collectReductionInfo(
1431 T loop, llvm::IRBuilderBase &builder,
1432 LLVM::ModuleTranslation &moduleTranslation,
1435 SmallVectorImpl<OwningAtomicReductionGen> &owningAtomicReductionGens,
1437 const ArrayRef<llvm::Value *> privateReductionVariables,
1439 ArrayRef<bool> isByRef) {
1440 unsigned numReductions = loop.getNumReductionVars();
1441
1442 for (unsigned i = 0; i < numReductions; ++i) {
1443 owningReductionGens.push_back(
1444 makeReductionGen(reductionDecls[i], builder, moduleTranslation));
1445 owningAtomicReductionGens.push_back(
1446 makeAtomicReductionGen(reductionDecls[i], builder, moduleTranslation));
1448 reductionDecls[i], builder, moduleTranslation, isByRef[i]));
1449 }
1450
1451 // Collect the reduction information.
1452 reductionInfos.reserve(numReductions);
1453 for (unsigned i = 0; i < numReductions; ++i) {
1454 llvm::OpenMPIRBuilder::ReductionGenAtomicCBTy atomicGen = nullptr;
1455 if (owningAtomicReductionGens[i])
1456 atomicGen = owningAtomicReductionGens[i];
1457 llvm::Value *variable =
1458 moduleTranslation.lookupValue(loop.getReductionVars()[i]);
1459 mlir::Type allocatedType;
1460 reductionDecls[i].getAllocRegion().walk([&](mlir::Operation *op) {
1461 if (auto alloca = mlir::dyn_cast<LLVM::AllocaOp>(op)) {
1462 allocatedType = alloca.getElemType();
1464 }
1465
1467 });
1468
1469 reductionInfos.push_back(
1470 {moduleTranslation.convertType(reductionDecls[i].getType()), variable,
1471 privateReductionVariables[i],
1472 /*EvaluationKind=*/llvm::OpenMPIRBuilder::EvalKind::Scalar,
1474 /*ReductionGenClang=*/nullptr, atomicGen,
1476 allocatedType ? moduleTranslation.convertType(allocatedType) : nullptr,
1477 reductionDecls[i].getByrefElementType()
1478 ? moduleTranslation.convertType(
1479 *reductionDecls[i].getByrefElementType())
1480 : nullptr});
1481 }
1482}
1483
1484/// handling of DeclareReductionOp's cleanup region
1485static LogicalResult
1487 llvm::ArrayRef<llvm::Value *> privateVariables,
1488 LLVM::ModuleTranslation &moduleTranslation,
1489 llvm::IRBuilderBase &builder, StringRef regionName,
1490 bool shouldLoadCleanupRegionArg = true) {
1491 for (auto [i, cleanupRegion] : llvm::enumerate(cleanupRegions)) {
1492 if (cleanupRegion->empty())
1493 continue;
1494
1495 // map the argument to the cleanup region
1496 Block &entry = cleanupRegion->front();
1497
1498 llvm::Instruction *potentialTerminator =
1499 builder.GetInsertBlock()->empty() ? nullptr
1500 : &builder.GetInsertBlock()->back();
1501 if (potentialTerminator && potentialTerminator->isTerminator())
1502 builder.SetInsertPoint(potentialTerminator);
1503 llvm::Value *privateVarValue =
1504 shouldLoadCleanupRegionArg
1505 ? builder.CreateLoad(
1506 moduleTranslation.convertType(entry.getArgument(0).getType()),
1507 privateVariables[i])
1508 : privateVariables[i];
1509
1510 moduleTranslation.mapValue(entry.getArgument(0), privateVarValue);
1511
1512 if (failed(inlineConvertOmpRegions(*cleanupRegion, regionName, builder,
1513 moduleTranslation)))
1514 return failure();
1515
1516 // clear block argument mapping in case it needs to be re-created with a
1517 // different source for another use of the same reduction decl
1518 moduleTranslation.forgetMapping(*cleanupRegion);
1519 }
1520 return success();
1521}
1522
1523// TODO: not used by ParallelOp
1524template <class OP>
1525static LogicalResult createReductionsAndCleanup(
1526 OP op, llvm::IRBuilderBase &builder,
1527 LLVM::ModuleTranslation &moduleTranslation,
1528 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1530 ArrayRef<llvm::Value *> privateReductionVariables, ArrayRef<bool> isByRef,
1531 bool isNowait = false, bool isTeamsReduction = false) {
1532 // Process the reductions if required.
1533 if (op.getNumReductionVars() == 0)
1534 return success();
1535
1537 SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
1538 SmallVector<OwningDataPtrPtrReductionGen> owningReductionGenRefDataPtrGens;
1540
1541 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
1542
1543 // Create the reduction generators. We need to own them here because
1544 // ReductionInfo only accepts references to the generators.
1545 collectReductionInfo(op, builder, moduleTranslation, reductionDecls,
1546 owningReductionGens, owningAtomicReductionGens,
1547 owningReductionGenRefDataPtrGens,
1548 privateReductionVariables, reductionInfos, isByRef);
1549
1550 // The call to createReductions below expects the block to have a
1551 // terminator. Create an unreachable instruction to serve as terminator
1552 // and remove it later.
1553 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
1554 builder.SetInsertPoint(tempTerminator);
1555 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
1556 ompBuilder->createReductions(builder.saveIP(), allocaIP, reductionInfos,
1557 isByRef, isNowait, isTeamsReduction);
1558
1559 if (failed(handleError(contInsertPoint, *op)))
1560 return failure();
1561
1562 if (!contInsertPoint->getBlock())
1563 return op->emitOpError() << "failed to convert reductions";
1564
1565 llvm::OpenMPIRBuilder::InsertPointTy afterIP = *contInsertPoint;
1566 if (!isTeamsReduction) {
1567 llvm::OpenMPIRBuilder::InsertPointOrErrorTy barrierIP =
1568 ompBuilder->createBarrier(*contInsertPoint, llvm::omp::OMPD_for);
1569
1570 if (failed(handleError(barrierIP, *op)))
1571 return failure();
1572 afterIP = *barrierIP;
1573 }
1574
1575 tempTerminator->eraseFromParent();
1576 builder.restoreIP(afterIP);
1577
1578 // after the construct, deallocate private reduction variables
1579 SmallVector<Region *> reductionRegions;
1580 llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
1581 [](omp::DeclareReductionOp reductionDecl) {
1582 return &reductionDecl.getCleanupRegion();
1583 });
1584 LogicalResult result = inlineOmpRegionCleanup(
1585 reductionRegions, privateReductionVariables, moduleTranslation, builder,
1586 "omp.reduction.cleanup");
1587
1588 bool useDeviceSharedMem = omp::opInSharedDeviceContext(*op);
1589 if (useDeviceSharedMem) {
1590 for (auto [var, reductionDecl] :
1591 llvm::zip_equal(privateReductionVariables, reductionDecls))
1592 ompBuilder->createOMPFreeShared(
1593 builder, var, moduleTranslation.convertType(reductionDecl.getType()));
1594 }
1595
1596 return result;
1597}
1598
1599static ArrayRef<bool> getIsByRef(std::optional<ArrayRef<bool>> attr) {
1600 if (!attr)
1601 return {};
1602 return *attr;
1603}
1604
1605// TODO: not used by omp.parallel
1606template <typename OP>
1608 OP op, ArrayRef<BlockArgument> reductionArgs, llvm::IRBuilderBase &builder,
1609 LLVM::ModuleTranslation &moduleTranslation,
1610 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1612 SmallVectorImpl<llvm::Value *> &privateReductionVariables,
1613 DenseMap<Value, llvm::Value *> &reductionVariableMap,
1614 llvm::ArrayRef<bool> isByRef) {
1615 if (op.getNumReductionVars() == 0)
1616 return success();
1617
1618 SmallVector<DeferredStore> deferredStores;
1619
1620 if (failed(allocReductionVars(op, reductionArgs, builder, moduleTranslation,
1621 allocaIP, reductionDecls,
1622 privateReductionVariables, reductionVariableMap,
1623 deferredStores, isByRef)))
1624 return failure();
1625
1626 return initReductionVars(op, reductionArgs, builder, moduleTranslation,
1627 allocaIP.getBlock(), reductionDecls,
1628 privateReductionVariables, reductionVariableMap,
1629 isByRef, deferredStores);
1630}
1631
1632/// Return the llvm::Value * corresponding to the `privateVar` that
1633/// is being privatized. It isn't always as simple as looking up
1634/// moduleTranslation with privateVar. For instance, in case of
1635/// an allocatable, the descriptor for the allocatable is privatized.
1636/// This descriptor is mapped using an MapInfoOp. So, this function
1637/// will return a pointer to the llvm::Value corresponding to the
1638/// block argument for the mapped descriptor.
1639static llvm::Value *
1640findAssociatedValue(Value privateVar, llvm::IRBuilderBase &builder,
1641 LLVM::ModuleTranslation &moduleTranslation,
1642 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1643 if (mappedPrivateVars == nullptr || !mappedPrivateVars->contains(privateVar))
1644 return moduleTranslation.lookupValue(privateVar);
1645
1646 Value blockArg = (*mappedPrivateVars)[privateVar];
1647 Type privVarType = privateVar.getType();
1648 Type blockArgType = blockArg.getType();
1649 assert(isa<LLVM::LLVMPointerType>(blockArgType) &&
1650 "A block argument corresponding to a mapped var should have "
1651 "!llvm.ptr type");
1652
1653 if (privVarType == blockArgType)
1654 return moduleTranslation.lookupValue(blockArg);
1655
1656 // This typically happens when the privatized type is lowered from
1657 // boxchar<KIND> and gets lowered to !llvm.struct<(ptr, i64)>. That is the
1658 // struct/pair is passed by value. But, mapped values are passed only as
1659 // pointers, so before we privatize, we must load the pointer.
1660 if (!isa<LLVM::LLVMPointerType>(privVarType))
1661 return builder.CreateLoad(moduleTranslation.convertType(privVarType),
1662 moduleTranslation.lookupValue(blockArg));
1663
1664 return moduleTranslation.lookupValue(privateVar);
1665}
1666
1667/// Initialize a single (first)private variable. You probably want to use
1668/// allocateAndInitPrivateVars instead of this.
1669/// This returns the private variable which has been initialized. This
1670/// variable should be mapped before constructing the body of the Op.
1672initPrivateVar(llvm::IRBuilderBase &builder,
1673 LLVM::ModuleTranslation &moduleTranslation,
1674 omp::PrivateClauseOp &privDecl, llvm::Value *nonPrivateVar,
1675 BlockArgument &blockArg, llvm::Value *llvmPrivateVar,
1676 llvm::BasicBlock *privInitBlock,
1677 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1678 Region &initRegion = privDecl.getInitRegion();
1679 if (initRegion.empty())
1680 return llvmPrivateVar;
1681
1682 assert(nonPrivateVar);
1683 moduleTranslation.mapValue(privDecl.getInitMoldArg(), nonPrivateVar);
1684 moduleTranslation.mapValue(privDecl.getInitPrivateArg(), llvmPrivateVar);
1685
1686 // in-place convert the private initialization region
1688 if (failed(inlineConvertOmpRegions(initRegion, "omp.private.init", builder,
1689 moduleTranslation, &phis)))
1690 return llvm::createStringError(
1691 "failed to inline `init` region of `omp.private`");
1692
1693 assert(phis.size() == 1 && "expected one allocation to be yielded");
1694
1695 // clear init region block argument mapping in case it needs to be
1696 // re-created with a different source for another use of the same
1697 // reduction decl
1698 moduleTranslation.forgetMapping(initRegion);
1699
1700 // Prefer the value yielded from the init region to the allocated private
1701 // variable in case the region is operating on arguments by-value (e.g.
1702 // Fortran character boxes).
1703 return phis[0];
1704}
1705
1706/// Version of initPrivateVar which looks up the nonPrivateVar from mlirPrivVar.
1708 llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation,
1709 omp::PrivateClauseOp &privDecl, Value mlirPrivVar, BlockArgument &blockArg,
1710 llvm::Value *llvmPrivateVar, llvm::BasicBlock *privInitBlock,
1711 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1712 return initPrivateVar(
1713 builder, moduleTranslation, privDecl,
1714 findAssociatedValue(mlirPrivVar, builder, moduleTranslation,
1715 mappedPrivateVars),
1716 blockArg, llvmPrivateVar, privInitBlock, mappedPrivateVars);
1717}
1718
1719static llvm::Error
1720initPrivateVars(llvm::IRBuilderBase &builder,
1721 LLVM::ModuleTranslation &moduleTranslation,
1722 PrivateVarsInfo &privateVarsInfo,
1723 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1724 if (privateVarsInfo.blockArgs.empty())
1725 return llvm::Error::success();
1726
1727 llvm::BasicBlock *privInitBlock = splitBB(builder, true, "omp.private.init");
1728 setInsertPointForPossiblyEmptyBlock(builder, privInitBlock);
1729
1730 for (auto [idx, zip] : llvm::enumerate(llvm::zip_equal(
1731 privateVarsInfo.privatizers, privateVarsInfo.mlirVars,
1732 privateVarsInfo.blockArgs, privateVarsInfo.llvmVars))) {
1733 auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVar] = zip;
1735 builder, moduleTranslation, privDecl, mlirPrivVar, blockArg,
1736 llvmPrivateVar, privInitBlock, mappedPrivateVars);
1737
1738 if (!privVarOrErr)
1739 return privVarOrErr.takeError();
1740
1741 llvmPrivateVar = privVarOrErr.get();
1742 moduleTranslation.mapValue(blockArg, llvmPrivateVar);
1743
1745 }
1746
1747 return llvm::Error::success();
1748}
1749
1750/// Allocate and initialize delayed private variables. Returns the basic block
1751/// which comes after all of these allocations. llvm::Value * for each of these
1752/// private variables are populated in llvmPrivateVars.
1753template <typename T>
1755allocatePrivateVars(T op, llvm::IRBuilderBase &builder,
1756 LLVM::ModuleTranslation &moduleTranslation,
1757 PrivateVarsInfo &privateVarsInfo,
1758 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1759 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1760 // Allocate private vars
1761 llvm::Instruction *allocaTerminator = allocaIP.getBlock()->getTerminator();
1762 splitBB(llvm::OpenMPIRBuilder::InsertPointTy(allocaIP.getBlock(),
1763 allocaTerminator->getIterator()),
1764 true, allocaTerminator->getStableDebugLoc(),
1765 "omp.region.after_alloca");
1766
1767 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1768 // Update the allocaTerminator since the alloca block was split above.
1769 allocaTerminator = allocaIP.getBlock()->getTerminator();
1770 builder.SetInsertPoint(allocaTerminator);
1771 // The new terminator is an uncondition branch created by the splitBB above.
1772 assert(allocaTerminator->getNumSuccessors() == 1 &&
1773 "This is an unconditional branch created by splitBB");
1774
1775 llvm::DataLayout dataLayout = builder.GetInsertBlock()->getDataLayout();
1776 llvm::BasicBlock *afterAllocas = allocaTerminator->getSuccessor(0);
1777
1778 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
1779 bool mightUseDeviceSharedMem = omp::opInSharedDeviceContext(*op);
1780 unsigned int allocaAS =
1781 moduleTranslation.getLLVMModule()->getDataLayout().getAllocaAddrSpace();
1782 unsigned int defaultAS = moduleTranslation.getLLVMModule()
1783 ->getDataLayout()
1784 .getProgramAddressSpace();
1785
1786 for (auto [privDecl, mlirPrivVar, blockArg] :
1787 llvm::zip_equal(privateVarsInfo.privatizers, privateVarsInfo.mlirVars,
1788 privateVarsInfo.blockArgs)) {
1789 llvm::Type *llvmAllocType =
1790 moduleTranslation.convertType(privDecl.getType());
1791 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1792 llvm::Value *llvmPrivateVar = nullptr;
1793 if (mightUseDeviceSharedMem && omp::allocaUsesRequireSharedMem(blockArg)) {
1794 llvmPrivateVar = ompBuilder->createOMPAllocShared(builder, llvmAllocType);
1795 } else {
1796 llvmPrivateVar = builder.CreateAlloca(
1797 llvmAllocType, /*ArraySize=*/nullptr, "omp.private.alloc");
1798 if (allocaAS != defaultAS)
1799 llvmPrivateVar = builder.CreateAddrSpaceCast(
1800 llvmPrivateVar, builder.getPtrTy(defaultAS));
1801 }
1802
1803 privateVarsInfo.llvmVars.push_back(llvmPrivateVar);
1804 }
1805
1806 return afterAllocas;
1807}
1808
1809/// This can't always be determined statically, but when we can, it is good to
1810/// avoid generating compiler-added barriers which will deadlock the program.
1812 for (mlir::Operation *parent = op->getParentOp(); parent != nullptr;
1813 parent = parent->getParentOp()) {
1814 if (mlir::isa<omp::SingleOp, omp::CriticalOp>(parent))
1815 return true;
1816
1817 // e.g.
1818 // omp.single {
1819 // omp.parallel {
1820 // op
1821 // }
1822 // }
1823 if (mlir::isa<omp::ParallelOp>(parent))
1824 return false;
1825 }
1826 return false;
1827}
1828
1829static LogicalResult copyFirstPrivateVars(
1830 mlir::Operation *op, llvm::IRBuilderBase &builder,
1831 LLVM::ModuleTranslation &moduleTranslation,
1833 ArrayRef<llvm::Value *> llvmPrivateVars,
1834 SmallVectorImpl<omp::PrivateClauseOp> &privateDecls, bool insertBarrier,
1835 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1836 // Apply copy region for firstprivate.
1837 bool needsFirstprivate =
1838 llvm::any_of(privateDecls, [](omp::PrivateClauseOp &privOp) {
1839 return privOp.getDataSharingType() ==
1840 omp::DataSharingClauseType::FirstPrivate;
1841 });
1842
1843 if (!needsFirstprivate)
1844 return success();
1845
1846 llvm::BasicBlock *copyBlock =
1847 splitBB(builder, /*CreateBranch=*/true, "omp.private.copy");
1848 setInsertPointForPossiblyEmptyBlock(builder, copyBlock);
1849
1850 for (auto [decl, moldVar, llvmVar] :
1851 llvm::zip_equal(privateDecls, moldVars, llvmPrivateVars)) {
1852 if (decl.getDataSharingType() != omp::DataSharingClauseType::FirstPrivate)
1853 continue;
1854
1855 // copyRegion implements `lhs = rhs`
1856 Region &copyRegion = decl.getCopyRegion();
1857
1858 moduleTranslation.mapValue(decl.getCopyMoldArg(), moldVar);
1859
1860 // map copyRegion lhs arg
1861 moduleTranslation.mapValue(decl.getCopyPrivateArg(), llvmVar);
1862
1863 // in-place convert copy region
1864 if (failed(inlineConvertOmpRegions(copyRegion, "omp.private.copy", builder,
1865 moduleTranslation)))
1866 return decl.emitError("failed to inline `copy` region of `omp.private`");
1867
1869
1870 // ignore unused value yielded from copy region
1871
1872 // clear copy region block argument mapping in case it needs to be
1873 // re-created with different sources for reuse of the same reduction
1874 // decl
1875 moduleTranslation.forgetMapping(copyRegion);
1876 }
1877
1878 if (insertBarrier && !opIsInSingleThread(op)) {
1879 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
1880 llvm::OpenMPIRBuilder::InsertPointOrErrorTy res =
1881 ompBuilder->createBarrier(builder.saveIP(), llvm::omp::OMPD_barrier);
1882 if (failed(handleError(res, *op)))
1883 return failure();
1884 }
1885
1886 return success();
1887}
1888
1889static LogicalResult copyFirstPrivateVars(
1890 mlir::Operation *op, llvm::IRBuilderBase &builder,
1891 LLVM::ModuleTranslation &moduleTranslation,
1892 SmallVectorImpl<mlir::Value> &mlirPrivateVars,
1893 ArrayRef<llvm::Value *> llvmPrivateVars,
1894 SmallVectorImpl<omp::PrivateClauseOp> &privateDecls, bool insertBarrier,
1895 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1896 llvm::SmallVector<llvm::Value *> moldVars(mlirPrivateVars.size());
1897 llvm::transform(mlirPrivateVars, moldVars.begin(), [&](mlir::Value mlirVar) {
1898 // map copyRegion rhs arg
1899 llvm::Value *moldVar = findAssociatedValue(
1900 mlirVar, builder, moduleTranslation, mappedPrivateVars);
1901 assert(moldVar);
1902 return moldVar;
1903 });
1904 return copyFirstPrivateVars(op, builder, moduleTranslation, moldVars,
1905 llvmPrivateVars, privateDecls, insertBarrier,
1906 mappedPrivateVars);
1907}
1908
1909template <typename T>
1910static LogicalResult
1911cleanupPrivateVars(T op, llvm::IRBuilderBase &builder,
1912 LLVM::ModuleTranslation &moduleTranslation, Location loc,
1913 PrivateVarsInfo &privateVarsInfo) {
1914 // private variable deallocation
1915 SmallVector<Region *> privateCleanupRegions;
1916 llvm::transform(privateVarsInfo.privatizers,
1917 std::back_inserter(privateCleanupRegions),
1918 [](omp::PrivateClauseOp privatizer) {
1919 return &privatizer.getDeallocRegion();
1920 });
1921
1922 if (failed(inlineOmpRegionCleanup(privateCleanupRegions,
1923 privateVarsInfo.llvmVars, moduleTranslation,
1924 builder, "omp.private.dealloc",
1925 /*shouldLoadCleanupRegionArg=*/false)))
1926 return mlir::emitError(loc, "failed to inline `dealloc` region of an "
1927 "`omp.private` op in");
1928
1929 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
1930 bool mightUseDeviceSharedMem = omp::opInSharedDeviceContext(*op);
1931 for (auto [privDecl, llvmPrivVar, blockArg] :
1932 llvm::zip_equal(privateVarsInfo.privatizers, privateVarsInfo.llvmVars,
1933 privateVarsInfo.blockArgs)) {
1934 if (mightUseDeviceSharedMem && omp::allocaUsesRequireSharedMem(blockArg)) {
1935 ompBuilder->createOMPFreeShared(
1936 builder, llvmPrivVar,
1937 moduleTranslation.convertType(privDecl.getType()));
1938 }
1939 }
1940
1941 return success();
1942}
1943
1944/// Returns true if the construct contains omp.cancel or omp.cancellation_point
1946 // omp.cancel and omp.cancellation_point must be "closely nested" so they will
1947 // be visible and not inside of function calls. This is enforced by the
1948 // verifier.
1949 return op
1950 ->walk([](Operation *child) {
1951 if (mlir::isa<omp::CancelOp, omp::CancellationPointOp>(child))
1952 return WalkResult::interrupt();
1953 return WalkResult::advance();
1954 })
1955 .wasInterrupted();
1956}
1957
1958static LogicalResult
1959convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder,
1960 LLVM::ModuleTranslation &moduleTranslation) {
1961 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1962 using StorableBodyGenCallbackTy =
1963 llvm::OpenMPIRBuilder::StorableBodyGenCallbackTy;
1964
1965 auto sectionsOp = cast<omp::SectionsOp>(opInst);
1966
1967 if (failed(checkImplementationStatus(opInst)))
1968 return failure();
1969
1970 llvm::ArrayRef<bool> isByRef = getIsByRef(sectionsOp.getReductionByref());
1971 assert(isByRef.size() == sectionsOp.getNumReductionVars());
1972
1974 collectReductionDecls(sectionsOp, reductionDecls);
1975 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1976 findAllocInsertPoints(builder, moduleTranslation);
1977
1978 SmallVector<llvm::Value *> privateReductionVariables(
1979 sectionsOp.getNumReductionVars());
1980 DenseMap<Value, llvm::Value *> reductionVariableMap;
1981
1982 MutableArrayRef<BlockArgument> reductionArgs =
1983 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
1984
1986 sectionsOp, reductionArgs, builder, moduleTranslation, allocaIP,
1987 reductionDecls, privateReductionVariables, reductionVariableMap,
1988 isByRef)))
1989 return failure();
1990
1992
1993 for (Operation &op : *sectionsOp.getRegion().begin()) {
1994 auto sectionOp = dyn_cast<omp::SectionOp>(op);
1995 if (!sectionOp) // omp.terminator
1996 continue;
1997
1998 Region &region = sectionOp.getRegion();
1999 auto sectionCB = [&sectionsOp, &region, &builder, &moduleTranslation](
2000 InsertPointTy allocaIP, InsertPointTy codeGenIP,
2001 ArrayRef<llvm::BasicBlock *> deallocBlocks) {
2002 builder.restoreIP(codeGenIP);
2003
2004 // map the omp.section reduction block argument to the omp.sections block
2005 // arguments
2006 // TODO: this assumes that the only block arguments are reduction
2007 // variables
2008 assert(region.getNumArguments() ==
2009 sectionsOp.getRegion().getNumArguments());
2010 for (auto [sectionsArg, sectionArg] : llvm::zip_equal(
2011 sectionsOp.getRegion().getArguments(), region.getArguments())) {
2012 llvm::Value *llvmVal = moduleTranslation.lookupValue(sectionsArg);
2013 assert(llvmVal);
2014 moduleTranslation.mapValue(sectionArg, llvmVal);
2015 }
2016
2017 return convertOmpOpRegions(region, "omp.section.region", builder,
2018 moduleTranslation)
2019 .takeError();
2020 };
2021 sectionCBs.push_back(sectionCB);
2022 }
2023
2024 // No sections within omp.sections operation - skip generation. This situation
2025 // is only possible if there is only a terminator operation inside the
2026 // sections operation
2027 if (sectionCBs.empty())
2028 return success();
2029
2030 assert(isa<omp::SectionOp>(*sectionsOp.getRegion().op_begin()));
2031
2032 // TODO: Perform appropriate actions according to the data-sharing
2033 // attribute (shared, private, firstprivate, ...) of variables.
2034 // Currently defaults to shared.
2035 auto privCB = [&](InsertPointTy, InsertPointTy codeGenIP, llvm::Value &,
2036 llvm::Value &vPtr, llvm::Value *&replacementValue)
2037 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
2038 replacementValue = &vPtr;
2039 return codeGenIP;
2040 };
2041
2042 // TODO: Perform finalization actions for variables. This has to be
2043 // called for variables which have destructors/finalizers.
2044 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
2045
2046 allocaIP = findAllocInsertPoints(builder, moduleTranslation);
2047 bool isCancellable = constructIsCancellable(sectionsOp);
2048 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2049 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2050 moduleTranslation.getOpenMPBuilder()->createSections(
2051 ompLoc, allocaIP, sectionCBs, privCB, finiCB, isCancellable,
2052 sectionsOp.getNowait());
2053
2054 if (failed(handleError(afterIP, opInst)))
2055 return failure();
2056
2057 builder.restoreIP(*afterIP);
2058
2059 // Process the reductions if required.
2061 sectionsOp, builder, moduleTranslation, allocaIP, reductionDecls,
2062 privateReductionVariables, isByRef, sectionsOp.getNowait());
2063}
2064
2065/// Converts an OpenMP scope construct into LLVM IR.
2066static LogicalResult
2067convertOmpScope(omp::ScopeOp &scopeOp, llvm::IRBuilderBase &builder,
2068 LLVM::ModuleTranslation &moduleTranslation) {
2069 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2070 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
2071
2072 if (failed(checkImplementationStatus(*scopeOp)))
2073 return failure();
2074
2075 llvm::ArrayRef<bool> isByRef = getIsByRef(scopeOp.getReductionByref());
2076 assert(isByRef.size() == scopeOp.getNumReductionVars());
2077
2078 PrivateVarsInfo privateVarsInfo(scopeOp);
2079
2081 collectReductionDecls(scopeOp, reductionDecls);
2082 InsertPointTy allocaIP = findAllocInsertPoints(builder, moduleTranslation);
2083
2084 SmallVector<llvm::Value *> privateReductionVariables(
2085 scopeOp.getNumReductionVars());
2086 DenseMap<Value, llvm::Value *> reductionVariableMap;
2087
2088 MutableArrayRef<BlockArgument> reductionArgs =
2089 cast<omp::BlockArgOpenMPOpInterface>(*scopeOp).getReductionBlockArgs();
2090
2091 // Allocate private vars before the scope body
2093 scopeOp, builder, moduleTranslation, privateVarsInfo, allocaIP);
2094 if (failed(handleError(afterAllocas, *scopeOp)))
2095 return failure();
2096
2098 scopeOp, reductionArgs, builder, moduleTranslation, allocaIP,
2099 reductionDecls, privateReductionVariables, reductionVariableMap,
2100 isByRef)))
2101 return failure();
2102
2103 auto bodyCB =
2104 [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
2105 llvm::ArrayRef<llvm::BasicBlock *> deallocBlocks) -> llvm::Error {
2106 builder.restoreIP(codeGenIP);
2107
2108 if (handleError(
2109 initPrivateVars(builder, moduleTranslation, privateVarsInfo),
2110 *scopeOp)
2111 .failed())
2112 return llvm::make_error<PreviouslyReportedError>();
2113
2114 if (failed(copyFirstPrivateVars(
2115 scopeOp, builder, moduleTranslation, privateVarsInfo.mlirVars,
2116 privateVarsInfo.llvmVars, privateVarsInfo.privatizers,
2117 scopeOp.getPrivateNeedsBarrier())))
2118 return llvm::make_error<PreviouslyReportedError>();
2119
2120 return convertOmpOpRegions(scopeOp.getRegion(), "omp.scope.region", builder,
2121 moduleTranslation)
2122 .takeError();
2123 };
2124
2125 auto finiCB = [&](InsertPointTy codeGenIP) -> llvm::Error {
2126 InsertPointTy oldIP = builder.saveIP();
2127 builder.restoreIP(codeGenIP);
2128 if (failed(cleanupPrivateVars(scopeOp, builder, moduleTranslation,
2129 scopeOp.getLoc(), privateVarsInfo)))
2130 return llvm::make_error<PreviouslyReportedError>();
2131 builder.restoreIP(oldIP);
2132 return llvm::Error::success();
2133 };
2134
2135 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2136 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2137 ompBuilder->createScope(ompLoc, bodyCB, finiCB, scopeOp.getNowait());
2138
2139 if (failed(handleError(afterIP, *scopeOp)))
2140 return failure();
2141
2142 builder.restoreIP(*afterIP);
2143
2144 // Process the reductions if required.
2146 scopeOp, builder, moduleTranslation, allocaIP, reductionDecls,
2147 privateReductionVariables, isByRef, scopeOp.getNowait(),
2148 /*isTeamsReduction=*/false);
2149}
2150
2151/// Converts an OpenMP single construct into LLVM IR using OpenMPIRBuilder.
2152static LogicalResult
2153convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder,
2154 LLVM::ModuleTranslation &moduleTranslation) {
2155 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2156 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2157
2158 if (failed(checkImplementationStatus(*singleOp)))
2159 return failure();
2160
2161 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP,
2162 llvm::ArrayRef<llvm::BasicBlock *> deallocBlocks) {
2163 builder.restoreIP(codegenIP);
2164 return convertOmpOpRegions(singleOp.getRegion(), "omp.single.region",
2165 builder, moduleTranslation)
2166 .takeError();
2167 };
2168 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
2169
2170 // Handle copyprivate
2171 Operation::operand_range cpVars = singleOp.getCopyprivateVars();
2172 std::optional<ArrayAttr> cpFuncs = singleOp.getCopyprivateSyms();
2175 for (size_t i = 0, e = cpVars.size(); i < e; ++i) {
2176 llvmCPVars.push_back(moduleTranslation.lookupValue(cpVars[i]));
2178 singleOp, cast<SymbolRefAttr>((*cpFuncs)[i]));
2179 llvmCPFuncs.push_back(
2180 moduleTranslation.lookupFunction(llvmFuncOp.getName()));
2181 }
2182
2183 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2184 moduleTranslation.getOpenMPBuilder()->createSingle(
2185 ompLoc, bodyCB, finiCB, singleOp.getNowait(), llvmCPVars,
2186 llvmCPFuncs);
2187
2188 if (failed(handleError(afterIP, *singleOp)))
2189 return failure();
2190
2191 builder.restoreIP(*afterIP);
2192 return success();
2193}
2194
2195static omp::DistributeOp
2197 // Early return if we found more than one distribute op or if we can't find
2198 // any distribute op in the teams region.
2199 omp::DistributeOp distOp;
2200 WalkResult walk = teamsOp.getRegion().walk([&](omp::DistributeOp op) {
2201 if (distOp)
2202 return WalkResult::interrupt();
2203 distOp = op;
2204 return WalkResult::skip();
2205 });
2206 if (walk.wasInterrupted() || !distOp)
2207 return {};
2208
2209 auto iface =
2210 llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(teamsOp.getOperation());
2211 // Check that all uses of the reduction block arg has the same distribute op
2212 // parent.
2214 for (auto ra : iface.getReductionBlockArgs())
2215 for (auto &use : ra.getUses()) {
2216 auto *useOp = use.getOwner();
2217 // Ignore debug uses.
2218 if (mlir::isa<LLVM::DbgDeclareOp, LLVM::DbgValueOp>(useOp)) {
2219 debugUses.push_back(useOp);
2220 continue;
2221 }
2222 if (!distOp->isProperAncestor(useOp))
2223 return {};
2224 }
2225
2226 // If we are going to use distribute reduction then remove any debug uses of
2227 // the reduction parameters in teamsOp. Otherwise they will be left without
2228 // any mapped value in moduleTranslation and will eventually error out.
2229 for (auto *use : debugUses)
2230 use->erase();
2231 return distOp;
2232}
2233
2234// Convert an OpenMP Teams construct to LLVM IR using OpenMPIRBuilder
2235static LogicalResult
2236convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder,
2237 LLVM::ModuleTranslation &moduleTranslation) {
2238 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2239 if (failed(checkImplementationStatus(*op)))
2240 return failure();
2241
2242 DenseMap<Value, llvm::Value *> reductionVariableMap;
2243 unsigned numReductionVars = op.getNumReductionVars();
2245 SmallVector<llvm::Value *> privateReductionVariables(numReductionVars);
2246 llvm::ArrayRef<bool> isByRef;
2247 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2248 findAllocInsertPoints(builder, moduleTranslation);
2249
2250 // Only do teams reduction if there is no distribute op that captures the
2251 // reduction instead.
2252 bool doTeamsReduction = !getDistributeCapturingTeamsReduction(op);
2253 if (doTeamsReduction) {
2254 isByRef = getIsByRef(op.getReductionByref());
2255
2256 assert(isByRef.size() == op.getNumReductionVars());
2257
2258 MutableArrayRef<BlockArgument> reductionArgs =
2259 llvm::cast<omp::BlockArgOpenMPOpInterface>(*op).getReductionBlockArgs();
2260
2261 collectReductionDecls(op, reductionDecls);
2262
2264 op, reductionArgs, builder, moduleTranslation, allocaIP,
2265 reductionDecls, privateReductionVariables, reductionVariableMap,
2266 isByRef)))
2267 return failure();
2268 }
2269
2270 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP,
2271 llvm::ArrayRef<llvm::BasicBlock *> deallocBlocks) {
2273 moduleTranslation, allocaIP, deallocBlocks);
2274 builder.restoreIP(codegenIP);
2275 return convertOmpOpRegions(op.getRegion(), "omp.teams.region", builder,
2276 moduleTranslation)
2277 .takeError();
2278 };
2279
2280 llvm::Value *numTeamsLower = nullptr;
2281 if (Value numTeamsLowerVar = op.getNumTeamsLower())
2282 numTeamsLower = moduleTranslation.lookupValue(numTeamsLowerVar);
2283
2284 llvm::Value *numTeamsUpper = nullptr;
2285 if (!op.getNumTeamsUpperVars().empty())
2286 numTeamsUpper = moduleTranslation.lookupValue(op.getNumTeams(0));
2287
2288 llvm::Value *threadLimit = nullptr;
2289 if (!op.getThreadLimitVars().empty())
2290 threadLimit = moduleTranslation.lookupValue(op.getThreadLimit(0));
2291
2292 llvm::Value *ifExpr = nullptr;
2293 if (Value ifVar = op.getIfExpr())
2294 ifExpr = moduleTranslation.lookupValue(ifVar);
2295
2296 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2297 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2298 moduleTranslation.getOpenMPBuilder()->createTeams(
2299 ompLoc, bodyCB, numTeamsLower, numTeamsUpper, threadLimit, ifExpr);
2300
2301 if (failed(handleError(afterIP, *op)))
2302 return failure();
2303
2304 builder.restoreIP(*afterIP);
2305 if (doTeamsReduction) {
2306 // Process the reductions if required.
2308 op, builder, moduleTranslation, allocaIP, reductionDecls,
2309 privateReductionVariables, isByRef,
2310 /*isNoWait*/ false, /*isTeamsReduction*/ true);
2311 }
2312 return success();
2313}
2314
2315static llvm::omp::RTLDependenceKindTy
2316convertDependKind(mlir::omp::ClauseTaskDepend kind) {
2317 switch (kind) {
2318 case mlir::omp::ClauseTaskDepend::taskdependin:
2319 return llvm::omp::RTLDependenceKindTy::DepIn;
2320 // The OpenMP runtime requires that the codegen for 'depend' clause for
2321 // 'out' dependency kind must be the same as codegen for 'depend' clause
2322 // with 'inout' dependency.
2323 case mlir::omp::ClauseTaskDepend::taskdependout:
2324 case mlir::omp::ClauseTaskDepend::taskdependinout:
2325 return llvm::omp::RTLDependenceKindTy::DepInOut;
2326 case mlir::omp::ClauseTaskDepend::taskdependmutexinoutset:
2327 return llvm::omp::RTLDependenceKindTy::DepMutexInOutSet;
2328 case mlir::omp::ClauseTaskDepend::taskdependinoutset:
2329 return llvm::omp::RTLDependenceKindTy::DepInOutSet;
2330 }
2331 llvm_unreachable("unhandled depend kind");
2332}
2333
2335 std::optional<ArrayAttr> dependKinds, OperandRange dependVars,
2336 LLVM::ModuleTranslation &moduleTranslation,
2338 if (dependVars.empty())
2339 return;
2340 for (auto dep : llvm::zip(dependVars, dependKinds->getValue())) {
2341 auto kind =
2342 cast<mlir::omp::ClauseTaskDependAttr>(std::get<1>(dep)).getValue();
2343 llvm::omp::RTLDependenceKindTy type = convertDependKind(kind);
2344 llvm::Value *depVal = moduleTranslation.lookupValue(std::get<0>(dep));
2345 llvm::OpenMPIRBuilder::DependData dd(type, depVal->getType(), depVal);
2346 dds.emplace_back(dd);
2347 }
2348}
2349
2350/// Shared implementation of a callback which adds a termiator for the new block
2351/// created for the branch taken when an openmp construct is cancelled. The
2352/// terminator is saved in \p cancelTerminators. This callback is invoked only
2353/// if there is cancellation inside of the taskgroup body.
2354/// The terminator will need to be fixed to branch to the correct block to
2355/// cleanup the construct.
2357 SmallVectorImpl<llvm::UncondBrInst *> &cancelTerminators,
2358 llvm::IRBuilderBase &llvmBuilder, llvm::OpenMPIRBuilder &ompBuilder,
2359 mlir::Operation *op, llvm::omp::Directive cancelDirective) {
2360 auto finiCB = [&](llvm::OpenMPIRBuilder::InsertPointTy ip) -> llvm::Error {
2361 llvm::IRBuilderBase::InsertPointGuard guard(llvmBuilder);
2362
2363 // ip is currently in the block branched to if cancellation occurred.
2364 // We need to create a branch to terminate that block.
2365 llvmBuilder.restoreIP(ip);
2366
2367 // We must still clean up the construct after cancelling it, so we need to
2368 // branch to the block that finalizes the taskgroup.
2369 // That block has not been created yet so use this block as a dummy for now
2370 // and fix this after creating the operation.
2371 cancelTerminators.push_back(llvmBuilder.CreateBr(ip.getBlock()));
2372 return llvm::Error::success();
2373 };
2374 // We have to add the cleanup to the OpenMPIRBuilder before the body gets
2375 // created in case the body contains omp.cancel (which will then expect to be
2376 // able to find this cleanup callback).
2377 ompBuilder.pushFinalizationCB(
2378 {finiCB, cancelDirective, constructIsCancellable(op)});
2379}
2380
2381/// If we cancelled the construct, we should branch to the finalization block of
2382/// that construct. OMPIRBuilder structures the CFG such that the cleanup block
2383/// is immediately before the continuation block. Now this finalization has
2384/// been created we can fix the branch.
2385static void
2387 llvm::OpenMPIRBuilder &ompBuilder,
2388 const llvm::OpenMPIRBuilder::InsertPointTy &afterIP) {
2389 ompBuilder.popFinalizationCB();
2390 llvm::BasicBlock *constructFini = afterIP.getBlock()->getSinglePredecessor();
2391 for (llvm::UncondBrInst *cancelBranch : cancelTerminators)
2392 cancelBranch->setSuccessor(constructFini);
2393}
2394
2395namespace {
2396/// TaskContextStructManager takes care of creating and freeing a structure
2397/// containing information needed by the task body to execute.
2398class TaskContextStructManager {
2399public:
2400 TaskContextStructManager(llvm::IRBuilderBase &builder,
2401 LLVM::ModuleTranslation &moduleTranslation,
2402 MutableArrayRef<omp::PrivateClauseOp> privateDecls)
2403 : builder{builder}, moduleTranslation{moduleTranslation},
2404 privateDecls{privateDecls} {}
2405
2406 /// Creates a heap allocated struct containing space for each private
2407 /// variable. Invariant: privateVarTypes, privateDecls, and the elements of
2408 /// the structure should all have the same order (although privateDecls which
2409 /// do not read from the mold argument are skipped).
2410 void generateTaskContextStruct();
2411
2412 /// Create GEPs to access each member of the structure representing a private
2413 /// variable, adding them to llvmPrivateVars. Null values are added where
2414 /// private decls were skipped so that the ordering continues to match the
2415 /// private decls.
2416 void createGEPsToPrivateVars();
2417
2418 /// Given the address of the structure, return a GEP for each private variable
2419 /// in the structure. Null values are added where private decls were skipped
2420 /// so that the ordering continues to match the private decls.
2421 /// Must be called after generateTaskContextStruct().
2422 SmallVector<llvm::Value *>
2423 createGEPsToPrivateVars(llvm::Value *altStructPtr) const;
2424
2425 /// De-allocate the task context structure.
2426 void freeStructPtr();
2427
2428 MutableArrayRef<llvm::Value *> getLLVMPrivateVarGEPs() {
2429 return llvmPrivateVarGEPs;
2430 }
2431
2432 llvm::Value *getStructPtr() { return structPtr; }
2433
2434private:
2435 llvm::IRBuilderBase &builder;
2436 LLVM::ModuleTranslation &moduleTranslation;
2437 MutableArrayRef<omp::PrivateClauseOp> privateDecls;
2438
2439 /// The type of each member of the structure, in order.
2440 SmallVector<llvm::Type *> privateVarTypes;
2441
2442 /// LLVM values for each private variable, or null if that private variable is
2443 /// not included in the task context structure
2444 SmallVector<llvm::Value *> llvmPrivateVarGEPs;
2445
2446 /// A pointer to the structure containing context for this task.
2447 llvm::Value *structPtr = nullptr;
2448 /// The type of the structure
2449 llvm::Type *structTy = nullptr;
2450};
2451
2452/// IteratorInfo extracts and prepares loop bounds information from an
2453/// mlir::omp::IteratorOp for lowering to LLVM IR.
2454///
2455/// It computes the per-dimension trip counts and the total linearized trip
2456/// count, casted to i64. These are used to build a canonical loop and to
2457/// reconstruct the physical induction variables inside the loop body.
2458class IteratorInfo {
2459private:
2460 llvm::SmallVector<llvm::Value *> lowerBounds;
2461 llvm::SmallVector<llvm::Value *> upperBounds;
2462 llvm::SmallVector<llvm::Value *> steps;
2463 llvm::SmallVector<llvm::Value *> trips;
2464 unsigned dims;
2465 llvm::Value *totalTrips;
2466
2467 llvm::Value *lookUpAsI64(mlir::Value val, const LLVM::ModuleTranslation &mt,
2468 llvm::IRBuilderBase &builder) {
2469 llvm::Value *v = mt.lookupValue(val);
2470 if (!v)
2471 return nullptr;
2472 if (v->getType()->isIntegerTy(64))
2473 return v;
2474 if (v->getType()->isIntegerTy())
2475 return builder.CreateSExtOrTrunc(v, builder.getInt64Ty());
2476 return nullptr;
2477 }
2478
2479public:
2480 IteratorInfo(mlir::omp::IteratorOp itersOp,
2481 mlir::LLVM::ModuleTranslation &moduleTranslation,
2482 llvm::IRBuilderBase &builder) {
2483 dims = itersOp.getLoopLowerBounds().size();
2484 lowerBounds.resize(dims);
2485 upperBounds.resize(dims);
2486 steps.resize(dims);
2487 trips.resize(dims);
2488
2489 for (unsigned d = 0; d < dims; ++d) {
2490 llvm::Value *lb = lookUpAsI64(itersOp.getLoopLowerBounds()[d],
2491 moduleTranslation, builder);
2492 llvm::Value *ub = lookUpAsI64(itersOp.getLoopUpperBounds()[d],
2493 moduleTranslation, builder);
2494 llvm::Value *st =
2495 lookUpAsI64(itersOp.getLoopSteps()[d], moduleTranslation, builder);
2496 assert(lb && ub && st &&
2497 "Expect lowerBounds, upperBounds, and steps in IteratorOp");
2498 assert((!llvm::isa<llvm::ConstantInt>(st) ||
2499 !llvm::cast<llvm::ConstantInt>(st)->isZero()) &&
2500 "Expect non-zero step in IteratorOp");
2501
2502 lowerBounds[d] = lb;
2503 upperBounds[d] = ub;
2504 steps[d] = st;
2505
2506 // trips = ((ub - lb) / step) + 1 (inclusive ub, assume positive step)
2507 llvm::Value *diff = builder.CreateSub(ub, lb);
2508 llvm::Value *div = builder.CreateSDiv(diff, st);
2509 trips[d] = builder.CreateAdd(
2510 div, llvm::ConstantInt::get(builder.getInt64Ty(), 1));
2511 }
2512
2513 totalTrips = llvm::ConstantInt::get(builder.getInt64Ty(), 1);
2514 for (unsigned d = 0; d < dims; ++d)
2515 totalTrips = builder.CreateMul(totalTrips, trips[d]);
2516 }
2517
2518 unsigned getDims() const { return dims; }
2519 llvm::ArrayRef<llvm::Value *> getLowerBounds() const { return lowerBounds; }
2520 llvm::ArrayRef<llvm::Value *> getUpperBounds() const { return upperBounds; }
2521 llvm::ArrayRef<llvm::Value *> getSteps() const { return steps; }
2522 llvm::ArrayRef<llvm::Value *> getTrips() const { return trips; }
2523 llvm::Value *getTotalTrips() const { return totalTrips; }
2524};
2525
2526} // namespace
2527
2528void TaskContextStructManager::generateTaskContextStruct() {
2529 if (privateDecls.empty())
2530 return;
2531 privateVarTypes.reserve(privateDecls.size());
2532
2533 for (omp::PrivateClauseOp &privOp : privateDecls) {
2534 // Skip private variables which can safely be allocated and initialised
2535 // inside of the task
2536 if (!privOp.readsFromMold())
2537 continue;
2538 Type mlirType = privOp.getType();
2539 privateVarTypes.push_back(moduleTranslation.convertType(mlirType));
2540 }
2541
2542 if (privateVarTypes.empty())
2543 return;
2544
2545 structTy = llvm::StructType::get(moduleTranslation.getLLVMContext(),
2546 privateVarTypes);
2547
2548 llvm::DataLayout dataLayout =
2549 builder.GetInsertBlock()->getModule()->getDataLayout();
2550 llvm::Type *intPtrTy = builder.getIntPtrTy(dataLayout);
2551 llvm::Constant *allocSize = llvm::ConstantExpr::getSizeOf(structTy);
2552
2553 // Heap allocate the structure
2554 structPtr = builder.CreateMalloc(intPtrTy, structTy, allocSize,
2555 /*ArraySize=*/nullptr, /*MallocF=*/nullptr,
2556 "omp.task.context_ptr");
2557}
2558
2559SmallVector<llvm::Value *> TaskContextStructManager::createGEPsToPrivateVars(
2560 llvm::Value *altStructPtr) const {
2561 SmallVector<llvm::Value *> ret;
2562
2563 // Create GEPs for each struct member
2564 ret.reserve(privateDecls.size());
2565 llvm::Value *zero = builder.getInt32(0);
2566 unsigned i = 0;
2567 for (auto privDecl : privateDecls) {
2568 if (!privDecl.readsFromMold()) {
2569 // Handle this inside of the task so we don't pass unnessecary vars in
2570 ret.push_back(nullptr);
2571 continue;
2572 }
2573 llvm::Value *iVal = builder.getInt32(i);
2574 llvm::Value *gep = builder.CreateGEP(structTy, altStructPtr, {zero, iVal});
2575 ret.push_back(gep);
2576 i += 1;
2577 }
2578 return ret;
2579}
2580
2581void TaskContextStructManager::createGEPsToPrivateVars() {
2582 if (!structPtr)
2583 assert(privateVarTypes.empty());
2584 // Still need to run createGEPsToPrivateVars to populate llvmPrivateVarGEPs
2585 // with null values for skipped private decls
2586
2587 llvmPrivateVarGEPs = createGEPsToPrivateVars(structPtr);
2588}
2589
2590void TaskContextStructManager::freeStructPtr() {
2591 if (!structPtr)
2592 return;
2593
2594 llvm::IRBuilderBase::InsertPointGuard guard{builder};
2595 // Ensure we don't put the call to free() after the terminator
2596 builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
2597 builder.CreateFree(structPtr);
2598}
2599
2600static void storeAffinityEntry(llvm::IRBuilderBase &builder,
2601 llvm::OpenMPIRBuilder &ompBuilder,
2602 llvm::Value *affinityList, llvm::Value *index,
2603 llvm::Value *addr, llvm::Value *len) {
2604 llvm::StructType *kmpTaskAffinityInfoTy =
2605 ompBuilder.getKmpTaskAffinityInfoTy();
2606 llvm::Value *entry = builder.CreateInBoundsGEP(
2607 kmpTaskAffinityInfoTy, affinityList, index, "omp.affinity.entry");
2608
2609 addr = builder.CreatePtrToInt(addr, kmpTaskAffinityInfoTy->getElementType(0));
2610 len = builder.CreateIntCast(len, kmpTaskAffinityInfoTy->getElementType(1),
2611 /*isSigned=*/false);
2612 llvm::Value *flags = builder.getInt32(0);
2613
2614 builder.CreateStore(addr,
2615 builder.CreateStructGEP(kmpTaskAffinityInfoTy, entry, 0));
2616 builder.CreateStore(len,
2617 builder.CreateStructGEP(kmpTaskAffinityInfoTy, entry, 1));
2618 builder.CreateStore(flags,
2619 builder.CreateStructGEP(kmpTaskAffinityInfoTy, entry, 2));
2620}
2621
2623 llvm::IRBuilderBase &builder,
2624 LLVM::ModuleTranslation &moduleTranslation,
2625 llvm::Value *affinityList) {
2626 for (auto [i, affinityVar] : llvm::enumerate(affinityVars)) {
2627 auto entryOp = affinityVar.getDefiningOp<mlir::omp::AffinityEntryOp>();
2628 assert(entryOp && "affinity item must be omp.affinity_entry");
2629
2630 llvm::Value *addr = moduleTranslation.lookupValue(entryOp.getAddr());
2631 llvm::Value *len = moduleTranslation.lookupValue(entryOp.getLen());
2632 assert(addr && len && "expect affinity addr and len to be non-null");
2633 storeAffinityEntry(builder, *moduleTranslation.getOpenMPBuilder(),
2634 affinityList, builder.getInt64(i), addr, len);
2635 }
2636}
2637
2638static mlir::LogicalResult
2639convertIteratorRegion(llvm::Value *linearIV, IteratorInfo &iterInfo,
2640 mlir::Block &iteratorRegionBlock,
2641 llvm::IRBuilderBase &builder,
2642 LLVM::ModuleTranslation &moduleTranslation) {
2643 llvm::Value *tmp = linearIV;
2644 for (int d = (int)iterInfo.getDims() - 1; d >= 0; --d) {
2645 llvm::Value *trip = iterInfo.getTrips()[d];
2646 // idx_d = tmp % trip_d
2647 llvm::Value *idx = builder.CreateURem(tmp, trip);
2648 // tmp = tmp / trip_d
2649 tmp = builder.CreateUDiv(tmp, trip);
2650
2651 // physIV_d = lb_d + idx_d * step_d
2652 llvm::Value *physIV = builder.CreateAdd(
2653 iterInfo.getLowerBounds()[d],
2654 builder.CreateMul(idx, iterInfo.getSteps()[d]), "omp.it.phys_iv");
2655
2656 moduleTranslation.mapValue(iteratorRegionBlock.getArgument(d), physIV);
2657 }
2658
2659 // Translate the iterator region into the loop body.
2660 moduleTranslation.mapBlock(&iteratorRegionBlock, builder.GetInsertBlock());
2661 if (mlir::failed(moduleTranslation.convertBlock(iteratorRegionBlock,
2662 /*ignoreArguments=*/true,
2663 builder))) {
2664 return mlir::failure();
2665 }
2666 return mlir::success();
2667}
2668
2670 llvm::function_ref<void(llvm::Value *linearIV, mlir::omp::YieldOp yield)>;
2671
2672static mlir::LogicalResult
2673fillIteratorLoop(mlir::omp::IteratorOp itersOp, llvm::IRBuilderBase &builder,
2674 mlir::LLVM::ModuleTranslation &moduleTranslation,
2675 IteratorInfo &iterInfo, llvm::StringRef loopName,
2676 IteratorStoreEntryTy genStoreEntry) {
2677 mlir::Region &itersRegion = itersOp.getRegion();
2678 mlir::Block &iteratorRegionBlock = itersRegion.front();
2679
2680 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
2681
2682 auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy bodyIP,
2683 llvm::Value *linearIV) -> llvm::Error {
2684 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2685 builder.restoreIP(bodyIP);
2686
2687 if (failed(convertIteratorRegion(linearIV, iterInfo, iteratorRegionBlock,
2688 builder, moduleTranslation))) {
2689 return llvm::make_error<llvm::StringError>(
2690 "failed to convert iterator region", llvm::inconvertibleErrorCode());
2691 }
2692
2693 auto yield =
2694 mlir::dyn_cast<mlir::omp::YieldOp>(iteratorRegionBlock.getTerminator());
2695 assert(yield && yield.getResults().size() == 1 &&
2696 "expect omp.yield in iterator region to have one result");
2697
2698 genStoreEntry(linearIV, yield);
2699
2700 // Iterator-region block/value mappings are temporary for this conversion,
2701 // clear them to avoid stale entries in ModuleTranslation.
2702 moduleTranslation.forgetMapping(itersRegion);
2703
2704 return llvm::Error::success();
2705 };
2706
2707 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2708 moduleTranslation.getOpenMPBuilder()->createIteratorLoop(
2709 loc, iterInfo.getTotalTrips(), bodyGen, loopName);
2710 if (failed(handleError(afterIP, *itersOp)))
2711 return failure();
2712
2713 builder.restoreIP(*afterIP);
2714
2715 return mlir::success();
2716}
2717
2718static mlir::LogicalResult
2719buildAffinityData(mlir::omp::TaskOp &taskOp, llvm::IRBuilderBase &builder,
2720 mlir::LLVM::ModuleTranslation &moduleTranslation,
2721 llvm::OpenMPIRBuilder::AffinityData &ad) {
2722
2723 if (taskOp.getAffinityVars().empty() && taskOp.getIterated().empty()) {
2724 ad.Count = nullptr;
2725 ad.Info = nullptr;
2726 return mlir::success();
2727 }
2728
2730 llvm::StructType *kmpTaskAffinityInfoTy =
2731 moduleTranslation.getOpenMPBuilder()->getKmpTaskAffinityInfoTy();
2732
2733 auto allocateAffinityList = [&](llvm::Value *count) -> llvm::Value * {
2734 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2735 if (llvm::isa<llvm::Constant>(count) || llvm::isa<llvm::Argument>(count))
2736 builder.restoreIP(findAllocInsertPoints(builder, moduleTranslation));
2737 return builder.CreateAlloca(kmpTaskAffinityInfoTy, count,
2738 "omp.affinity_list");
2739 };
2740
2741 auto createAffinity =
2742 [&](llvm::Value *count,
2743 llvm::Value *info) -> llvm::OpenMPIRBuilder::AffinityData {
2744 llvm::OpenMPIRBuilder::AffinityData ad{};
2745 ad.Count = builder.CreateTrunc(count, builder.getInt32Ty());
2746 ad.Info =
2747 builder.CreatePointerBitCastOrAddrSpaceCast(info, builder.getPtrTy(0));
2748 return ad;
2749 };
2750
2751 if (!taskOp.getAffinityVars().empty()) {
2752 llvm::Value *count = llvm::ConstantInt::get(
2753 builder.getInt64Ty(), taskOp.getAffinityVars().size());
2754 llvm::Value *list = allocateAffinityList(count);
2755 fillAffinityLocators(taskOp.getAffinityVars(), builder, moduleTranslation,
2756 list);
2757 ads.emplace_back(createAffinity(count, list));
2758 }
2759
2760 if (!taskOp.getIterated().empty()) {
2761 for (auto [i, iter] : llvm::enumerate(taskOp.getIterated())) {
2762 auto itersOp = iter.getDefiningOp<omp::IteratorOp>();
2763 assert(itersOp && "iterated value must be defined by omp.iterator");
2764 IteratorInfo iterInfo(itersOp, moduleTranslation, builder);
2765 llvm::Value *affList = allocateAffinityList(iterInfo.getTotalTrips());
2766 if (failed(fillIteratorLoop(
2767 itersOp, builder, moduleTranslation, iterInfo, "iterator",
2768 [&](llvm::Value *linearIV, mlir::omp::YieldOp yield) {
2769 auto entryOp = yield.getResults()[0]
2770 .getDefiningOp<mlir::omp::AffinityEntryOp>();
2771 assert(entryOp && "expect yield produce an affinity entry");
2772 llvm::Value *addr =
2773 moduleTranslation.lookupValue(entryOp.getAddr());
2774 llvm::Value *len =
2775 moduleTranslation.lookupValue(entryOp.getLen());
2776 storeAffinityEntry(builder,
2777 *moduleTranslation.getOpenMPBuilder(),
2778 affList, linearIV, addr, len);
2779 })))
2780 return llvm::failure();
2781 ads.emplace_back(createAffinity(iterInfo.getTotalTrips(), affList));
2782 }
2783 }
2784
2785 llvm::Value *totalAffinityCount = builder.getInt32(0);
2786 for (const auto &affinity : ads)
2787 totalAffinityCount = builder.CreateAdd(
2788 totalAffinityCount,
2789 builder.CreateIntCast(affinity.Count, builder.getInt32Ty(),
2790 /*isSigned=*/false));
2791
2792 llvm::Value *affinityInfo = ads.front().Info;
2793 if (ads.size() > 1) {
2794 llvm::StructType *kmpTaskAffinityInfoTy =
2795 moduleTranslation.getOpenMPBuilder()->getKmpTaskAffinityInfoTy();
2796 llvm::Value *affinityInfoElemSize = builder.getInt64(
2797 moduleTranslation.getLLVMModule()->getDataLayout().getTypeAllocSize(
2798 kmpTaskAffinityInfoTy));
2799
2800 llvm::Value *packedAffinityInfo = allocateAffinityList(totalAffinityCount);
2801 llvm::Value *packedAffinityInfoOffset = builder.getInt32(0);
2802 for (const auto &affinity : ads) {
2803 llvm::Value *affinityCount = builder.CreateIntCast(
2804 affinity.Count, builder.getInt32Ty(), /*isSigned=*/false);
2805 llvm::Value *affinityCountInt64 = builder.CreateIntCast(
2806 affinityCount, builder.getInt64Ty(), /*isSigned=*/false);
2807 llvm::Value *affinityInfoSize =
2808 builder.CreateMul(affinityCountInt64, affinityInfoElemSize);
2809
2810 llvm::Value *packedAffinityInfoIndex = builder.CreateIntCast(
2811 packedAffinityInfoOffset, kmpTaskAffinityInfoTy->getElementType(0),
2812 /*isSigned=*/false);
2813 packedAffinityInfoIndex = builder.CreateInBoundsGEP(
2814 kmpTaskAffinityInfoTy, packedAffinityInfo, packedAffinityInfoIndex);
2815
2816 builder.CreateMemCpy(
2817 packedAffinityInfoIndex, llvm::Align(1),
2818 builder.CreatePointerBitCastOrAddrSpaceCast(
2819 affinity.Info, builder.getPtrTy(packedAffinityInfoIndex->getType()
2820 ->getPointerAddressSpace())),
2821 llvm::Align(1), affinityInfoSize);
2822
2823 packedAffinityInfoOffset =
2824 builder.CreateAdd(packedAffinityInfoOffset, affinityCount);
2825 }
2826
2827 affinityInfo = packedAffinityInfo;
2828 }
2829
2830 ad.Count = totalAffinityCount;
2831 ad.Info = affinityInfo;
2832
2833 return mlir::success();
2834}
2835
2836// Allocates a single kmp_dep_info array sized to hold both locator
2837// (non-iterated) and iterated entries, fills the locator entries first, then
2838// runs an iterator loop for each iterator modifier object.
2839static mlir::LogicalResult
2840buildDependData(OperandRange dependVars, std::optional<ArrayAttr> dependKinds,
2841 OperandRange dependIterated,
2842 std::optional<ArrayAttr> dependIteratedKinds,
2843 llvm::IRBuilderBase &builder,
2844 mlir::LLVM::ModuleTranslation &moduleTranslation,
2845 llvm::OpenMPIRBuilder::DependenciesInfo &taskDeps) {
2846 if (dependIterated.empty()) {
2847 buildDependDataLocator(dependKinds, dependVars, moduleTranslation,
2848 taskDeps.Deps);
2849 return mlir::success();
2850 }
2851
2852 llvm::OpenMPIRBuilder &ompBuilder = *moduleTranslation.getOpenMPBuilder();
2853 llvm::Type *dependInfoTy = ompBuilder.DependInfo;
2854 unsigned numLocator = dependVars.size();
2855
2856 // Compute total count: locator deps + sum of iterator trip counts.
2857 llvm::Value *totalCount =
2858 llvm::ConstantInt::get(builder.getInt64Ty(), numLocator);
2859
2861 for (auto iter : dependIterated) {
2862 auto itersOp = iter.getDefiningOp<mlir::omp::IteratorOp>();
2863 assert(itersOp && "depend_iterated value must be defined by omp.iterator");
2864 iterInfos.emplace_back(itersOp, moduleTranslation, builder);
2865 totalCount =
2866 builder.CreateAdd(totalCount, iterInfos.back().getTotalTrips());
2867 }
2868
2869 // Heap-allocate the kmp_depend_info array so we don't risk
2870 // dynamic-sized alloca outside the entry block (e.g. inside loops).
2871 llvm::Constant *allocSize = llvm::ConstantExpr::getSizeOf(dependInfoTy);
2872 llvm::Value *depArray =
2873 builder.CreateMalloc(ompBuilder.SizeTy, dependInfoTy, allocSize,
2874 totalCount, /*MallocF=*/nullptr, ".dep.arr.addr");
2875
2876 // Fill non-iterated entries at indices [0, numLocator).
2877 if (numLocator > 0) {
2879 buildDependDataLocator(dependKinds, dependVars, moduleTranslation, dds);
2880 for (auto [i, dd] : llvm::enumerate(dds)) {
2881 llvm::Value *idx = llvm::ConstantInt::get(builder.getInt64Ty(), i);
2882 llvm::Value *entry =
2883 builder.CreateInBoundsGEP(dependInfoTy, depArray, idx);
2884 ompBuilder.emitTaskDependency(builder, entry, dd);
2885 }
2886 }
2887
2888 // Fill iterated entries starting at index numLocator.
2889 llvm::Value *offset =
2890 llvm::ConstantInt::get(builder.getInt64Ty(), numLocator);
2891 for (auto [i, iterInfo] : llvm::enumerate(iterInfos)) {
2892 auto kindAttr = cast<mlir::omp::ClauseTaskDependAttr>(
2893 dependIteratedKinds->getValue()[i]);
2894 llvm::omp::RTLDependenceKindTy rtlKind =
2895 convertDependKind(kindAttr.getValue());
2896
2897 auto itersOp = dependIterated[i].getDefiningOp<mlir::omp::IteratorOp>();
2898 if (failed(fillIteratorLoop(
2899 itersOp, builder, moduleTranslation, iterInfo, "dep_iterator",
2900 [&](llvm::Value *linearIV, mlir::omp::YieldOp yield) {
2901 llvm::Value *addr =
2902 moduleTranslation.lookupValue(yield.getResults()[0]);
2903 llvm::Value *idx = builder.CreateAdd(offset, linearIV);
2904 llvm::Value *entry =
2905 builder.CreateInBoundsGEP(dependInfoTy, depArray, idx);
2906 ompBuilder.emitTaskDependency(
2907 builder, entry,
2908 llvm::OpenMPIRBuilder::DependData{rtlKind, addr->getType(),
2909 addr});
2910 })))
2911 return mlir::failure();
2912
2913 // Advance offset by the trip count of this iterator.
2914 offset = builder.CreateAdd(offset, iterInfo.getTotalTrips());
2915 }
2916
2917 taskDeps.DepArray = depArray;
2918 taskDeps.NumDeps = builder.CreateTrunc(totalCount, builder.getInt32Ty());
2919 return mlir::success();
2920}
2921
2922/// Converts an OpenMP task construct into LLVM IR using OpenMPIRBuilder.
2923static LogicalResult
2924convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
2925 LLVM::ModuleTranslation &moduleTranslation) {
2926 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2927 if (failed(checkImplementationStatus(*taskOp)))
2928 return failure();
2929
2930 PrivateVarsInfo privateVarsInfo(taskOp);
2931 TaskContextStructManager taskStructMgr{builder, moduleTranslation,
2932 privateVarsInfo.privatizers};
2933
2934 // Allocate and copy private variables before creating the task. This avoids
2935 // accessing invalid memory if (after this scope ends) the private variables
2936 // are initialized from host variables or if the variables are copied into
2937 // from host variables (firstprivate). The insertion point is just before
2938 // where the code for creating and scheduling the task will go. That puts this
2939 // code outside of the outlined task region, which is what we want because
2940 // this way the initialization and copy regions are executed immediately while
2941 // the host variable data are still live.
2943 InsertPointTy allocaIP =
2944 findAllocInsertPoints(builder, moduleTranslation, &deallocBlocks);
2945
2946 // Not using splitBB() because that requires the current block to have a
2947 // terminator.
2948 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end());
2949 llvm::BasicBlock *taskStartBlock = llvm::BasicBlock::Create(
2950 builder.getContext(), "omp.task.start",
2951 /*Parent=*/builder.GetInsertBlock()->getParent());
2952 llvm::Instruction *branchToTaskStartBlock = builder.CreateBr(taskStartBlock);
2953 builder.SetInsertPoint(branchToTaskStartBlock);
2954
2955 // Now do this again to make the initialization and copy blocks
2956 llvm::BasicBlock *copyBlock =
2957 splitBB(builder, /*CreateBranch=*/true, "omp.private.copy");
2958 llvm::BasicBlock *initBlock =
2959 splitBB(builder, /*CreateBranch=*/true, "omp.private.init");
2960
2961 // Now the control flow graph should look like
2962 // starter_block:
2963 // <---- where we started when convertOmpTaskOp was called
2964 // br %omp.private.init
2965 // omp.private.init:
2966 // br %omp.private.copy
2967 // omp.private.copy:
2968 // br %omp.task.start
2969 // omp.task.start:
2970 // <---- where we want the insertion point to be when we call createTask()
2971
2972 // Save the alloca insertion point on ModuleTranslation stack for use in
2973 // nested regions.
2975 moduleTranslation, allocaIP, deallocBlocks);
2976
2977 // Allocate and initialize private variables
2978 builder.SetInsertPoint(initBlock->getTerminator());
2979
2980 // Create task variable structure
2981 taskStructMgr.generateTaskContextStruct();
2982 // GEPs so that we can initialize the variables. Don't use these GEPs inside
2983 // of the body otherwise it will be the GEP not the struct which is fowarded
2984 // to the outlined function. GEPs forwarded in this way are passed in a
2985 // stack-allocated (by OpenMPIRBuilder) structure which is not safe for tasks
2986 // which may not be executed until after the current stack frame goes out of
2987 // scope.
2988 taskStructMgr.createGEPsToPrivateVars();
2989
2990 for (auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVarAlloc] :
2991 llvm::zip_equal(privateVarsInfo.privatizers, privateVarsInfo.mlirVars,
2992 privateVarsInfo.blockArgs,
2993 taskStructMgr.getLLVMPrivateVarGEPs())) {
2994 // To be handled inside the task.
2995 if (!privDecl.readsFromMold())
2996 continue;
2997 assert(llvmPrivateVarAlloc &&
2998 "reads from mold so shouldn't have been skipped");
2999
3000 llvm::Expected<llvm::Value *> privateVarOrErr =
3001 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
3002 blockArg, llvmPrivateVarAlloc, initBlock);
3003 if (!privateVarOrErr)
3004 return handleError(privateVarOrErr, *taskOp.getOperation());
3005
3007
3008 // TODO: this is a bit of a hack for Fortran character boxes.
3009 // Character boxes are passed by value into the init region and then the
3010 // initialized character box is yielded by value. Here we need to store the
3011 // yielded value into the private allocation, and load the private
3012 // allocation to match the type expected by region block arguments.
3013 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
3014 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
3015 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
3016 // Load it so we have the value pointed to by the GEP
3017 llvmPrivateVarAlloc = builder.CreateLoad(privateVarOrErr.get()->getType(),
3018 llvmPrivateVarAlloc);
3019 }
3020 assert(llvmPrivateVarAlloc->getType() ==
3021 moduleTranslation.convertType(blockArg.getType()));
3022
3023 // Mapping blockArg -> llvmPrivateVarAlloc is done inside the body callback
3024 // so that OpenMPIRBuilder doesn't try to pass each GEP address through a
3025 // stack allocated structure.
3026 }
3027
3028 // firstprivate copy region
3029 setInsertPointForPossiblyEmptyBlock(builder, copyBlock);
3030 if (failed(copyFirstPrivateVars(
3031 taskOp, builder, moduleTranslation, privateVarsInfo.mlirVars,
3032 taskStructMgr.getLLVMPrivateVarGEPs(), privateVarsInfo.privatizers,
3033 taskOp.getPrivateNeedsBarrier())))
3034 return llvm::failure();
3035
3036 llvm::OpenMPIRBuilder::AffinityData ad;
3037 if (failed(buildAffinityData(taskOp, builder, moduleTranslation, ad)))
3038 return llvm::failure();
3039
3040 // Set up for call to createTask()
3041 builder.SetInsertPoint(taskStartBlock);
3042
3043 auto bodyCB =
3044 [&](InsertPointTy allocaIP, InsertPointTy codegenIP,
3045 llvm::ArrayRef<llvm::BasicBlock *> deallocBlocks) -> llvm::Error {
3046 // Save the alloca insertion point on ModuleTranslation stack for use in
3047 // nested regions.
3049 moduleTranslation, allocaIP, deallocBlocks);
3050
3051 // translate the body of the task:
3052 builder.restoreIP(codegenIP);
3053
3054 llvm::BasicBlock *privInitBlock = nullptr;
3055 privateVarsInfo.llvmVars.resize(privateVarsInfo.blockArgs.size());
3056 for (auto [i, zip] : llvm::enumerate(llvm::zip_equal(
3057 privateVarsInfo.blockArgs, privateVarsInfo.privatizers,
3058 privateVarsInfo.mlirVars))) {
3059 auto [blockArg, privDecl, mlirPrivVar] = zip;
3060 // This is handled before the task executes
3061 if (privDecl.readsFromMold())
3062 continue;
3063
3064 llvm::IRBuilderBase::InsertPointGuard guard(builder);
3065 llvm::Type *llvmAllocType =
3066 moduleTranslation.convertType(privDecl.getType());
3067 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
3068 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
3069 llvmAllocType, /*ArraySize=*/nullptr, "omp.private.alloc");
3070
3071 llvm::Expected<llvm::Value *> privateVarOrError =
3072 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
3073 blockArg, llvmPrivateVar, privInitBlock);
3074 if (!privateVarOrError)
3075 return privateVarOrError.takeError();
3076 moduleTranslation.mapValue(blockArg, privateVarOrError.get());
3077 privateVarsInfo.llvmVars[i] = privateVarOrError.get();
3078 }
3079
3080 taskStructMgr.createGEPsToPrivateVars();
3081 for (auto [i, llvmPrivVar] :
3082 llvm::enumerate(taskStructMgr.getLLVMPrivateVarGEPs())) {
3083 if (!llvmPrivVar) {
3084 assert(privateVarsInfo.llvmVars[i] &&
3085 "This is added in the loop above");
3086 continue;
3087 }
3088 privateVarsInfo.llvmVars[i] = llvmPrivVar;
3089 }
3090
3091 // Find and map the addresses of each variable within the task context
3092 // structure
3093 for (auto [blockArg, llvmPrivateVar, privateDecl] :
3094 llvm::zip_equal(privateVarsInfo.blockArgs, privateVarsInfo.llvmVars,
3095 privateVarsInfo.privatizers)) {
3096 // This was handled above.
3097 if (!privateDecl.readsFromMold())
3098 continue;
3099 // Fix broken pass-by-value case for Fortran character boxes
3100 if (!mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
3101 llvmPrivateVar = builder.CreateLoad(
3102 moduleTranslation.convertType(blockArg.getType()), llvmPrivateVar);
3103 }
3104 assert(llvmPrivateVar->getType() ==
3105 moduleTranslation.convertType(blockArg.getType()));
3106 moduleTranslation.mapValue(blockArg, llvmPrivateVar);
3107 }
3108
3109 auto continuationBlockOrError = convertOmpOpRegions(
3110 taskOp.getRegion(), "omp.task.region", builder, moduleTranslation);
3111 if (failed(handleError(continuationBlockOrError, *taskOp)))
3112 return llvm::make_error<PreviouslyReportedError>();
3113
3114 builder.SetInsertPoint(continuationBlockOrError.get()->getTerminator());
3115
3116 if (failed(cleanupPrivateVars(taskOp, builder, moduleTranslation,
3117 taskOp.getLoc(), privateVarsInfo)))
3118 return llvm::make_error<PreviouslyReportedError>();
3119
3120 // Free heap allocated task context structure at the end of the task.
3121 taskStructMgr.freeStructPtr();
3122
3123 return llvm::Error::success();
3124 };
3125
3126 llvm::OpenMPIRBuilder &ompBuilder = *moduleTranslation.getOpenMPBuilder();
3127 SmallVector<llvm::UncondBrInst *> cancelTerminators;
3128 // The directive to match here is OMPD_taskgroup because it is the taskgroup
3129 // which is canceled. This is handled here because it is the task's cleanup
3130 // block which should be branched to.
3131 pushCancelFinalizationCB(cancelTerminators, builder, ompBuilder, taskOp,
3132 llvm::omp::Directive::OMPD_taskgroup);
3133
3134 llvm::OpenMPIRBuilder::DependenciesInfo dependencies;
3135 if (failed(buildDependData(taskOp.getDependVars(), taskOp.getDependKinds(),
3136 taskOp.getDependIterated(),
3137 taskOp.getDependIteratedKinds(), builder,
3138 moduleTranslation, dependencies)))
3139 return failure();
3140
3141 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3142 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3143 moduleTranslation.getOpenMPBuilder()->createTask(
3144 ompLoc, allocaIP, deallocBlocks, bodyCB, !taskOp.getUntied(),
3145 moduleTranslation.lookupValue(taskOp.getFinal()),
3146 moduleTranslation.lookupValue(taskOp.getIfExpr()), dependencies, ad,
3147 taskOp.getMergeable(),
3148 moduleTranslation.lookupValue(taskOp.getEventHandle()),
3149 moduleTranslation.lookupValue(taskOp.getPriority()));
3150
3151 if (failed(handleError(afterIP, *taskOp)))
3152 return failure();
3153
3154 // Set the correct branch target for task cancellation
3155 popCancelFinalizationCB(cancelTerminators, ompBuilder, afterIP.get());
3156
3157 builder.restoreIP(*afterIP);
3158
3159 if (dependencies.DepArray)
3160 builder.CreateFree(dependencies.DepArray);
3161
3162 return success();
3163}
3164
3165/// The correct entry point is convertOmpTaskloopContextOp. This gets called
3166/// whilst lowering the body of the taskloop context (i.e. the task function).
3167static LogicalResult
3168convertOmpTaskloopWrapperOp(omp::TaskloopWrapperOp loopWrapperOp,
3169 llvm::IRBuilderBase &builder,
3170 LLVM::ModuleTranslation &moduleTranslation) {
3171 mlir::Operation &opInst = *loopWrapperOp.getOperation();
3172 if (failed(checkImplementationStatus(opInst)))
3173 return failure();
3174
3175 // Recurse into the loop body.
3176 auto continuationBlockOrError = convertOmpOpRegions(
3177 loopWrapperOp.getRegion(), "omp.taskloop.wrapper.region", builder,
3178 moduleTranslation);
3179
3180 if (failed(handleError(continuationBlockOrError, opInst)))
3181 return failure();
3182
3183 builder.SetInsertPoint(continuationBlockOrError.get());
3184 return success();
3185}
3186
3187/// Look up the given value in the mapping, and if it's not there, translate its
3188/// defining operation at the current builder insertion point. Only pure,
3189/// regionless operations are supported because the same operation will later be
3190/// translated again when the taskloop body itself is lowered.
3191static llvm::Expected<llvm::Value *>
3193 LLVM::ModuleTranslation &moduleTranslation,
3194 llvm::IRBuilderBase &builder) {
3195 if (llvm::Value *mapped = moduleTranslation.lookupValue(value))
3196 return mapped;
3197
3198 Operation *defOp = value.getDefiningOp();
3199 if (!defOp)
3200 return llvm::make_error<llvm::StringError>(
3201 "value is a block argument and is not mapped",
3202 llvm::inconvertibleErrorCode());
3203 if (defOp->getNumRegions() != 0 || !isPure(defOp))
3204 return llvm::make_error<llvm::StringError>(
3205 "unsupported op defining taskloop loop bound",
3206 llvm::inconvertibleErrorCode());
3207
3208 SmallVector<Value> mappingsToRemove;
3209 mappingsToRemove.reserve(defOp->getNumOperands() + defOp->getNumResults());
3210 for (Value operand : defOp->getOperands()) {
3211 if (moduleTranslation.lookupValue(operand))
3212 continue;
3213
3214 llvm::Expected<llvm::Value *> operandOrError =
3215 lookupOrTranslatePureValue(operand, moduleTranslation, builder);
3216 if (!operandOrError)
3217 return operandOrError.takeError();
3218 moduleTranslation.mapValue(operand, *operandOrError);
3219 mappingsToRemove.push_back(operand);
3220 }
3221
3222 if (failed(moduleTranslation.convertOperation(*defOp, builder)))
3223 return llvm::make_error<llvm::StringError>(
3224 "failed to convert op defining taskloop loop bound",
3225 llvm::inconvertibleErrorCode());
3226
3227 llvm::Value *result = moduleTranslation.lookupValue(value);
3228 assert(result && "expected conversion of loop bound op to produce a value");
3229
3230 for (Value resultValue : defOp->getResults()) {
3231 if (moduleTranslation.lookupValue(resultValue))
3232 mappingsToRemove.push_back(resultValue);
3233 }
3234 for (Value mappedValue : mappingsToRemove)
3235 moduleTranslation.forgetMapping(mappedValue);
3236
3237 return result;
3238}
3239
3240static llvm::Error
3241computeTaskloopBounds(omp::LoopNestOp loopOp, llvm::IRBuilderBase &builder,
3242 LLVM::ModuleTranslation &moduleTranslation,
3243 llvm::Value *&lbVal, llvm::Value *&ubVal,
3244 llvm::Value *&stepVal) {
3245 Operation::operand_range lowerBounds = loopOp.getLoopLowerBounds();
3246 Operation::operand_range upperBounds = loopOp.getLoopUpperBounds();
3247 Operation::operand_range steps = loopOp.getLoopSteps();
3248
3249 llvm::Expected<llvm::Value *> firstLbOrErr =
3250 lookupOrTranslatePureValue(lowerBounds[0], moduleTranslation, builder);
3251 if (!firstLbOrErr)
3252 return firstLbOrErr.takeError();
3253
3254 llvm::Type *boundType = (*firstLbOrErr)->getType();
3255 ubVal = builder.getIntN(boundType->getIntegerBitWidth(), 1);
3256 if (loopOp.getCollapseNumLoops() > 1) {
3257 // In cases where Collapse is used with Taskloop, the upper bound of the
3258 // iteration space needs to be recalculated to cater for the collapsed loop.
3259 // The Collapsed Loop UpperBound is the product of all collapsed
3260 // loop's tripcount.
3261 // The LowerBound for collapsed loops is always 1. When the loops are
3262 // collapsed, it will reset the bounds and introduce processing to ensure
3263 // the index's are presented as expected. As this happens after creating
3264 // Taskloop, these bounds need predicting. Example:
3265 // !$omp taskloop collapse(2)
3266 // do i = 1, 10
3267 // do j = 1, 5
3268 // ..
3269 // end do
3270 // end do
3271 // This loop above has a total of 50 iterations, so the lb will be 1, and
3272 // the ub will be 50. collapseLoops in OMPIRBuilder then handles ensuring
3273 // that i and j are properly presented when used in the loop.
3274 for (uint64_t i = 0; i < loopOp.getCollapseNumLoops(); i++) {
3276 i == 0 ? std::move(firstLbOrErr)
3277 : lookupOrTranslatePureValue(lowerBounds[i], moduleTranslation,
3278 builder);
3279 if (!lbOrErr)
3280 return lbOrErr.takeError();
3282 upperBounds[i], moduleTranslation, builder);
3283 if (!ubOrErr)
3284 return ubOrErr.takeError();
3286 lookupOrTranslatePureValue(steps[i], moduleTranslation, builder);
3287 if (!stepOrErr)
3288 return stepOrErr.takeError();
3289
3290 llvm::Value *loopLb = *lbOrErr;
3291 llvm::Value *loopUb = *ubOrErr;
3292 llvm::Value *loopStep = *stepOrErr;
3293 // In some cases, such as where the ub is less than the lb so the loop
3294 // steps down, the calculation for the loopTripCount is swapped. To ensure
3295 // the correct value is found, calculate both UB - LB and LB - UB then
3296 // select which value to use depending on how the loop has been
3297 // configured.
3298 llvm::Value *loopLbMinusOne = builder.CreateSub(
3299 loopLb, builder.getIntN(boundType->getIntegerBitWidth(), 1));
3300 llvm::Value *loopUbMinusOne = builder.CreateSub(
3301 loopUb, builder.getIntN(boundType->getIntegerBitWidth(), 1));
3302 llvm::Value *boundsCmp = builder.CreateICmpSLT(loopLb, loopUb);
3303 llvm::Value *ubMinusLb = builder.CreateSub(loopUb, loopLbMinusOne);
3304 llvm::Value *lbMinusUb = builder.CreateSub(loopLb, loopUbMinusOne);
3305 llvm::Value *loopTripCount =
3306 builder.CreateSelect(boundsCmp, ubMinusLb, lbMinusUb);
3307 loopTripCount = builder.CreateBinaryIntrinsic(
3308 llvm::Intrinsic::abs, loopTripCount, builder.getFalse());
3309 // For loops that have a step value not equal to 1, we need to adjust the
3310 // trip count to ensure the correct number of iterations for the loop is
3311 // captured.
3312 llvm::Value *loopTripCountDivStep =
3313 builder.CreateSDiv(loopTripCount, loopStep);
3314 loopTripCountDivStep = builder.CreateBinaryIntrinsic(
3315 llvm::Intrinsic::abs, loopTripCountDivStep, builder.getFalse());
3316 llvm::Value *loopTripCountRem =
3317 builder.CreateSRem(loopTripCount, loopStep);
3318 loopTripCountRem = builder.CreateBinaryIntrinsic(
3319 llvm::Intrinsic::abs, loopTripCountRem, builder.getFalse());
3320 llvm::Value *needsRoundUp = builder.CreateICmpNE(
3321 loopTripCountRem,
3322 builder.getIntN(loopTripCountRem->getType()->getIntegerBitWidth(),
3323 0));
3324 loopTripCount =
3325 builder.CreateAdd(loopTripCountDivStep,
3326 builder.CreateZExtOrTrunc(
3327 needsRoundUp, loopTripCountDivStep->getType()));
3328 ubVal = builder.CreateMul(ubVal, loopTripCount);
3329 }
3330 lbVal = builder.getIntN(boundType->getIntegerBitWidth(), 1);
3331 stepVal = builder.getIntN(boundType->getIntegerBitWidth(), 1);
3332 } else {
3334 lookupOrTranslatePureValue(upperBounds[0], moduleTranslation, builder);
3335 if (!ubOrErr)
3336 return ubOrErr.takeError();
3338 lookupOrTranslatePureValue(steps[0], moduleTranslation, builder);
3339 if (!stepOrErr)
3340 return stepOrErr.takeError();
3341 lbVal = *firstLbOrErr;
3342 ubVal = *ubOrErr;
3343 stepVal = *stepOrErr;
3344 }
3345
3346 assert(lbVal != nullptr && "Expected value for lbVal");
3347 assert(ubVal != nullptr && "Expected value for ubVal");
3348 assert(stepVal != nullptr && "Expected value for stepVal");
3349 return llvm::Error::success();
3350}
3351
3352// Converts an OpenMP taskloop construct into LLVM IR using OpenMPIRBuilder.
3353static LogicalResult
3354convertOmpTaskloopContextOp(omp::TaskloopContextOp contextOp,
3355 llvm::IRBuilderBase &builder,
3356 LLVM::ModuleTranslation &moduleTranslation) {
3357 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
3358 mlir::Operation &opInst = *contextOp.getOperation();
3359 omp::TaskloopWrapperOp loopWrapperOp = contextOp.getLoopOp();
3360 if (failed(checkImplementationStatus(opInst)))
3361 return failure();
3362
3363 // It stores the pointer of allocated firstprivate copies,
3364 // which can be used later for freeing the allocated space.
3365 SmallVector<llvm::Value *> llvmFirstPrivateVars;
3366 PrivateVarsInfo privateVarsInfo(contextOp);
3367 TaskContextStructManager taskStructMgr{builder, moduleTranslation,
3368 privateVarsInfo.privatizers};
3369
3371 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3372 findAllocInsertPoints(builder, moduleTranslation, &deallocBlocks);
3373
3374 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end());
3375 llvm::BasicBlock *taskloopStartBlock = llvm::BasicBlock::Create(
3376 builder.getContext(), "omp.taskloop.wrapper.start",
3377 /*Parent=*/builder.GetInsertBlock()->getParent());
3378 llvm::Instruction *branchToTaskloopStartBlock =
3379 builder.CreateBr(taskloopStartBlock);
3380 builder.SetInsertPoint(branchToTaskloopStartBlock);
3381
3382 llvm::BasicBlock *copyBlock =
3383 splitBB(builder, /*CreateBranch=*/true, "omp.private.copy");
3384 llvm::BasicBlock *initBlock =
3385 splitBB(builder, /*CreateBranch=*/true, "omp.private.init");
3386
3388 moduleTranslation, allocaIP, deallocBlocks);
3389
3390 // Allocate and initialize private variables
3391 builder.SetInsertPoint(initBlock->getTerminator());
3392
3393 // TODO: don't allocate if the loop has zero iterations.
3394 taskStructMgr.generateTaskContextStruct();
3395 taskStructMgr.createGEPsToPrivateVars();
3396
3397 llvmFirstPrivateVars.resize(privateVarsInfo.blockArgs.size());
3398
3399 for (auto [i, zip] : llvm::enumerate(llvm::zip_equal(
3400 privateVarsInfo.privatizers, privateVarsInfo.mlirVars,
3401 privateVarsInfo.blockArgs, taskStructMgr.getLLVMPrivateVarGEPs()))) {
3402 auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVarAlloc] = zip;
3403 // To be handled inside the taskloop.
3404 if (!privDecl.readsFromMold())
3405 continue;
3406 assert(llvmPrivateVarAlloc &&
3407 "reads from mold so shouldn't have been skipped");
3408
3409 llvm::Expected<llvm::Value *> privateVarOrErr =
3410 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
3411 blockArg, llvmPrivateVarAlloc, initBlock);
3412 if (!privateVarOrErr)
3413 return handleError(privateVarOrErr, *contextOp.getOperation());
3414
3415 llvmFirstPrivateVars[i] = privateVarOrErr.get();
3416
3417 llvm::IRBuilderBase::InsertPointGuard guard(builder);
3418 builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
3419
3420 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
3421 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
3422 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
3423 // Load it so we have the value pointed to by the GEP
3424 llvmPrivateVarAlloc = builder.CreateLoad(privateVarOrErr.get()->getType(),
3425 llvmPrivateVarAlloc);
3426 }
3427 assert(llvmPrivateVarAlloc->getType() ==
3428 moduleTranslation.convertType(blockArg.getType()));
3429 }
3430
3431 // firstprivate copy region
3432 setInsertPointForPossiblyEmptyBlock(builder, copyBlock);
3433 if (failed(copyFirstPrivateVars(
3434 contextOp, builder, moduleTranslation, privateVarsInfo.mlirVars,
3435 taskStructMgr.getLLVMPrivateVarGEPs(), privateVarsInfo.privatizers,
3436 contextOp.getPrivateNeedsBarrier())))
3437 return llvm::failure();
3438
3439 // Set up inserttion point for call to createTaskloop()
3440 builder.SetInsertPoint(taskloopStartBlock);
3441
3442 auto loopOp = cast<omp::LoopNestOp>(loopWrapperOp.getWrappedLoop());
3443 llvm::Value *lbVal = nullptr;
3444 llvm::Value *ubVal = nullptr;
3445 llvm::Value *stepVal = nullptr;
3446 if (llvm::Error err = computeTaskloopBounds(
3447 loopOp, builder, moduleTranslation, lbVal, ubVal, stepVal))
3448 return handleError(std::move(err), opInst);
3449
3450 auto bodyCB =
3451 [&](InsertPointTy allocaIP, InsertPointTy codegenIP,
3452 llvm::ArrayRef<llvm::BasicBlock *> deallocBlocks) -> llvm::Error {
3453 // Save the alloca insertion point on ModuleTranslation stack for use in
3454 // nested regions.
3456 moduleTranslation, allocaIP, deallocBlocks);
3457
3458 // translate the body of the taskloop:
3459 builder.restoreIP(codegenIP);
3460
3461 llvm::BasicBlock *privInitBlock = nullptr;
3462 privateVarsInfo.llvmVars.resize(privateVarsInfo.blockArgs.size());
3463 for (auto [i, zip] : llvm::enumerate(llvm::zip_equal(
3464 privateVarsInfo.blockArgs, privateVarsInfo.privatizers,
3465 privateVarsInfo.mlirVars))) {
3466 auto [blockArg, privDecl, mlirPrivVar] = zip;
3467 // This is handled before the task executes
3468 if (privDecl.readsFromMold())
3469 continue;
3470
3471 llvm::IRBuilderBase::InsertPointGuard guard(builder);
3472 llvm::Type *llvmAllocType =
3473 moduleTranslation.convertType(privDecl.getType());
3474 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
3475 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
3476 llvmAllocType, /*ArraySize=*/nullptr, "omp.private.alloc");
3477
3478 llvm::Expected<llvm::Value *> privateVarOrError =
3479 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
3480 blockArg, llvmPrivateVar, privInitBlock);
3481 if (!privateVarOrError)
3482 return privateVarOrError.takeError();
3483 moduleTranslation.mapValue(blockArg, privateVarOrError.get());
3484 privateVarsInfo.llvmVars[i] = privateVarOrError.get();
3485 }
3486
3487 taskStructMgr.createGEPsToPrivateVars();
3488 for (auto [i, llvmPrivVar] :
3489 llvm::enumerate(taskStructMgr.getLLVMPrivateVarGEPs())) {
3490 if (!llvmPrivVar) {
3491 assert(privateVarsInfo.llvmVars[i] &&
3492 "This is added in the loop above");
3493 continue;
3494 }
3495 privateVarsInfo.llvmVars[i] = llvmPrivVar;
3496 }
3497
3498 // Find and map the addresses of each variable within the taskloop context
3499 // structure
3500 for (auto [blockArg, llvmPrivateVar, privateDecl] :
3501 llvm::zip_equal(privateVarsInfo.blockArgs, privateVarsInfo.llvmVars,
3502 privateVarsInfo.privatizers)) {
3503 // This was handled above.
3504 if (!privateDecl.readsFromMold())
3505 continue;
3506 // Fix broken pass-by-value case for Fortran character boxes
3507 if (!mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
3508 llvmPrivateVar = builder.CreateLoad(
3509 moduleTranslation.convertType(blockArg.getType()), llvmPrivateVar);
3510 }
3511 assert(llvmPrivateVar->getType() ==
3512 moduleTranslation.convertType(blockArg.getType()));
3513 moduleTranslation.mapValue(blockArg, llvmPrivateVar);
3514 }
3515
3516 // Lower the contents of the taskloop context region: this is the body of
3517 // the generated task, not the loop.
3518 auto continuationBlockOrError = convertOmpOpRegions(
3519 contextOp.getRegion(), "omp.taskloop.context.region", builder,
3520 moduleTranslation);
3521
3522 if (failed(handleError(continuationBlockOrError, opInst)))
3523 return llvm::make_error<PreviouslyReportedError>();
3524
3525 builder.SetInsertPoint(continuationBlockOrError.get()->getTerminator());
3526
3527 // This is freeing the private variables as mapped inside of the task: these
3528 // will be per-task private copies possibly after task duplication. This is
3529 // handled transparently by how these are passed to the structure passed
3530 // into the outlined function. When the task is duplicated, that structure
3531 // is duplicated too.
3532 if (failed(cleanupPrivateVars(contextOp, builder, moduleTranslation,
3533 contextOp.getLoc(), privateVarsInfo)))
3534 return llvm::make_error<PreviouslyReportedError>();
3535 // Similarly, the task context structure freed inside the task is the
3536 // per-task copy after task duplication.
3537 taskStructMgr.freeStructPtr();
3538
3539 return llvm::Error::success();
3540 };
3541
3542 // Taskloop divides into an appropriate number of tasks by repeatedly
3543 // duplicating the original task. Each time this is done, the task context
3544 // structure must be duplicated too.
3545 auto taskDupCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP,
3546 llvm::Value *destPtr, llvm::Value *srcPtr)
3548 llvm::IRBuilderBase::InsertPointGuard guard(builder);
3549 builder.restoreIP(codegenIP);
3550
3551 llvm::Type *ptrTy =
3552 builder.getPtrTy(srcPtr->getType()->getPointerAddressSpace());
3553 llvm::Value *src =
3554 builder.CreateLoad(ptrTy, srcPtr, "omp.taskloop.context.src");
3555
3556 TaskContextStructManager &srcStructMgr = taskStructMgr;
3557 TaskContextStructManager destStructMgr(builder, moduleTranslation,
3558 privateVarsInfo.privatizers);
3559 destStructMgr.generateTaskContextStruct();
3560 llvm::Value *dest = destStructMgr.getStructPtr();
3561 dest->setName("omp.taskloop.context.dest");
3562 builder.CreateStore(dest, destPtr);
3563
3565 srcStructMgr.createGEPsToPrivateVars(src);
3567 destStructMgr.createGEPsToPrivateVars(dest);
3568
3569 // Inline init regions.
3570 for (auto [privDecl, mold, blockArg, llvmPrivateVarAlloc] :
3571 llvm::zip_equal(privateVarsInfo.privatizers, srcGEPs,
3572 privateVarsInfo.blockArgs, destGEPs)) {
3573 // To be handled inside task body.
3574 if (!privDecl.readsFromMold())
3575 continue;
3576 assert(llvmPrivateVarAlloc &&
3577 "reads from mold so shouldn't have been skipped");
3578
3579 llvm::Expected<llvm::Value *> privateVarOrErr =
3580 initPrivateVar(builder, moduleTranslation, privDecl, mold, blockArg,
3581 llvmPrivateVarAlloc, builder.GetInsertBlock());
3582 if (!privateVarOrErr)
3583 return privateVarOrErr.takeError();
3584
3586
3587 // TODO: this is a bit of a hack for Fortran character boxes.
3588 // Character boxes are passed by value into the init region and then the
3589 // initialized character box is yielded by value. Here we need to store
3590 // the yielded value into the private allocation, and load the private
3591 // allocation to match the type expected by region block arguments.
3592 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
3593 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
3594 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
3595 // Load it so we have the value pointed to by the GEP
3596 llvmPrivateVarAlloc = builder.CreateLoad(
3597 privateVarOrErr.get()->getType(), llvmPrivateVarAlloc);
3598 }
3599 assert(llvmPrivateVarAlloc->getType() ==
3600 moduleTranslation.convertType(blockArg.getType()));
3601
3602 // Mapping blockArg -> llvmPrivateVarAlloc is done inside the body
3603 // callback so that OpenMPIRBuilder doesn't try to pass each GEP address
3604 // through a stack allocated structure.
3605 }
3606
3607 if (failed(copyFirstPrivateVars(contextOp.getOperation(), builder,
3608 moduleTranslation, srcGEPs, destGEPs,
3609 privateVarsInfo.privatizers,
3610 contextOp.getPrivateNeedsBarrier())))
3611 return llvm::make_error<PreviouslyReportedError>();
3612
3613 return builder.saveIP();
3614 };
3615
3616 auto loopInfo = [&]() -> llvm::Expected<llvm::CanonicalLoopInfo *> {
3617 llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
3618 return loopInfo;
3619 };
3620
3621 llvm::Value *ifCond = nullptr;
3622 llvm::Value *grainsize = nullptr;
3623 int sched = 0; // default
3624 mlir::Value grainsizeVal = contextOp.getGrainsize();
3625 mlir::Value numTasksVal = contextOp.getNumTasks();
3626 if (Value ifVar = contextOp.getIfExpr())
3627 ifCond = moduleTranslation.lookupValue(ifVar);
3628 if (grainsizeVal) {
3629 grainsize = moduleTranslation.lookupValue(grainsizeVal);
3630 sched = 1; // grainsize
3631 } else if (numTasksVal) {
3632 grainsize = moduleTranslation.lookupValue(numTasksVal);
3633 sched = 2; // num_tasks
3634 }
3635
3636 llvm::OpenMPIRBuilder::TaskDupCallbackTy taskDupOrNull = nullptr;
3637 if (taskStructMgr.getStructPtr())
3638 taskDupOrNull = taskDupCB;
3639
3640 llvm::OpenMPIRBuilder &ompBuilder = *moduleTranslation.getOpenMPBuilder();
3641 SmallVector<llvm::UncondBrInst *> cancelTerminators;
3642 // The directive to match here is OMPD_taskgroup because it is the
3643 // taskgroup which is canceled. This is handled here because it is the
3644 // task's cleanup block which should be branched to. It doesn't depend upon
3645 // nogroup because even in that case the taskloop might still be inside an
3646 // explicit taskgroup.
3647 pushCancelFinalizationCB(cancelTerminators, builder, ompBuilder, contextOp,
3648 llvm::omp::Directive::OMPD_taskgroup);
3649
3650 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3651 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3652 moduleTranslation.getOpenMPBuilder()->createTaskloop(
3653 ompLoc, allocaIP, deallocBlocks, bodyCB, loopInfo, lbVal, ubVal,
3654 stepVal, contextOp.getUntied(), ifCond, grainsize,
3655 contextOp.getNogroup(), sched,
3656 moduleTranslation.lookupValue(contextOp.getFinal()),
3657 contextOp.getMergeable(),
3658 moduleTranslation.lookupValue(contextOp.getPriority()),
3659 loopOp.getCollapseNumLoops(), taskDupOrNull,
3660 taskStructMgr.getStructPtr());
3661
3662 if (failed(handleError(afterIP, opInst)))
3663 return failure();
3664
3665 popCancelFinalizationCB(cancelTerminators, ompBuilder, afterIP.get());
3666
3667 builder.restoreIP(*afterIP);
3668 return success();
3669}
3670
3671/// Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder.
3672static LogicalResult
3673convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder,
3674 LLVM::ModuleTranslation &moduleTranslation) {
3675 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
3676 if (failed(checkImplementationStatus(*tgOp)))
3677 return failure();
3678
3679 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP,
3680 llvm::ArrayRef<llvm::BasicBlock *> deallocBlocks) {
3681 builder.restoreIP(codegenIP);
3682 return convertOmpOpRegions(tgOp.getRegion(), "omp.taskgroup.region",
3683 builder, moduleTranslation)
3684 .takeError();
3685 };
3686
3688 InsertPointTy allocaIP =
3689 findAllocInsertPoints(builder, moduleTranslation, &deallocBlocks);
3690 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3691 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3692 moduleTranslation.getOpenMPBuilder()->createTaskgroup(
3693 ompLoc, allocaIP, deallocBlocks, bodyCB);
3694
3695 if (failed(handleError(afterIP, *tgOp)))
3696 return failure();
3697
3698 builder.restoreIP(*afterIP);
3699 return success();
3700}
3701
3702static LogicalResult
3703convertOmpTaskwaitOp(omp::TaskwaitOp twOp, llvm::IRBuilderBase &builder,
3704 LLVM::ModuleTranslation &moduleTranslation) {
3705 if (failed(checkImplementationStatus(*twOp)))
3706 return failure();
3707
3708 moduleTranslation.getOpenMPBuilder()->createTaskwait(builder.saveIP());
3709 return success();
3710}
3711
3712/// Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
3713static LogicalResult
3714convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
3715 LLVM::ModuleTranslation &moduleTranslation) {
3716 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3717 auto wsloopOp = cast<omp::WsloopOp>(opInst);
3718 if (failed(checkImplementationStatus(opInst)))
3719 return failure();
3720
3721 auto loopOp = cast<omp::LoopNestOp>(wsloopOp.getWrappedLoop());
3722 llvm::ArrayRef<bool> isByRef = getIsByRef(wsloopOp.getReductionByref());
3723 assert(isByRef.size() == wsloopOp.getNumReductionVars());
3724
3725 // Static is the default.
3726 auto schedule =
3727 wsloopOp.getScheduleKind().value_or(omp::ClauseScheduleKind::Static);
3728
3729 // Find the loop configuration.
3730 llvm::Value *step = moduleTranslation.lookupValue(loopOp.getLoopSteps()[0]);
3731 llvm::Type *ivType = step->getType();
3732 llvm::Value *chunk = nullptr;
3733 if (wsloopOp.getScheduleChunk()) {
3734 llvm::Value *chunkVar =
3735 moduleTranslation.lookupValue(wsloopOp.getScheduleChunk());
3736 chunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
3737 }
3738
3739 omp::DistributeOp distributeOp = nullptr;
3740 llvm::Value *distScheduleChunk = nullptr;
3741 bool hasDistSchedule = false;
3742 if (llvm::isa_and_present<omp::DistributeOp>(opInst.getParentOp())) {
3743 distributeOp = cast<omp::DistributeOp>(opInst.getParentOp());
3744 hasDistSchedule = distributeOp.getDistScheduleStatic();
3745 if (distributeOp.getDistScheduleChunkSize()) {
3746 llvm::Value *chunkVar = moduleTranslation.lookupValue(
3747 distributeOp.getDistScheduleChunkSize());
3748 distScheduleChunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
3749 }
3750 }
3751
3752 PrivateVarsInfo privateVarsInfo(wsloopOp);
3753
3755 collectReductionDecls(wsloopOp, reductionDecls);
3756
3757 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3758 findAllocInsertPoints(builder, moduleTranslation);
3759
3760 SmallVector<llvm::Value *> privateReductionVariables(
3761 wsloopOp.getNumReductionVars());
3762
3764 wsloopOp, builder, moduleTranslation, privateVarsInfo, allocaIP);
3765 if (handleError(afterAllocas, opInst).failed())
3766 return failure();
3767
3768 DenseMap<Value, llvm::Value *> reductionVariableMap;
3769
3770 MutableArrayRef<BlockArgument> reductionArgs =
3771 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
3772
3773 SmallVector<DeferredStore> deferredStores;
3774
3775 if (failed(allocReductionVars(wsloopOp, reductionArgs, builder,
3776 moduleTranslation, allocaIP, reductionDecls,
3777 privateReductionVariables, reductionVariableMap,
3778 deferredStores, isByRef)))
3779 return failure();
3780
3781 if (handleError(initPrivateVars(builder, moduleTranslation, privateVarsInfo),
3782 opInst)
3783 .failed())
3784 return failure();
3785
3786 if (failed(copyFirstPrivateVars(
3787 wsloopOp, builder, moduleTranslation, privateVarsInfo.mlirVars,
3788 privateVarsInfo.llvmVars, privateVarsInfo.privatizers,
3789 wsloopOp.getPrivateNeedsBarrier())))
3790 return failure();
3791
3792 assert(afterAllocas.get()->getSinglePredecessor());
3793 if (failed(initReductionVars(wsloopOp, reductionArgs, builder,
3794 moduleTranslation,
3795 afterAllocas.get()->getSinglePredecessor(),
3796 reductionDecls, privateReductionVariables,
3797 reductionVariableMap, isByRef, deferredStores)))
3798 return failure();
3799
3800 // TODO: Handle doacross loops when the ordered clause has a parameter.
3801 bool isOrdered = wsloopOp.getOrdered().has_value();
3802 std::optional<omp::ScheduleModifier> scheduleMod = wsloopOp.getScheduleMod();
3803 bool isSimd = wsloopOp.getScheduleSimd();
3804 bool loopNeedsBarrier = !wsloopOp.getNowait();
3805
3806 // The only legal way for the direct parent to be omp.distribute is that this
3807 // represents 'distribute parallel do'. Otherwise, this is a regular
3808 // worksharing loop.
3809 llvm::omp::WorksharingLoopType workshareLoopType =
3810 llvm::isa_and_present<omp::DistributeOp>(opInst.getParentOp())
3811 ? llvm::omp::WorksharingLoopType::DistributeForStaticLoop
3812 : llvm::omp::WorksharingLoopType::ForStaticLoop;
3813
3814 SmallVector<llvm::UncondBrInst *> cancelTerminators;
3815 pushCancelFinalizationCB(cancelTerminators, builder, *ompBuilder, wsloopOp,
3816 llvm::omp::Directive::OMPD_for);
3817
3818 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3819
3820 // Initialize linear variables and linear step
3821 LinearClauseProcessor linearClauseProcessor;
3822
3823 if (!wsloopOp.getLinearVars().empty()) {
3824 auto linearVarTypes = wsloopOp.getLinearVarTypes().value();
3825 for (mlir::Attribute linearVarType : linearVarTypes)
3826 linearClauseProcessor.registerType(moduleTranslation, linearVarType);
3827
3828 for (auto [idx, linearVar] : llvm::enumerate(wsloopOp.getLinearVars()))
3829 linearClauseProcessor.createLinearVar(
3830 builder, moduleTranslation, moduleTranslation.lookupValue(linearVar),
3831 idx);
3832 for (mlir::Value linearStep : wsloopOp.getLinearStepVars())
3833 linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
3834 }
3835
3836 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
3838 wsloopOp.getRegion(), "omp.wsloop.region", builder, moduleTranslation);
3839
3840 if (failed(handleError(regionBlock, opInst)))
3841 return failure();
3842
3843 llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
3844
3845 // Emit Initialization and Update IR for linear variables
3846 if (!wsloopOp.getLinearVars().empty()) {
3847 linearClauseProcessor.initLinearVar(builder, moduleTranslation,
3848 loopInfo->getPreheader());
3849 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
3850 moduleTranslation.getOpenMPBuilder()->createBarrier(
3851 builder.saveIP(), llvm::omp::OMPD_barrier);
3852 if (failed(handleError(afterBarrierIP, *loopOp)))
3853 return failure();
3854 builder.restoreIP(*afterBarrierIP);
3855 linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
3856 loopInfo->getIndVar());
3857 linearClauseProcessor.splitLinearFiniBB(builder, loopInfo->getExit());
3858 }
3859
3860 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
3861
3862 // Check if we can generate no-loop kernel
3863 bool noLoopMode = false;
3864 omp::TargetOp targetOp = wsloopOp->getParentOfType<mlir::omp::TargetOp>();
3865 if (targetOp) {
3866 Operation *targetCapturedOp = targetOp.getInnermostCapturedOmpOp();
3867 // We need this check because, without it, noLoopMode would be set to true
3868 // for every omp.wsloop nested inside a no-loop SPMD target region, even if
3869 // that loop is not the top-level SPMD one.
3870 if (loopOp == targetCapturedOp) {
3871 if (targetOp.getKernelExecFlags(targetCapturedOp) ==
3872 omp::TargetExecMode::no_loop)
3873 noLoopMode = true;
3874 }
3875 }
3876
3877 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
3878 ompBuilder->applyWorkshareLoop(
3879 ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
3880 convertToScheduleKind(schedule), chunk, isSimd,
3881 scheduleMod == omp::ScheduleModifier::monotonic,
3882 scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
3883 workshareLoopType, noLoopMode, hasDistSchedule, distScheduleChunk);
3884
3885 if (failed(handleError(wsloopIP, opInst)))
3886 return failure();
3887
3888 // Emit finalization and in-place rewrites for linear vars.
3889 if (!wsloopOp.getLinearVars().empty()) {
3890 llvm::OpenMPIRBuilder::InsertPointTy oldIP = builder.saveIP();
3891 assert(loopInfo->getLastIter() &&
3892 "`lastiter` in CanonicalLoopInfo is nullptr");
3893 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
3894 linearClauseProcessor.finalizeLinearVar(builder, moduleTranslation,
3895 loopInfo->getLastIter());
3896 if (failed(handleError(afterBarrierIP, *loopOp)))
3897 return failure();
3898 for (size_t index = 0; index < wsloopOp.getLinearVars().size(); index++)
3899 linearClauseProcessor.rewriteInPlace(
3900 builder, sourceBlock->getSingleSuccessor(), *regionBlock,
3901 "omp.loop_nest.region", index);
3902
3903 builder.restoreIP(oldIP);
3904 }
3905
3906 // Set the correct branch target for task cancellation
3907 popCancelFinalizationCB(cancelTerminators, *ompBuilder, wsloopIP.get());
3908
3909 // Process the reductions if required.
3910 if (failed(createReductionsAndCleanup(
3911 wsloopOp, builder, moduleTranslation, allocaIP, reductionDecls,
3912 privateReductionVariables, isByRef, wsloopOp.getNowait(),
3913 /*isTeamsReduction=*/false)))
3914 return failure();
3915
3916 return cleanupPrivateVars(wsloopOp, builder, moduleTranslation,
3917 wsloopOp.getLoc(), privateVarsInfo);
3918}
3919
3920/// Converts the OpenMP parallel operation to LLVM IR.
3921static LogicalResult
3922convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
3923 LLVM::ModuleTranslation &moduleTranslation) {
3924 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
3925 ArrayRef<bool> isByRef = getIsByRef(opInst.getReductionByref());
3926 assert(isByRef.size() == opInst.getNumReductionVars());
3927 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3928 bool isCancellable = constructIsCancellable(opInst);
3929
3930 if (failed(checkImplementationStatus(*opInst)))
3931 return failure();
3932
3933 PrivateVarsInfo privateVarsInfo(opInst);
3934
3935 // Collect reduction declarations
3937 collectReductionDecls(opInst, reductionDecls);
3938 SmallVector<llvm::Value *> privateReductionVariables(
3939 opInst.getNumReductionVars());
3940 SmallVector<DeferredStore> deferredStores;
3941
3942 auto bodyGenCB =
3943 [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
3944 llvm::ArrayRef<llvm::BasicBlock *> deallocBlocks) -> llvm::Error {
3946 opInst, builder, moduleTranslation, privateVarsInfo, allocaIP);
3947 if (handleError(afterAllocas, *opInst).failed())
3948 return llvm::make_error<PreviouslyReportedError>();
3949
3950 // Allocate reduction vars
3951 DenseMap<Value, llvm::Value *> reductionVariableMap;
3952
3953 MutableArrayRef<BlockArgument> reductionArgs =
3954 cast<omp::BlockArgOpenMPOpInterface>(*opInst).getReductionBlockArgs();
3955
3956 allocaIP =
3957 InsertPointTy(allocaIP.getBlock(),
3958 allocaIP.getBlock()->getTerminator()->getIterator());
3959
3960 if (failed(allocReductionVars(
3961 opInst, reductionArgs, builder, moduleTranslation, allocaIP,
3962 reductionDecls, privateReductionVariables, reductionVariableMap,
3963 deferredStores, isByRef)))
3964 return llvm::make_error<PreviouslyReportedError>();
3965
3966 assert(afterAllocas.get()->getSinglePredecessor());
3967 builder.restoreIP(codeGenIP);
3968
3969 if (handleError(
3970 initPrivateVars(builder, moduleTranslation, privateVarsInfo),
3971 *opInst)
3972 .failed())
3973 return llvm::make_error<PreviouslyReportedError>();
3974
3975 if (failed(copyFirstPrivateVars(
3976 opInst, builder, moduleTranslation, privateVarsInfo.mlirVars,
3977 privateVarsInfo.llvmVars, privateVarsInfo.privatizers,
3978 opInst.getPrivateNeedsBarrier())))
3979 return llvm::make_error<PreviouslyReportedError>();
3980
3981 if (failed(
3982 initReductionVars(opInst, reductionArgs, builder, moduleTranslation,
3983 afterAllocas.get()->getSinglePredecessor(),
3984 reductionDecls, privateReductionVariables,
3985 reductionVariableMap, isByRef, deferredStores)))
3986 return llvm::make_error<PreviouslyReportedError>();
3987
3988 // Save the alloca insertion point on ModuleTranslation stack for use in
3989 // nested regions.
3991 moduleTranslation, allocaIP, deallocBlocks);
3992
3993 // ParallelOp has only one region associated with it.
3995 opInst.getRegion(), "omp.par.region", builder, moduleTranslation);
3996 if (!regionBlock)
3997 return regionBlock.takeError();
3998
3999 // Process the reductions if required.
4000 if (opInst.getNumReductionVars() > 0) {
4001 // Collect reduction info
4003 SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
4005 owningReductionGenRefDataPtrGens;
4007 collectReductionInfo(opInst, builder, moduleTranslation, reductionDecls,
4008 owningReductionGens, owningAtomicReductionGens,
4009 owningReductionGenRefDataPtrGens,
4010 privateReductionVariables, reductionInfos, isByRef);
4011
4012 // Move to region cont block
4013 builder.SetInsertPoint((*regionBlock)->getTerminator());
4014
4015 // Generate reductions from info
4016 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
4017 builder.SetInsertPoint(tempTerminator);
4018
4019 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
4020 ompBuilder->createReductions(
4021 builder.saveIP(), allocaIP, reductionInfos, isByRef,
4022 /*IsNoWait=*/false, /*IsTeamsReduction=*/false);
4023 if (!contInsertPoint)
4024 return contInsertPoint.takeError();
4025
4026 if (!contInsertPoint->getBlock())
4027 return llvm::make_error<PreviouslyReportedError>();
4028
4029 tempTerminator->eraseFromParent();
4030 builder.restoreIP(*contInsertPoint);
4031 }
4032
4033 return llvm::Error::success();
4034 };
4035
4036 auto privCB = [](InsertPointTy allocaIP, InsertPointTy codeGenIP,
4037 llvm::Value &, llvm::Value &val, llvm::Value *&replVal) {
4038 // tell OpenMPIRBuilder not to do anything. We handled Privatisation in
4039 // bodyGenCB.
4040 replVal = &val;
4041 return codeGenIP;
4042 };
4043
4044 // TODO: Perform finalization actions for variables. This has to be
4045 // called for variables which have destructors/finalizers.
4046 auto finiCB = [&](InsertPointTy codeGenIP) -> llvm::Error {
4047 InsertPointTy oldIP = builder.saveIP();
4048 builder.restoreIP(codeGenIP);
4049
4050 // if the reduction has a cleanup region, inline it here to finalize the
4051 // reduction variables
4052 SmallVector<Region *> reductionCleanupRegions;
4053 llvm::transform(reductionDecls, std::back_inserter(reductionCleanupRegions),
4054 [](omp::DeclareReductionOp reductionDecl) {
4055 return &reductionDecl.getCleanupRegion();
4056 });
4057 if (failed(inlineOmpRegionCleanup(
4058 reductionCleanupRegions, privateReductionVariables,
4059 moduleTranslation, builder, "omp.reduction.cleanup")))
4060 return llvm::createStringError(
4061 "failed to inline `cleanup` region of `omp.declare_reduction`");
4062
4063 if (failed(cleanupPrivateVars(opInst, builder, moduleTranslation,
4064 opInst.getLoc(), privateVarsInfo)))
4065 return llvm::make_error<PreviouslyReportedError>();
4066
4067 // If we could be performing cancellation, add the cancellation barrier on
4068 // the way out of the outlined region.
4069 if (isCancellable) {
4070 auto IPOrErr = ompBuilder->createBarrier(
4071 llvm::OpenMPIRBuilder::LocationDescription(builder),
4072 llvm::omp::Directive::OMPD_unknown,
4073 /* ForceSimpleCall */ false,
4074 /* CheckCancelFlag */ false);
4075 if (!IPOrErr)
4076 return IPOrErr.takeError();
4077 }
4078
4079 builder.restoreIP(oldIP);
4080 return llvm::Error::success();
4081 };
4082
4083 llvm::Value *ifCond = nullptr;
4084 if (auto ifVar = opInst.getIfExpr())
4085 ifCond = moduleTranslation.lookupValue(ifVar);
4086 llvm::Value *numThreads = nullptr;
4087 if (!opInst.getNumThreadsVars().empty())
4088 numThreads = moduleTranslation.lookupValue(opInst.getNumThreads(0));
4089 auto pbKind = llvm::omp::OMP_PROC_BIND_default;
4090 if (auto bind = opInst.getProcBindKind())
4091 pbKind = getProcBindKind(*bind);
4092
4094 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4095 findAllocInsertPoints(builder, moduleTranslation, &deallocBlocks);
4096 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4097
4098 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4099 ompBuilder->createParallel(ompLoc, allocaIP, deallocBlocks, bodyGenCB,
4100 privCB, finiCB, ifCond, numThreads, pbKind,
4101 isCancellable);
4102
4103 if (failed(handleError(afterIP, *opInst)))
4104 return failure();
4105
4106 builder.restoreIP(*afterIP);
4107 return success();
4108}
4109
4110/// Convert Order attribute to llvm::omp::OrderKind.
4111static llvm::omp::OrderKind
4112convertOrderKind(std::optional<omp::ClauseOrderKind> o) {
4113 if (!o)
4114 return llvm::omp::OrderKind::OMP_ORDER_unknown;
4115 switch (*o) {
4116 case omp::ClauseOrderKind::Concurrent:
4117 return llvm::omp::OrderKind::OMP_ORDER_concurrent;
4118 }
4119 llvm_unreachable("Unknown ClauseOrderKind kind");
4120}
4121
4122/// Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder.
4123static LogicalResult
4124convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
4125 LLVM::ModuleTranslation &moduleTranslation) {
4126 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4127 auto simdOp = cast<omp::SimdOp>(opInst);
4128
4129 if (failed(checkImplementationStatus(opInst)))
4130 return failure();
4131
4132 PrivateVarsInfo privateVarsInfo(simdOp);
4133
4134 MutableArrayRef<BlockArgument> reductionArgs =
4135 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
4136 DenseMap<Value, llvm::Value *> reductionVariableMap;
4137 SmallVector<llvm::Value *> privateReductionVariables(
4138 simdOp.getNumReductionVars());
4139 SmallVector<DeferredStore> deferredStores;
4141 collectReductionDecls(simdOp, reductionDecls);
4142 llvm::ArrayRef<bool> isByRef = getIsByRef(simdOp.getReductionByref());
4143 assert(isByRef.size() == simdOp.getNumReductionVars());
4144
4145 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4146 findAllocInsertPoints(builder, moduleTranslation);
4147
4149 simdOp, builder, moduleTranslation, privateVarsInfo, allocaIP);
4150 if (handleError(afterAllocas, opInst).failed())
4151 return failure();
4152
4153 // Initialize linear variables and linear step
4154 LinearClauseProcessor linearClauseProcessor;
4155
4156 if (!simdOp.getLinearVars().empty()) {
4157 auto linearVarTypes = simdOp.getLinearVarTypes().value();
4158 for (mlir::Attribute linearVarType : linearVarTypes)
4159 linearClauseProcessor.registerType(moduleTranslation, linearVarType);
4160 for (auto [idx, linearVar] : llvm::enumerate(simdOp.getLinearVars())) {
4161 bool isImplicit = false;
4162 for (auto [mlirPrivVar, llvmPrivateVar] : llvm::zip_equal(
4163 privateVarsInfo.mlirVars, privateVarsInfo.llvmVars)) {
4164 // If the linear variable is implicit, reuse the already
4165 // existing llvm::Value
4166 if (linearVar == mlirPrivVar) {
4167 isImplicit = true;
4168 linearClauseProcessor.createLinearVar(builder, moduleTranslation,
4169 llvmPrivateVar, idx);
4170 break;
4171 }
4172 }
4173
4174 if (!isImplicit)
4175 linearClauseProcessor.createLinearVar(
4176 builder, moduleTranslation,
4177 moduleTranslation.lookupValue(linearVar), idx);
4178 }
4179 for (mlir::Value linearStep : simdOp.getLinearStepVars())
4180 linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
4181 }
4182
4183 if (failed(allocReductionVars(simdOp, reductionArgs, builder,
4184 moduleTranslation, allocaIP, reductionDecls,
4185 privateReductionVariables, reductionVariableMap,
4186 deferredStores, isByRef)))
4187 return failure();
4188
4189 if (handleError(initPrivateVars(builder, moduleTranslation, privateVarsInfo),
4190 opInst)
4191 .failed())
4192 return failure();
4193
4194 // No call to copyFirstPrivateVars because FIRSTPRIVATE is not allowed for
4195 // SIMD.
4196
4197 assert(afterAllocas.get()->getSinglePredecessor());
4198 if (failed(initReductionVars(simdOp, reductionArgs, builder,
4199 moduleTranslation,
4200 afterAllocas.get()->getSinglePredecessor(),
4201 reductionDecls, privateReductionVariables,
4202 reductionVariableMap, isByRef, deferredStores)))
4203 return failure();
4204
4205 llvm::ConstantInt *simdlen = nullptr;
4206 if (std::optional<uint64_t> simdlenVar = simdOp.getSimdlen())
4207 simdlen = builder.getInt64(simdlenVar.value());
4208
4209 llvm::ConstantInt *safelen = nullptr;
4210 if (std::optional<uint64_t> safelenVar = simdOp.getSafelen())
4211 safelen = builder.getInt64(safelenVar.value());
4212
4213 llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars;
4214 llvm::omp::OrderKind order = convertOrderKind(simdOp.getOrder());
4215
4216 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
4217 std::optional<ArrayAttr> alignmentValues = simdOp.getAlignments();
4218 mlir::OperandRange operands = simdOp.getAlignedVars();
4219 for (size_t i = 0; i < operands.size(); ++i) {
4220 llvm::Value *alignment = nullptr;
4221 llvm::Value *llvmVal = moduleTranslation.lookupValue(operands[i]);
4222 llvm::Type *ty = llvmVal->getType();
4223
4224 auto intAttr = cast<IntegerAttr>((*alignmentValues)[i]);
4225 alignment = builder.getInt64(intAttr.getInt());
4226 assert(ty->isPointerTy() && "Invalid type for aligned variable");
4227 assert(alignment && "Invalid alignment value");
4228
4229 // Check if the alignment value is not a power of 2. If so, skip emitting
4230 // alignment.
4231 if (!intAttr.getValue().isPowerOf2())
4232 continue;
4233
4234 auto curInsert = builder.saveIP();
4235 builder.SetInsertPoint(sourceBlock);
4236 llvmVal = builder.CreateLoad(ty, llvmVal);
4237 builder.restoreIP(curInsert);
4238 alignedVars[llvmVal] = alignment;
4239 }
4240
4242 simdOp.getRegion(), "omp.simd.region", builder, moduleTranslation);
4243
4244 if (failed(handleError(regionBlock, opInst)))
4245 return failure();
4246
4247 llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
4248 // Emit Initialization for linear variables
4249 if (simdOp.getLinearVars().size()) {
4250 linearClauseProcessor.initLinearVar(builder, moduleTranslation,
4251 loopInfo->getPreheader());
4252
4253 linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
4254 loopInfo->getIndVar());
4255 }
4256 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
4257
4258 ompBuilder->applySimd(loopInfo, alignedVars,
4259 simdOp.getIfExpr()
4260 ? moduleTranslation.lookupValue(simdOp.getIfExpr())
4261 : nullptr,
4262 order, simdlen, safelen);
4263
4264 linearClauseProcessor.emitStoresForLinearVar(builder);
4265
4266 // Check if this SIMD loop contains ordered regions
4267 bool hasOrderedRegions = false;
4268 simdOp.getRegion().walk([&](omp::OrderedRegionOp orderedOp) {
4269 hasOrderedRegions = true;
4270 return WalkResult::interrupt();
4271 });
4272
4273 for (size_t index = 0; index < simdOp.getLinearVars().size(); index++) {
4274 llvm::BasicBlock *startBB = sourceBlock->getSingleSuccessor();
4275 llvm::BasicBlock *endBB = *regionBlock;
4276 linearClauseProcessor.rewriteInPlace(builder, startBB, endBB,
4277 "omp.loop_nest.region", index);
4278
4279 if (hasOrderedRegions) {
4280 // Also rewrite uses in ordered regions so they read the current value
4281 linearClauseProcessor.rewriteInPlace(builder, startBB, endBB,
4282 "omp.ordered.region", index);
4283 // Also rewrite uses in finalize blocks (code after ordered regions)
4284 linearClauseProcessor.rewriteInPlace(builder, startBB, endBB,
4285 "omp_region.finalize", index);
4286 }
4287 }
4288
4289 // We now need to reduce the per-simd-lane reduction variable into the
4290 // original variable. This works a bit differently to other reductions (e.g.
4291 // wsloop) because we don't need to call into the OpenMP runtime to handle
4292 // threads: everything happened in this one thread.
4293 for (auto [i, tuple] : llvm::enumerate(
4294 llvm::zip(reductionDecls, isByRef, simdOp.getReductionVars(),
4295 privateReductionVariables))) {
4296 auto [decl, byRef, reductionVar, privateReductionVar] = tuple;
4297
4298 OwningReductionGen gen = makeReductionGen(decl, builder, moduleTranslation);
4299 llvm::Value *originalVariable = moduleTranslation.lookupValue(reductionVar);
4300 llvm::Type *reductionType = moduleTranslation.convertType(decl.getType());
4301
4302 // We have one less load for by-ref case because that load is now inside of
4303 // the reduction region.
4304 llvm::Value *redValue = originalVariable;
4305 if (!byRef)
4306 redValue =
4307 builder.CreateLoad(reductionType, redValue, "red.value." + Twine(i));
4308 llvm::Value *privateRedValue = builder.CreateLoad(
4309 reductionType, privateReductionVar, "red.private.value." + Twine(i));
4310 llvm::Value *reduced;
4311
4312 auto res = gen(builder.saveIP(), redValue, privateRedValue, reduced);
4313 if (failed(handleError(res, opInst)))
4314 return failure();
4315 builder.restoreIP(res.get());
4316
4317 // For by-ref case, the store is inside of the reduction region.
4318 if (!byRef)
4319 builder.CreateStore(reduced, originalVariable);
4320 }
4321
4322 // After the construct, deallocate private reduction variables.
4323 SmallVector<Region *> reductionRegions;
4324 llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
4325 [](omp::DeclareReductionOp reductionDecl) {
4326 return &reductionDecl.getCleanupRegion();
4327 });
4328 if (failed(inlineOmpRegionCleanup(reductionRegions, privateReductionVariables,
4329 moduleTranslation, builder,
4330 "omp.reduction.cleanup")))
4331 return failure();
4332
4333 return cleanupPrivateVars(simdOp, builder, moduleTranslation, simdOp.getLoc(),
4334 privateVarsInfo);
4335}
4336
4337/// Converts an OpenMP loop nest into LLVM IR using OpenMPIRBuilder.
4338static LogicalResult
4339convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder,
4340 LLVM::ModuleTranslation &moduleTranslation) {
4341 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4342 auto loopOp = cast<omp::LoopNestOp>(opInst);
4343
4344 if (failed(checkImplementationStatus(opInst)))
4345 return failure();
4346
4347 // Set up the source location value for OpenMP runtime.
4348 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4349
4350 // Generator of the canonical loop body.
4353 auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip,
4354 llvm::Value *iv) -> llvm::Error {
4355 // Make sure further conversions know about the induction variable.
4356 moduleTranslation.mapValue(
4357 loopOp.getRegion().front().getArgument(loopInfos.size()), iv);
4358
4359 // Capture the body insertion point for use in nested loops. BodyIP of the
4360 // CanonicalLoopInfo always points to the beginning of the entry block of
4361 // the body.
4362 bodyInsertPoints.push_back(ip);
4363
4364 if (loopInfos.size() != loopOp.getNumLoops() - 1)
4365 return llvm::Error::success();
4366
4367 // Convert the body of the loop.
4368 builder.restoreIP(ip);
4370 loopOp.getRegion(), "omp.loop_nest.region", builder, moduleTranslation);
4371 if (!regionBlock)
4372 return regionBlock.takeError();
4373
4374 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
4375 return llvm::Error::success();
4376 };
4377
4378 // Delegate actual loop construction to the OpenMP IRBuilder.
4379 // TODO: this currently assumes omp.loop_nest is semantically similar to SCF
4380 // loop, i.e. it has a positive step, uses signed integer semantics.
4381 // Reconsider this code when the nested loop operation clearly supports more
4382 // cases.
4383 for (unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) {
4384 llvm::Value *lowerBound =
4385 moduleTranslation.lookupValue(loopOp.getLoopLowerBounds()[i]);
4386 llvm::Value *upperBound =
4387 moduleTranslation.lookupValue(loopOp.getLoopUpperBounds()[i]);
4388 llvm::Value *step = moduleTranslation.lookupValue(loopOp.getLoopSteps()[i]);
4389
4390 // Make sure loop trip count are emitted in the preheader of the outermost
4391 // loop at the latest so that they are all available for the new collapsed
4392 // loop will be created below.
4393 llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc;
4394 llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP;
4395 if (i != 0) {
4396 loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back(),
4397 ompLoc.DL);
4398 computeIP = loopInfos.front()->getPreheaderIP();
4399 }
4400
4402 ompBuilder->createCanonicalLoop(
4403 loc, bodyGen, lowerBound, upperBound, step,
4404 /*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP);
4405
4406 if (failed(handleError(loopResult, *loopOp)))
4407 return failure();
4408
4409 loopInfos.push_back(*loopResult);
4410 }
4411
4412 llvm::OpenMPIRBuilder::InsertPointTy afterIP =
4413 loopInfos.front()->getAfterIP();
4414
4415 // Do tiling.
4416 if (const auto &tiles = loopOp.getTileSizes()) {
4417 llvm::Type *ivType = loopInfos.front()->getIndVarType();
4419
4420 for (auto tile : tiles.value()) {
4421 llvm::Value *tileVal = llvm::ConstantInt::get(ivType, tile);
4422 tileSizes.push_back(tileVal);
4423 }
4424
4425 std::vector<llvm::CanonicalLoopInfo *> newLoops =
4426 ompBuilder->tileLoops(ompLoc.DL, loopInfos, tileSizes);
4427
4428 // Update afterIP to get the correct insertion point after
4429 // tiling.
4430 llvm::BasicBlock *afterBB = newLoops.front()->getAfter();
4431 llvm::BasicBlock *afterAfterBB = afterBB->getSingleSuccessor();
4432 afterIP = {afterAfterBB, afterAfterBB->begin()};
4433
4434 // Update the loop infos.
4435 loopInfos.clear();
4436 for (const auto &newLoop : newLoops)
4437 loopInfos.push_back(newLoop);
4438 } // Tiling done.
4439
4440 // Do collapse.
4441 const auto &numCollapse = loopOp.getCollapseNumLoops();
4443 loopInfos.begin(), loopInfos.begin() + (numCollapse));
4444
4445 auto newTopLoopInfo =
4446 ompBuilder->collapseLoops(ompLoc.DL, collapseLoopInfos, {});
4447
4448 assert(newTopLoopInfo && "New top loop information is missing");
4449 moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
4450 [&](OpenMPLoopInfoStackFrame &frame) {
4451 frame.loopInfo = newTopLoopInfo;
4452 return WalkResult::interrupt();
4453 });
4454
4455 // Continue building IR after the loop. Note that the LoopInfo returned by
4456 // `collapseLoops` points inside the outermost loop and is intended for
4457 // potential further loop transformations. Use the insertion point stored
4458 // before collapsing loops instead.
4459 builder.restoreIP(afterIP);
4460 return success();
4461}
4462
4463/// Convert an omp.canonical_loop to LLVM-IR
4464static LogicalResult
4465convertOmpCanonicalLoopOp(omp::CanonicalLoopOp op, llvm::IRBuilderBase &builder,
4466 LLVM::ModuleTranslation &moduleTranslation) {
4467 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4468
4469 llvm::OpenMPIRBuilder::LocationDescription loopLoc(builder);
4470 Value loopIV = op.getInductionVar();
4471 Value loopTC = op.getTripCount();
4472
4473 llvm::Value *llvmTC = moduleTranslation.lookupValue(loopTC);
4474
4476 ompBuilder->createCanonicalLoop(
4477 loopLoc,
4478 [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *llvmIV) {
4479 // Register the mapping of MLIR induction variable to LLVM-IR
4480 // induction variable
4481 moduleTranslation.mapValue(loopIV, llvmIV);
4482
4483 builder.restoreIP(ip);
4485 convertOmpOpRegions(op.getRegion(), "omp.loop.region", builder,
4486 moduleTranslation);
4487
4488 return bodyGenStatus.takeError();
4489 },
4490 llvmTC, "omp.loop");
4491 if (!llvmOrError)
4492 return op.emitError(llvm::toString(llvmOrError.takeError()));
4493
4494 llvm::CanonicalLoopInfo *llvmCLI = *llvmOrError;
4495 llvm::IRBuilderBase::InsertPoint afterIP = llvmCLI->getAfterIP();
4496 builder.restoreIP(afterIP);
4497
4498 // Register the mapping of MLIR loop to LLVM-IR OpenMPIRBuilder loop
4499 if (Value cli = op.getCli())
4500 moduleTranslation.mapOmpLoop(cli, llvmCLI);
4501
4502 return success();
4503}
4504
4505/// Apply a `#pragma omp unroll` / "!$omp unroll" transformation using the
4506/// OpenMPIRBuilder.
4507static LogicalResult
4508applyUnrollHeuristic(omp::UnrollHeuristicOp op, llvm::IRBuilderBase &builder,
4509 LLVM::ModuleTranslation &moduleTranslation) {
4510 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4511
4512 Value applyee = op.getApplyee();
4513 assert(applyee && "Loop to apply unrolling on required");
4514
4515 llvm::CanonicalLoopInfo *consBuilderCLI =
4516 moduleTranslation.lookupOMPLoop(applyee);
4517 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
4518 ompBuilder->unrollLoopHeuristic(loc.DL, consBuilderCLI);
4519
4520 moduleTranslation.invalidateOmpLoop(applyee);
4521 return success();
4522}
4523
4524/// Apply a `#pragma omp tile` / `!$omp tile` transformation using the
4525/// OpenMPIRBuilder.
4526static LogicalResult applyTile(omp::TileOp op, llvm::IRBuilderBase &builder,
4527 LLVM::ModuleTranslation &moduleTranslation) {
4528 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4529 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
4530
4532 SmallVector<llvm::Value *> translatedSizes;
4533
4534 for (Value size : op.getSizes()) {
4535 llvm::Value *translatedSize = moduleTranslation.lookupValue(size);
4536 assert(translatedSize &&
4537 "sizes clause arguments must already be translated");
4538 translatedSizes.push_back(translatedSize);
4539 }
4540
4541 for (Value applyee : op.getApplyees()) {
4542 llvm::CanonicalLoopInfo *consBuilderCLI =
4543 moduleTranslation.lookupOMPLoop(applyee);
4544 assert(applyee && "Canonical loop must already been translated");
4545 translatedLoops.push_back(consBuilderCLI);
4546 }
4547
4548 auto generatedLoops =
4549 ompBuilder->tileLoops(loc.DL, translatedLoops, translatedSizes);
4550 if (!op.getGeneratees().empty()) {
4551 for (auto [mlirLoop, genLoop] :
4552 zip_equal(op.getGeneratees(), generatedLoops))
4553 moduleTranslation.mapOmpLoop(mlirLoop, genLoop);
4554 }
4555
4556 // CLIs can only be consumed once
4557 for (Value applyee : op.getApplyees())
4558 moduleTranslation.invalidateOmpLoop(applyee);
4559
4560 return success();
4561}
4562
4563/// Apply a `#pragma omp fuse` / `!$omp fuse` transformation using the
4564/// OpenMPIRBuilder.
4565static LogicalResult applyFuse(omp::FuseOp op, llvm::IRBuilderBase &builder,
4566 LLVM::ModuleTranslation &moduleTranslation) {
4567 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4568 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
4569
4570 // Select what CLIs are going to be fused
4571 SmallVector<llvm::CanonicalLoopInfo *> beforeFuse, toFuse, afterFuse;
4572 for (size_t i = 0; i < op.getApplyees().size(); i++) {
4573 Value applyee = op.getApplyees()[i];
4574 llvm::CanonicalLoopInfo *consBuilderCLI =
4575 moduleTranslation.lookupOMPLoop(applyee);
4576 assert(applyee && "Canonical loop must already been translated");
4577 if (op.getFirst().has_value() && i < op.getFirst().value() - 1)
4578 beforeFuse.push_back(consBuilderCLI);
4579 else if (op.getCount().has_value() &&
4580 i >= op.getFirst().value() + op.getCount().value() - 1)
4581 afterFuse.push_back(consBuilderCLI);
4582 else
4583 toFuse.push_back(consBuilderCLI);
4584 }
4585 assert(
4586 (op.getGeneratees().empty() ||
4587 beforeFuse.size() + afterFuse.size() + 1 == op.getGeneratees().size()) &&
4588 "Wrong number of generatees");
4589
4590 // do the fuse
4591 auto generatedLoop = ompBuilder->fuseLoops(loc.DL, toFuse);
4592 if (!op.getGeneratees().empty()) {
4593 size_t i = 0;
4594 for (; i < beforeFuse.size(); i++)
4595 moduleTranslation.mapOmpLoop(op.getGeneratees()[i], beforeFuse[i]);
4596 moduleTranslation.mapOmpLoop(op.getGeneratees()[i++], generatedLoop);
4597 for (; i < afterFuse.size(); i++)
4598 moduleTranslation.mapOmpLoop(op.getGeneratees()[i], afterFuse[i]);
4599 }
4600
4601 // CLIs can only be consumed once
4602 for (Value applyee : op.getApplyees())
4603 moduleTranslation.invalidateOmpLoop(applyee);
4604
4605 return success();
4606}
4607
4608/// Convert an Atomic Ordering attribute to llvm::AtomicOrdering.
4609static llvm::AtomicOrdering
4610convertAtomicOrdering(std::optional<omp::ClauseMemoryOrderKind> ao) {
4611 if (!ao)
4612 return llvm::AtomicOrdering::Monotonic; // Default Memory Ordering
4613
4614 switch (*ao) {
4615 case omp::ClauseMemoryOrderKind::Seq_cst:
4616 return llvm::AtomicOrdering::SequentiallyConsistent;
4617 case omp::ClauseMemoryOrderKind::Acq_rel:
4618 return llvm::AtomicOrdering::AcquireRelease;
4619 case omp::ClauseMemoryOrderKind::Acquire:
4620 return llvm::AtomicOrdering::Acquire;
4621 case omp::ClauseMemoryOrderKind::Release:
4622 return llvm::AtomicOrdering::Release;
4623 case omp::ClauseMemoryOrderKind::Relaxed:
4624 return llvm::AtomicOrdering::Monotonic;
4625 }
4626 llvm_unreachable("Unknown ClauseMemoryOrderKind kind");
4627}
4628
4629/// Convert omp.atomic.read operation to LLVM IR.
4630static LogicalResult
4631convertOmpAtomicRead(Operation &opInst, llvm::IRBuilderBase &builder,
4632 LLVM::ModuleTranslation &moduleTranslation) {
4633 auto readOp = cast<omp::AtomicReadOp>(opInst);
4634 if (failed(checkImplementationStatus(opInst)))
4635 return failure();
4636
4637 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4638 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4639 findAllocInsertPoints(builder, moduleTranslation);
4640
4641 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4642
4643 llvm::AtomicOrdering AO = convertAtomicOrdering(readOp.getMemoryOrder());
4644 llvm::Value *x = moduleTranslation.lookupValue(readOp.getX());
4645 llvm::Value *v = moduleTranslation.lookupValue(readOp.getV());
4646
4647 llvm::Type *elementType =
4648 moduleTranslation.convertType(readOp.getElementType());
4649
4650 llvm::OpenMPIRBuilder::AtomicOpValue V = {v, elementType, false, false};
4651 llvm::OpenMPIRBuilder::AtomicOpValue X = {x, elementType, false, false};
4652 builder.restoreIP(ompBuilder->createAtomicRead(ompLoc, X, V, AO, allocaIP));
4653 return success();
4654}
4655
4656/// Converts an omp.atomic.write operation to LLVM IR.
4657static LogicalResult
4658convertOmpAtomicWrite(Operation &opInst, llvm::IRBuilderBase &builder,
4659 LLVM::ModuleTranslation &moduleTranslation) {
4660 auto writeOp = cast<omp::AtomicWriteOp>(opInst);
4661 if (failed(checkImplementationStatus(opInst)))
4662 return failure();
4663
4664 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4665 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4666 findAllocInsertPoints(builder, moduleTranslation);
4667
4668 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4669 llvm::AtomicOrdering ao = convertAtomicOrdering(writeOp.getMemoryOrder());
4670 llvm::Value *expr = moduleTranslation.lookupValue(writeOp.getExpr());
4671 llvm::Value *dest = moduleTranslation.lookupValue(writeOp.getX());
4672 llvm::Type *ty = moduleTranslation.convertType(writeOp.getExpr().getType());
4673 llvm::OpenMPIRBuilder::AtomicOpValue x = {dest, ty, /*isSigned=*/false,
4674 /*isVolatile=*/false};
4675 builder.restoreIP(
4676 ompBuilder->createAtomicWrite(ompLoc, x, expr, ao, allocaIP));
4677 return success();
4678}
4679
4680/// Converts an LLVM dialect binary operation to the corresponding enum value
4681/// for `atomicrmw` supported binary operation.
4682static llvm::AtomicRMWInst::BinOp convertBinOpToAtomic(Operation &op) {
4684 .Case([&](LLVM::AddOp) { return llvm::AtomicRMWInst::BinOp::Add; })
4685 .Case([&](LLVM::SubOp) { return llvm::AtomicRMWInst::BinOp::Sub; })
4686 .Case([&](LLVM::AndOp) { return llvm::AtomicRMWInst::BinOp::And; })
4687 .Case([&](LLVM::OrOp) { return llvm::AtomicRMWInst::BinOp::Or; })
4688 .Case([&](LLVM::XOrOp) { return llvm::AtomicRMWInst::BinOp::Xor; })
4689 .Case([&](LLVM::UMaxOp) { return llvm::AtomicRMWInst::BinOp::UMax; })
4690 .Case([&](LLVM::UMinOp) { return llvm::AtomicRMWInst::BinOp::UMin; })
4691 .Case([&](LLVM::FAddOp) { return llvm::AtomicRMWInst::BinOp::FAdd; })
4692 .Case([&](LLVM::FSubOp) { return llvm::AtomicRMWInst::BinOp::FSub; })
4693 .Default(llvm::AtomicRMWInst::BinOp::BAD_BINOP);
4694}
4695
4696static void extractAtomicControlFlags(omp::AtomicUpdateOp atomicUpdateOp,
4697 bool &isIgnoreDenormalMode,
4698 bool &isFineGrainedMemory,
4699 bool &isRemoteMemory) {
4700 isIgnoreDenormalMode = false;
4701 isFineGrainedMemory = false;
4702 isRemoteMemory = false;
4703 if (atomicUpdateOp &&
4704 atomicUpdateOp->hasAttr(atomicUpdateOp.getAtomicControlAttrName())) {
4705 mlir::omp::AtomicControlAttr atomicControlAttr =
4706 atomicUpdateOp.getAtomicControlAttr();
4707 isIgnoreDenormalMode = atomicControlAttr.getIgnoreDenormalMode();
4708 isFineGrainedMemory = atomicControlAttr.getFineGrainedMemory();
4709 isRemoteMemory = atomicControlAttr.getRemoteMemory();
4710 }
4711}
4712
4713/// Converts an OpenMP atomic update operation using OpenMPIRBuilder.
4714static LogicalResult
4715convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst,
4716 llvm::IRBuilderBase &builder,
4717 LLVM::ModuleTranslation &moduleTranslation) {
4718 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4719 if (failed(checkImplementationStatus(*opInst)))
4720 return failure();
4721
4722 // Convert values and types.
4723 auto &innerOpList = opInst.getRegion().front().getOperations();
4724 bool isXBinopExpr{false};
4725 llvm::AtomicRMWInst::BinOp binop;
4726 mlir::Value mlirExpr;
4727 llvm::Value *llvmExpr = nullptr;
4728 llvm::Value *llvmX = nullptr;
4729 llvm::Type *llvmXElementType = nullptr;
4730 if (innerOpList.size() == 2) {
4731 // The two operations here are the update and the terminator.
4732 // Since we can identify the update operation, there is a possibility
4733 // that we can generate the atomicrmw instruction.
4734 mlir::Operation &innerOp = *opInst.getRegion().front().begin();
4735 if (!llvm::is_contained(innerOp.getOperands(),
4736 opInst.getRegion().getArgument(0))) {
4737 return opInst.emitError("no atomic update operation with region argument"
4738 " as operand found inside atomic.update region");
4739 }
4740 binop = convertBinOpToAtomic(innerOp);
4741 isXBinopExpr = innerOp.getOperand(0) == opInst.getRegion().getArgument(0);
4742 mlirExpr = (isXBinopExpr ? innerOp.getOperand(1) : innerOp.getOperand(0));
4743 llvmExpr = moduleTranslation.lookupValue(mlirExpr);
4744 } else {
4745 // Since the update region includes more than one operation
4746 // we will resort to generating a cmpxchg loop.
4747 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
4748 }
4749 llvmX = moduleTranslation.lookupValue(opInst.getX());
4750 llvmXElementType = moduleTranslation.convertType(
4751 opInst.getRegion().getArgument(0).getType());
4752 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
4753 /*isSigned=*/false,
4754 /*isVolatile=*/false};
4755
4756 llvm::AtomicOrdering atomicOrdering =
4757 convertAtomicOrdering(opInst.getMemoryOrder());
4758
4759 // Generate update code.
4760 auto updateFn =
4761 [&opInst, &moduleTranslation](
4762 llvm::Value *atomicx,
4763 llvm::IRBuilder<> &builder) -> llvm::Expected<llvm::Value *> {
4764 Block &bb = *opInst.getRegion().begin();
4765 moduleTranslation.mapValue(*opInst.getRegion().args_begin(), atomicx);
4766 moduleTranslation.mapBlock(&bb, builder.GetInsertBlock());
4767 if (failed(moduleTranslation.convertBlock(bb, true, builder)))
4768 return llvm::make_error<PreviouslyReportedError>();
4769
4770 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.getTerminator());
4771 assert(yieldop && yieldop.getResults().size() == 1 &&
4772 "terminator must be omp.yield op and it must have exactly one "
4773 "argument");
4774 return moduleTranslation.lookupValue(yieldop.getResults()[0]);
4775 };
4776
4777 bool isIgnoreDenormalMode;
4778 bool isFineGrainedMemory;
4779 bool isRemoteMemory;
4780 extractAtomicControlFlags(opInst, isIgnoreDenormalMode, isFineGrainedMemory,
4781 isRemoteMemory);
4782 // Handle ambiguous alloca, if any.
4783 auto allocaIP = findAllocInsertPoints(builder, moduleTranslation);
4784 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4785 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4786 ompBuilder->createAtomicUpdate(ompLoc, allocaIP, llvmAtomicX, llvmExpr,
4787 atomicOrdering, binop, updateFn,
4788 isXBinopExpr, isIgnoreDenormalMode,
4789 isFineGrainedMemory, isRemoteMemory);
4790
4791 if (failed(handleError(afterIP, *opInst)))
4792 return failure();
4793
4794 builder.restoreIP(*afterIP);
4795 return success();
4796}
4797
4798static LogicalResult
4799convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp,
4800 llvm::IRBuilderBase &builder,
4801 LLVM::ModuleTranslation &moduleTranslation) {
4802 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4803 if (failed(checkImplementationStatus(*atomicCaptureOp)))
4804 return failure();
4805
4806 mlir::Value mlirExpr;
4807 bool isXBinopExpr = false, isPostfixUpdate = false;
4808 llvm::AtomicRMWInst::BinOp binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
4809
4810 omp::AtomicUpdateOp atomicUpdateOp = atomicCaptureOp.getAtomicUpdateOp();
4811 omp::AtomicWriteOp atomicWriteOp = atomicCaptureOp.getAtomicWriteOp();
4812
4813 assert((atomicUpdateOp || atomicWriteOp) &&
4814 "internal op must be an atomic.update or atomic.write op");
4815
4816 if (atomicWriteOp) {
4817 isPostfixUpdate = true;
4818 mlirExpr = atomicWriteOp.getExpr();
4819 } else {
4820 isPostfixUpdate = atomicCaptureOp.getSecondOp() ==
4821 atomicCaptureOp.getAtomicUpdateOp().getOperation();
4822 auto &innerOpList = atomicUpdateOp.getRegion().front().getOperations();
4823 // Find the binary update operation that uses the region argument
4824 // and get the expression to update
4825 if (innerOpList.size() == 2) {
4826 mlir::Operation &innerOp = *atomicUpdateOp.getRegion().front().begin();
4827 if (!llvm::is_contained(innerOp.getOperands(),
4828 atomicUpdateOp.getRegion().getArgument(0))) {
4829 return atomicUpdateOp.emitError(
4830 "no atomic update operation with region argument"
4831 " as operand found inside atomic.update region");
4832 }
4833 binop = convertBinOpToAtomic(innerOp);
4834 isXBinopExpr =
4835 innerOp.getOperand(0) == atomicUpdateOp.getRegion().getArgument(0);
4836 mlirExpr = (isXBinopExpr ? innerOp.getOperand(1) : innerOp.getOperand(0));
4837 } else {
4838 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
4839 }
4840 }
4841
4842 llvm::Value *llvmExpr = moduleTranslation.lookupValue(mlirExpr);
4843 llvm::Value *llvmX =
4844 moduleTranslation.lookupValue(atomicCaptureOp.getAtomicReadOp().getX());
4845 llvm::Value *llvmV =
4846 moduleTranslation.lookupValue(atomicCaptureOp.getAtomicReadOp().getV());
4847 llvm::Type *llvmXElementType = moduleTranslation.convertType(
4848 atomicCaptureOp.getAtomicReadOp().getElementType());
4849 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
4850 /*isSigned=*/false,
4851 /*isVolatile=*/false};
4852 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicV = {llvmV, llvmXElementType,
4853 /*isSigned=*/false,
4854 /*isVolatile=*/false};
4855
4856 llvm::AtomicOrdering atomicOrdering =
4857 convertAtomicOrdering(atomicCaptureOp.getMemoryOrder());
4858
4859 auto updateFn =
4860 [&](llvm::Value *atomicx,
4861 llvm::IRBuilder<> &builder) -> llvm::Expected<llvm::Value *> {
4862 if (atomicWriteOp)
4863 return moduleTranslation.lookupValue(atomicWriteOp.getExpr());
4864 Block &bb = *atomicUpdateOp.getRegion().begin();
4865 moduleTranslation.mapValue(*atomicUpdateOp.getRegion().args_begin(),
4866 atomicx);
4867 moduleTranslation.mapBlock(&bb, builder.GetInsertBlock());
4868 if (failed(moduleTranslation.convertBlock(bb, true, builder)))
4869 return llvm::make_error<PreviouslyReportedError>();
4870
4871 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.getTerminator());
4872 assert(yieldop && yieldop.getResults().size() == 1 &&
4873 "terminator must be omp.yield op and it must have exactly one "
4874 "argument");
4875 return moduleTranslation.lookupValue(yieldop.getResults()[0]);
4876 };
4877
4878 bool isIgnoreDenormalMode;
4879 bool isFineGrainedMemory;
4880 bool isRemoteMemory;
4881 extractAtomicControlFlags(atomicUpdateOp, isIgnoreDenormalMode,
4882 isFineGrainedMemory, isRemoteMemory);
4883 // Handle ambiguous alloca, if any.
4884 auto allocaIP = findAllocInsertPoints(builder, moduleTranslation);
4885 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4886 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4887 ompBuilder->createAtomicCapture(
4888 ompLoc, allocaIP, llvmAtomicX, llvmAtomicV, llvmExpr, atomicOrdering,
4889 binop, updateFn, atomicUpdateOp, isPostfixUpdate, isXBinopExpr,
4890 isIgnoreDenormalMode, isFineGrainedMemory, isRemoteMemory);
4891
4892 if (failed(handleError(afterIP, *atomicCaptureOp)))
4893 return failure();
4894
4895 builder.restoreIP(*afterIP);
4896 return success();
4897}
4898
4899/// Helper to extract the OMPAtomicCompareOp from an integer comparison
4900/// predicate. Returns std::nullopt for unsupported predicates.
4901static std::optional<llvm::omp::OMPAtomicCompareOp>
4902convertICmpPredicateToAtomicCompareOp(LLVM::ICmpPredicate predicate) {
4903 switch (predicate) {
4904 case LLVM::ICmpPredicate::eq:
4905 return llvm::omp::OMPAtomicCompareOp::EQ;
4906 case LLVM::ICmpPredicate::slt:
4907 case LLVM::ICmpPredicate::ult:
4908 return llvm::omp::OMPAtomicCompareOp::MIN;
4909 case LLVM::ICmpPredicate::sgt:
4910 case LLVM::ICmpPredicate::ugt:
4911 return llvm::omp::OMPAtomicCompareOp::MAX;
4912 default:
4913 return std::nullopt;
4914 }
4915}
4916
4917/// Helper to extract the OMPAtomicCompareOp from a floating-point comparison
4918/// predicate. Returns std::nullopt for unsupported predicates.
4919static std::optional<llvm::omp::OMPAtomicCompareOp>
4920convertFCmpPredicateToAtomicCompareOp(LLVM::FCmpPredicate predicate) {
4921 switch (predicate) {
4922 case LLVM::FCmpPredicate::oeq:
4923 case LLVM::FCmpPredicate::ueq:
4924 return llvm::omp::OMPAtomicCompareOp::EQ;
4925 case LLVM::FCmpPredicate::olt:
4926 case LLVM::FCmpPredicate::ult:
4927 return llvm::omp::OMPAtomicCompareOp::MIN;
4928 case LLVM::FCmpPredicate::ogt:
4929 case LLVM::FCmpPredicate::ugt:
4930 return llvm::omp::OMPAtomicCompareOp::MAX;
4931 default:
4932 return std::nullopt;
4933 }
4934}
4935
4936/// Converts an omp.atomic.compare operation to LLVM IR.
4937///
4938/// if (x == e) x = d
4939/// The region contains a comparison + select pattern:
4940/// ^bb0(%xval: T):
4941/// %cmp = llvm.icmp/fcmp <pred> %xval, %e : T
4942/// %sel = llvm.select %cmp, %d, %xval : i1, T
4943/// omp.yield(%sel : T)
4944///
4945/// From MLIR extract:
4946/// 1) comparison operator
4947/// 2) expected value (e)
4948/// 3) desired value (d)
4949/// These are passed to OpenMPIRBuilder::createAtomicCompare which generates
4950/// the actual cmpxchg / atomicrmw instruction.
4951///
4952static LogicalResult
4953convertOmpAtomicCompare(omp::AtomicCompareOp atomicCompareOp,
4954 llvm::IRBuilderBase &builder,
4955 LLVM::ModuleTranslation &moduleTranslation) {
4956 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4957 if (failed(checkImplementationStatus(*atomicCompareOp)))
4958 return failure();
4959
4960 Region &region = atomicCompareOp.getRegion();
4961 Block &block = region.front();
4962
4963 // Determine element type from the region block argument
4964 llvm::Type *llvmXElementType =
4965 moduleTranslation.convertType(block.getArgument(0).getType());
4966 if (!llvmXElementType)
4967 return atomicCompareOp.emitError(
4968 "unable to determine element type for atomic compare");
4969
4970 llvm::Value *llvmX = moduleTranslation.lookupValue(atomicCompareOp.getX());
4971
4972 // IsSigned is determined from the comparison predicate in the region.
4973 // Signed ICmp predicates (slt/sgt) set this to true; unsigned (ult/ugt)
4974 // leave it false. For EQ and float comparisons, signedness is irrelevant.
4975 bool isSigned = false;
4976 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
4977 isSigned,
4978 /*IsVolatile=*/false};
4979
4980 llvm::AtomicOrdering atomicOrdering =
4981 convertAtomicOrdering(atomicCompareOp.getMemoryOrder());
4982
4983 auto isAtomicComparePatternOp = [](Operation &op) {
4984 return llvm::isa<LLVM::ICmpOp, LLVM::FCmpOp, LLVM::SelectOp, LLVM::AndOp,
4985 LLVM::OrOp>(op);
4986 };
4987
4988 // Pre-translate operations inside the region that compute e and d (e.g.,
4989 // GEP, loads for dereferencing Fortran pointers) but are not part of the
4990 // atomic compare-and-swap pattern (icmp/fcmp, select, and/or).
4991 //
4992 // 1) Validity: The OpenMP spec requires e and d to be evaluated before the
4993 // atomic operation, so emitting their computation here is correct.
4994 // 2) Memory effects: These ops only depend on values defined outside the
4995 // region. They cannot observe the block argument (%xval), which is the
4996 // value loaded atomically by cmpxchg and does not exist yet.
4997 // 3) Invariant enforcement: The `allOperandsMapped` check below skips any
4998 // op whose operands include the unmapped block argument, guaranteeing
4999 // only region-external-dependent ops are pre-translated.
5000 for (Operation &op : block.without_terminator()) {
5001 // Skip operations that form the atomic compare pattern — these are
5002 // not emitted as individual instructions but are analyzed below to
5003 // extract the comparison predicate, expected value (e), and desired
5004 // value (d) for generating a single cmpxchg/atomicrmw.
5005 if (isAtomicComparePatternOp(op))
5006 continue;
5007
5008 // Avoid translating ops that depend on the unmapped block argument.
5009 bool allOperandsMapped = llvm::all_of(op.getOperands(), [&](mlir::Value v) {
5010 return moduleTranslation.lookupValue(v) != nullptr;
5011 });
5012 if (!allOperandsMapped)
5013 continue;
5014
5015 if (failed(moduleTranslation.convertOperation(op, builder)))
5016 return atomicCompareOp.emitError(
5017 "failed to translate operation inside atomic compare region");
5018 }
5019
5020 // Look up a value that may have been pre-translated or defined outside the
5021 // region.
5022 auto materializeValue = [&](mlir::Value val) -> llvm::Value * {
5023 // Check if the value is already mapped (pre-translated or defined outside).
5024 if (llvm::Value *existing = moduleTranslation.lookupValue(val))
5025 return existing;
5026 // Fallback for a single LoadOp whose address is mapped but whose result
5027 // was not pre-translated.
5028 if (auto loadOp = val.getDefiningOp<LLVM::LoadOp>()) {
5029 if (loadOp->getParentRegion() == &region) {
5030 llvm::Value *loadAddr = moduleTranslation.lookupValue(loadOp.getAddr());
5031 if (!loadAddr)
5032 return nullptr;
5033 llvm::Type *loadType =
5034 moduleTranslation.convertType(loadOp.getResult().getType());
5035 return builder.CreateLoad(loadType, loadAddr);
5036 }
5037 }
5038 return nullptr;
5039 };
5040
5041 // Walk the region to extract comparison predicate, eVal, and dVal.
5042 // if (x == eVal) x = dVal
5043 llvm::omp::OMPAtomicCompareOp compareOp = llvm::omp::OMPAtomicCompareOp::EQ;
5044 llvm::Value *eVal = nullptr;
5045 llvm::Value *dVal = nullptr;
5046 bool isXBinopExpr = false;
5047
5048 auto traceToAggregate = [](mlir::Value v) -> mlir::Value {
5049 if (auto extractOp = v.getDefiningOp<LLVM::ExtractValueOp>())
5050 return extractOp.getContainer();
5051 return nullptr;
5052 };
5053
5054 // Check for a decomposed complex comparison pattern:
5055 // %re_x = llvm.extractvalue %xval[0]
5056 // %re_e = llvm.extractvalue %eStruct[0]
5057 // %cmp_re = llvm.fcmp "oeq" %re_x, %re_e
5058 // %im_x = llvm.extractvalue %xval[1]
5059 // %im_e = llvm.extractvalue %eStruct[1]
5060 // %cmp_im = llvm.fcmp "oeq" %im_x, %im_e
5061 // %cmp = llvm.and %cmp_re, %cmp_im (for EQ)
5062 // Detect this by looking for AndOp/OrOp whose operands are both FCmpOps
5063 // operating on ExtractValueOps from the block argument.
5064 bool isComplexPattern = false;
5065 for (Operation &op : block.getOperations()) {
5066 if (!isa<LLVM::AndOp, LLVM::OrOp>(op))
5067 continue;
5068
5069 // Using : %cmp = llvm.and %cmp_re, %cmp_im
5070 auto lhsFcmp = op.getOperand(0).getDefiningOp<LLVM::FCmpOp>();
5071 auto rhsFcmp = op.getOperand(1).getDefiningOp<LLVM::FCmpOp>();
5072 if (!lhsFcmp || !rhsFcmp)
5073 continue;
5074
5075 // Using : %cmp_re = llvm.fcmp "oeq" %re_x, %re_e
5076 // Check presence of x (block argument) and get e.
5077 mlir::Value lhsAgg0 = traceToAggregate(lhsFcmp.getOperand(0));
5078 mlir::Value lhsAgg1 = traceToAggregate(lhsFcmp.getOperand(1));
5079 bool lhsXIsOp0 = (lhsAgg0 == block.getArgument(0));
5080 bool lhsXIsOp1 = (lhsAgg1 == block.getArgument(0));
5081 if (!lhsXIsOp0 && !lhsXIsOp1)
5082 continue;
5083 mlir::Value eAggregate = lhsXIsOp0 ? lhsAgg1 : lhsAgg0;
5084 if (!eAggregate)
5085 continue;
5086
5087 if (isa<LLVM::AndOp>(op))
5088 compareOp = llvm::omp::OMPAtomicCompareOp::EQ;
5089 else
5090 // OrOp corresponds to NE, which is not a valid atomic compare op.
5091 return atomicCompareOp.emitError(
5092 "unsupported comparison predicate (NE) for complex atomic compare");
5093
5094 isXBinopExpr = lhsXIsOp0;
5095 eVal = materializeValue(eAggregate);
5096 isComplexPattern = true;
5097 break;
5098 }
5099
5100 if (isComplexPattern) {
5101 // dVal from SelectOp or YieldOp.
5102 for (Operation &op : block.getOperations()) {
5103 if (auto selectOp = dyn_cast<LLVM::SelectOp>(op)) {
5104 dVal = materializeValue(selectOp.getTrueValue());
5105 break;
5106 }
5107 }
5108 if (!dVal) {
5109 auto yieldOp = cast<omp::YieldOp>(block.getTerminator());
5110 if (yieldOp.getResults().empty())
5111 return atomicCompareOp.emitError(
5112 "failed to extract desired value (d) from atomic compare region");
5113 dVal = materializeValue(yieldOp.getResults()[0]);
5114 }
5115
5116 const llvm::DataLayout &DL =
5117 builder.GetInsertBlock()->getModule()->getDataLayout();
5118 unsigned totalBits =
5119 DL.getTypeStoreSizeInBits(llvmXElementType).getFixedValue();
5120
5121 llvm::IntegerType *intTy =
5122 llvm::IntegerType::get(builder.getContext(), totalBits);
5123
5124 llvm::Align complexAlign = DL.getABITypeAlign(llvmXElementType);
5125 llvm::Align intAlign = DL.getABITypeAlign(intTy);
5126 llvm::Align maxAlign = std::max(complexAlign, intAlign);
5127
5128 llvm::AllocaInst *eAlloca =
5129 builder.CreateAlloca(llvmXElementType, nullptr, "cmplx.e");
5130 eAlloca->setAlignment(maxAlign);
5131 llvm::AllocaInst *dAlloca =
5132 builder.CreateAlloca(llvmXElementType, nullptr, "cmplx.d");
5133 dAlloca->setAlignment(maxAlign);
5134
5135 builder.CreateAlignedStore(eVal, eAlloca, maxAlign);
5136 llvm::Value *eInt =
5137 builder.CreateAlignedLoad(intTy, eAlloca, maxAlign, "cmplx.e.int");
5138 builder.CreateAlignedStore(dVal, dAlloca, maxAlign);
5139 llvm::Value *dInt =
5140 builder.CreateAlignedLoad(intTy, dAlloca, maxAlign, "cmplx.d.int");
5141
5142 llvm::AtomicOrdering failOrdering =
5143 llvm::AtomicCmpXchgInst::getStrongestFailureOrdering(atomicOrdering);
5144 builder.CreateAtomicCmpXchg(llvmX, eInt, dInt, maxAlign, atomicOrdering,
5145 failOrdering);
5146
5147 // Emit flush after atomic compare if needed (for release, acq_rel,
5148 // seq_cst orderings).
5149 if (atomicOrdering == llvm::AtomicOrdering::Release ||
5150 atomicOrdering == llvm::AtomicOrdering::AcquireRelease ||
5151 atomicOrdering == llvm::AtomicOrdering::SequentiallyConsistent) {
5152 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
5153 ompBuilder->createFlush(ompLoc);
5154 }
5155 return success();
5156 } else {
5157
5158 for (Operation &op : block.getOperations()) {
5159 if (auto icmpOp = dyn_cast<LLVM::ICmpOp>(op)) {
5160 auto maybeOp =
5161 convertICmpPredicateToAtomicCompareOp(icmpOp.getPredicate());
5162 if (!maybeOp)
5163 return atomicCompareOp.emitError(
5164 "unsupported comparison predicate in atomic compare");
5165 compareOp = *maybeOp;
5166
5167 LLVM::ICmpPredicate pred = icmpOp.getPredicate();
5168 isSigned = (pred == LLVM::ICmpPredicate::slt ||
5169 pred == LLVM::ICmpPredicate::sgt ||
5170 pred == LLVM::ICmpPredicate::sle ||
5171 pred == LLVM::ICmpPredicate::sge);
5172
5173 // Identify which operand is the block argument (x) and which is e.
5174 isXBinopExpr = (icmpOp.getOperand(0) == block.getArgument(0));
5175 mlir::Value eOperand =
5176 isXBinopExpr ? icmpOp.getOperand(1) : icmpOp.getOperand(0);
5177 eVal = materializeValue(eOperand);
5178 } else if (auto fcmpOp = dyn_cast<LLVM::FCmpOp>(op)) {
5179 auto maybeOp =
5180 convertFCmpPredicateToAtomicCompareOp(fcmpOp.getPredicate());
5181 if (!maybeOp)
5182 return atomicCompareOp.emitError(
5183 "unsupported comparison predicate in atomic compare");
5184 compareOp = *maybeOp;
5185
5186 isXBinopExpr = (fcmpOp.getOperand(0) == block.getArgument(0));
5187 mlir::Value eOperand =
5188 isXBinopExpr ? fcmpOp.getOperand(1) : fcmpOp.getOperand(0);
5189 eVal = materializeValue(eOperand);
5190 } else if (auto selectOp = dyn_cast<LLVM::SelectOp>(op)) {
5191 if (!dVal)
5192 dVal = materializeValue(selectOp.getTrueValue());
5193 }
5194 }
5195 }
5196
5197 // For non-complex patterns, also extract dVal from SelectOp.
5198 if (!dVal) {
5199 for (Operation &op : block.getOperations()) {
5200 if (auto selectOp = dyn_cast<LLVM::SelectOp>(op)) {
5201 dVal = materializeValue(selectOp.getTrueValue());
5202 break;
5203 }
5204 }
5205 }
5206
5207 if (!eVal)
5208 return atomicCompareOp.emitError(
5209 "failed to extract expected value (e) from atomic compare region");
5210 if (!dVal) {
5211 // Fall back to the yield operand.
5212 auto yieldOp = cast<omp::YieldOp>(block.getTerminator());
5213 if (yieldOp.getResults().empty())
5214 return atomicCompareOp.emitError(
5215 "failed to extract desired value (d) from atomic compare region");
5216 dVal = materializeValue(yieldOp.getResults()[0]);
5217 }
5218
5219 llvmAtomicX.IsSigned = isSigned;
5220
5221 llvm::OpenMPIRBuilder::AtomicOpValue vOpVal = {nullptr, nullptr, false,
5222 false};
5223 llvm::OpenMPIRBuilder::AtomicOpValue rOpVal = {nullptr, nullptr, false,
5224 false};
5225 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
5226
5227 bool savedHandleFPNegZero = ompBuilder->setHandleFPNegZero(true);
5228 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
5229 ompBuilder->createAtomicCompare(ompLoc, llvmAtomicX, vOpVal, rOpVal, eVal,
5230 dVal, atomicOrdering, compareOp,
5231 isXBinopExpr, false, false);
5232 ompBuilder->setHandleFPNegZero(savedHandleFPNegZero);
5233
5234 if (failed(handleError(afterIP, *atomicCompareOp)))
5235 return failure();
5236
5237 builder.restoreIP(*afterIP);
5238 return success();
5239}
5240
5241static llvm::omp::Directive convertCancellationConstructType(
5242 omp::ClauseCancellationConstructType directive) {
5243 switch (directive) {
5244 case omp::ClauseCancellationConstructType::Loop:
5245 return llvm::omp::Directive::OMPD_for;
5246 case omp::ClauseCancellationConstructType::Parallel:
5247 return llvm::omp::Directive::OMPD_parallel;
5248 case omp::ClauseCancellationConstructType::Sections:
5249 return llvm::omp::Directive::OMPD_sections;
5250 case omp::ClauseCancellationConstructType::Taskgroup:
5251 return llvm::omp::Directive::OMPD_taskgroup;
5252 }
5253 llvm_unreachable("Unhandled cancellation construct type");
5254}
5255
5256static LogicalResult
5257convertOmpCancel(omp::CancelOp op, llvm::IRBuilderBase &builder,
5258 LLVM::ModuleTranslation &moduleTranslation) {
5259 if (failed(checkImplementationStatus(*op.getOperation())))
5260 return failure();
5261
5262 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
5263 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
5264
5265 llvm::Value *ifCond = nullptr;
5266 if (Value ifVar = op.getIfExpr())
5267 ifCond = moduleTranslation.lookupValue(ifVar);
5268
5269 llvm::omp::Directive cancelledDirective =
5270 convertCancellationConstructType(op.getCancelDirective());
5271
5272 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
5273 ompBuilder->createCancel(ompLoc, ifCond, cancelledDirective);
5274
5275 if (failed(handleError(afterIP, *op.getOperation())))
5276 return failure();
5277
5278 builder.restoreIP(afterIP.get());
5279
5280 return success();
5281}
5282
5283static LogicalResult
5284convertOmpCancellationPoint(omp::CancellationPointOp op,
5285 llvm::IRBuilderBase &builder,
5286 LLVM::ModuleTranslation &moduleTranslation) {
5287 if (failed(checkImplementationStatus(*op.getOperation())))
5288 return failure();
5289
5290 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
5291 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
5292
5293 llvm::omp::Directive cancelledDirective =
5294 convertCancellationConstructType(op.getCancelDirective());
5295
5296 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
5297 ompBuilder->createCancellationPoint(ompLoc, cancelledDirective);
5298
5299 if (failed(handleError(afterIP, *op.getOperation())))
5300 return failure();
5301
5302 builder.restoreIP(afterIP.get());
5303
5304 return success();
5305}
5306
5307/// Converts an OpenMP Threadprivate operation into LLVM IR using
5308/// OpenMPIRBuilder.
5309static LogicalResult
5310convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder,
5311 LLVM::ModuleTranslation &moduleTranslation) {
5312 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
5313 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
5314 auto threadprivateOp = cast<omp::ThreadprivateOp>(opInst);
5315
5316 if (failed(checkImplementationStatus(opInst)))
5317 return failure();
5318
5319 Value symAddr = threadprivateOp.getSymAddr();
5320 auto *symOp = symAddr.getDefiningOp();
5321
5322 if (auto asCast = dyn_cast<LLVM::AddrSpaceCastOp>(symOp))
5323 symOp = asCast.getOperand().getDefiningOp();
5324
5325 if (!isa<LLVM::AddressOfOp>(symOp))
5326 return opInst.emitError("Addressing symbol not found");
5327 LLVM::AddressOfOp addressOfOp = dyn_cast<LLVM::AddressOfOp>(symOp);
5328
5329 LLVM::GlobalOp global =
5330 addressOfOp.getGlobal(moduleTranslation.symbolTable());
5331 llvm::GlobalValue *globalValue = moduleTranslation.lookupGlobal(global);
5332 llvm::Type *type = globalValue->getValueType();
5333 llvm::TypeSize typeSize =
5334 builder.GetInsertBlock()->getModule()->getDataLayout().getTypeStoreSize(
5335 type);
5336 llvm::ConstantInt *size = builder.getInt64(typeSize.getFixedValue());
5337 llvm::Value *callInst = ompBuilder->createCachedThreadPrivate(
5338 ompLoc, globalValue, size, global.getSymName() + ".cache");
5339 moduleTranslation.mapValue(opInst.getResult(0), callInst);
5340
5341 return success();
5342}
5343
5344static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind
5345convertToDeviceClauseKind(mlir::omp::DeclareTargetDeviceType deviceClause) {
5346 switch (deviceClause) {
5347 case mlir::omp::DeclareTargetDeviceType::host:
5348 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseHost;
5349 break;
5350 case mlir::omp::DeclareTargetDeviceType::nohost:
5351 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseNoHost;
5352 break;
5353 case mlir::omp::DeclareTargetDeviceType::any:
5354 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseAny;
5355 break;
5356 }
5357 llvm_unreachable("unhandled device clause");
5358}
5359
5360static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind
5362 mlir::omp::DeclareTargetCaptureClause captureClause) {
5363 switch (captureClause) {
5364 case mlir::omp::DeclareTargetCaptureClause::to:
5365 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo;
5366 case mlir::omp::DeclareTargetCaptureClause::link:
5367 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink;
5368 case mlir::omp::DeclareTargetCaptureClause::enter:
5369 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter;
5370 case mlir::omp::DeclareTargetCaptureClause::none:
5371 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryNone;
5372 }
5373 llvm_unreachable("unhandled capture clause");
5374}
5375
5377 Operation *op = value.getDefiningOp();
5378 if (auto addrCast = dyn_cast_if_present<LLVM::AddrSpaceCastOp>(op))
5379 op = addrCast->getOperand(0).getDefiningOp();
5380 if (auto addressOfOp = dyn_cast_if_present<LLVM::AddressOfOp>(op)) {
5381 auto modOp = addressOfOp->getParentOfType<mlir::ModuleOp>();
5382 return modOp.lookupSymbol(addressOfOp.getGlobalName());
5383 }
5384 return nullptr;
5385}
5386
5388 while (Operation *op = value.getDefiningOp()) {
5389 if (auto addrCast = dyn_cast_if_present<LLVM::AddrSpaceCastOp>(op))
5390 value = addrCast.getOperand();
5391 // Traces through hlfir.declare, fir.declare to reach the base address and
5392 // use for type lookup.
5393 else if (op->getName().getIdentifier() &&
5394 (op->getName().getIdentifier().str() == "hlfir.declare" ||
5395 op->getName().getIdentifier().str() == "fir.declare")) {
5396 if (op->getNumOperands() > 0)
5397 value = op->getOperand(0);
5398 else
5399 break;
5400 } else {
5401 break;
5402 }
5403 }
5404 return value;
5405}
5406
5407static llvm::SmallString<64>
5408getDeclareTargetRefPtrSuffix(LLVM::GlobalOp globalOp,
5409 llvm::OpenMPIRBuilder &ompBuilder,
5410 llvm::vfs::FileSystem &vfs) {
5411 llvm::SmallString<64> suffix;
5412 llvm::raw_svector_ostream os(suffix);
5413 if (globalOp.getVisibility() == mlir::SymbolTable::Visibility::Private) {
5414 auto loc = globalOp->getLoc()->findInstanceOf<FileLineColLoc>();
5415 auto fileInfoCallBack = [&loc]() {
5416 return std::pair<std::string, uint64_t>(
5417 llvm::StringRef(loc.getFilename()), loc.getLine());
5418 };
5419
5420 os << llvm::format(
5421 "_%x",
5422 ompBuilder.getTargetEntryUniqueInfo(fileInfoCallBack, vfs).FileID);
5423 }
5424 os << "_decl_tgt_ref_ptr";
5425
5426 return suffix;
5427}
5428
5429static bool isDeclareTargetLink(Value value) {
5430 if (auto declareTargetGlobal =
5431 dyn_cast_if_present<omp::DeclareTargetInterface>(
5432 getGlobalOpFromValue(value)))
5433 if (declareTargetGlobal.getDeclareTargetCaptureClause() ==
5434 omp::DeclareTargetCaptureClause::link)
5435 return true;
5436 return false;
5437}
5438
5439static bool isDeclareTargetTo(Value value) {
5440 if (auto declareTargetGlobal =
5441 dyn_cast_if_present<omp::DeclareTargetInterface>(
5442 getGlobalOpFromValue(value)))
5443 if (declareTargetGlobal.getDeclareTargetCaptureClause() ==
5444 omp::DeclareTargetCaptureClause::to ||
5445 declareTargetGlobal.getDeclareTargetCaptureClause() ==
5446 omp::DeclareTargetCaptureClause::enter)
5447 return true;
5448 return false;
5449}
5450
5451// Returns the reference pointer generated by the lowering of the declare
5452// target operation in cases where the link clause is used or the to clause is
5453// used in USM mode.
5454static llvm::Value *
5456 LLVM::ModuleTranslation &moduleTranslation) {
5457 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
5458 if (auto gOp =
5459 dyn_cast_or_null<LLVM::GlobalOp>(getGlobalOpFromValue(value))) {
5460 if (auto declareTargetGlobal =
5461 dyn_cast<omp::DeclareTargetInterface>(gOp.getOperation())) {
5462 // In this case, we must utilise the reference pointer generated by
5463 // the declare target operation, similar to Clang
5464 if ((declareTargetGlobal.getDeclareTargetCaptureClause() ==
5465 omp::DeclareTargetCaptureClause::link) ||
5466 (declareTargetGlobal.getDeclareTargetCaptureClause() ==
5467 omp::DeclareTargetCaptureClause::to &&
5468 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
5470 gOp, *ompBuilder, moduleTranslation.getFileSystem());
5471
5472 if (gOp.getSymName().contains(suffix))
5473 return moduleTranslation.getLLVMModule()->getNamedValue(
5474 gOp.getSymName());
5475
5476 return moduleTranslation.getLLVMModule()->getNamedValue(
5477 (gOp.getSymName().str() + suffix.str()).str());
5478 }
5479 }
5480 }
5481 return nullptr;
5482}
5483
5484namespace {
5485// Append customMappers information to existing MapInfosTy
5486struct MapInfosTy : llvm::OpenMPIRBuilder::MapInfosTy {
5487 SmallVector<Operation *, 4> Mappers;
5488
5489 /// Append arrays in \a CurInfo.
5490 void append(MapInfosTy &curInfo) {
5491 Mappers.append(curInfo.Mappers.begin(), curInfo.Mappers.end());
5492 llvm::OpenMPIRBuilder::MapInfosTy::append(curInfo);
5493 }
5494};
5495// A small helper structure to contain data gathered
5496// for map lowering and coalese it into one area and
5497// avoiding extra computations such as searches in the
5498// llvm module for lowered mapped variables or checking
5499// if something is declare target (and retrieving the
5500// value) more than neccessary.
5501struct MapInfoData : MapInfosTy {
5502 llvm::SmallVector<bool, 4> IsDeclareTarget;
5503 llvm::SmallVector<bool, 4> IsAMember;
5504 // Identify if mapping was added by mapClause or use_device clauses.
5505 llvm::SmallVector<bool, 4> IsAMapping;
5506 llvm::SmallVector<mlir::Operation *, 4> MapClause;
5507 llvm::SmallVector<llvm::Value *, 4> OriginalValue;
5508 // Stripped off array/pointer to get the underlying
5509 // element type
5510 llvm::SmallVector<llvm::Type *, 4> BaseType;
5511
5512 /// Append arrays in \a CurInfo.
5513 void append(MapInfoData &CurInfo) {
5514 IsDeclareTarget.append(CurInfo.IsDeclareTarget.begin(),
5515 CurInfo.IsDeclareTarget.end());
5516 MapClause.append(CurInfo.MapClause.begin(), CurInfo.MapClause.end());
5517 OriginalValue.append(CurInfo.OriginalValue.begin(),
5518 CurInfo.OriginalValue.end());
5519 BaseType.append(CurInfo.BaseType.begin(), CurInfo.BaseType.end());
5520 MapInfosTy::append(CurInfo);
5521 }
5522};
5523
5524enum class TargetDirectiveEnumTy : uint32_t {
5525 None = 0,
5526 Target = 1,
5527 TargetData = 2,
5528 TargetEnterData = 3,
5529 TargetExitData = 4,
5530 TargetUpdate = 5
5531};
5532
5533static TargetDirectiveEnumTy getTargetDirectiveEnumTyFromOp(Operation *op) {
5534 return llvm::TypeSwitch<Operation *, TargetDirectiveEnumTy>(op)
5535 .Case([](omp::TargetDataOp) { return TargetDirectiveEnumTy::TargetData; })
5536 .Case([](omp::TargetEnterDataOp) {
5537 return TargetDirectiveEnumTy::TargetEnterData;
5538 })
5539 .Case([&](omp::TargetExitDataOp) {
5540 return TargetDirectiveEnumTy::TargetExitData;
5541 })
5542 .Case([&](omp::TargetUpdateOp) {
5543 return TargetDirectiveEnumTy::TargetUpdate;
5544 })
5545 .Case([&](omp::TargetOp) { return TargetDirectiveEnumTy::Target; })
5546 .Default([&](Operation *op) { return TargetDirectiveEnumTy::None; });
5547}
5548
5549} // namespace
5550
5551static uint64_t getArrayElementSizeInBits(LLVM::LLVMArrayType arrTy,
5552 DataLayout &dl) {
5553 if (auto nestedArrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(
5554 arrTy.getElementType()))
5555 return getArrayElementSizeInBits(nestedArrTy, dl);
5556 return dl.getTypeSizeInBits(arrTy.getElementType());
5557}
5558
5559// This function calculates the size to be offloaded for a specified type, given
5560// its associated map clause (which can contain bounds information which affects
5561// the total size), this size is calculated based on the underlying element type
5562// e.g. given a 1-D array of ints, we will calculate the size from the integer
5563// type * number of elements in the array. This size can be used in other
5564// calculations but is ultimately used as an argument to the OpenMP runtimes
5565// kernel argument structure which is generated through the combinedInfo data
5566// structures.
5567// This function is somewhat equivalent to Clang's getExprTypeSize inside of
5568// CGOpenMPRuntime.cpp.
5569static llvm::Value *getSizeInBytes(DataLayout &dl, const mlir::Type &type,
5570 Operation *clauseOp,
5571 llvm::Value *basePointer,
5572 llvm::Type *baseType,
5573 llvm::IRBuilderBase &builder,
5574 LLVM::ModuleTranslation &moduleTranslation) {
5575 if (auto memberClause =
5576 mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(clauseOp)) {
5577 // This calculates the size to transfer based on bounds and the underlying
5578 // element type, provided bounds have been specified (Fortran
5579 // pointers/allocatables/target and arrays that have sections specified fall
5580 // into this as well)
5581 if (!memberClause.getBounds().empty()) {
5582 llvm::Value *elementCount = builder.getInt64(1);
5583 for (auto bounds : memberClause.getBounds()) {
5584 if (auto boundOp = mlir::dyn_cast_if_present<mlir::omp::MapBoundsOp>(
5585 bounds.getDefiningOp())) {
5586 // The below calculation for the size to be mapped calculated from the
5587 // map.info's bounds is: (elemCount * [UB - LB] + 1), later we
5588 // multiply by the underlying element types byte size to get the full
5589 // size to be offloaded based on the bounds
5590 elementCount = builder.CreateMul(
5591 elementCount,
5592 builder.CreateAdd(
5593 builder.CreateSub(
5594 moduleTranslation.lookupValue(boundOp.getUpperBound()),
5595 moduleTranslation.lookupValue(boundOp.getLowerBound())),
5596 builder.getInt64(1)));
5597 }
5598 }
5599
5600 // utilising getTypeSizeInBits instead of getTypeSize as getTypeSize gives
5601 // the size in inconsistent byte or bit format.
5602 uint64_t underlyingTypeSzInBits = dl.getTypeSizeInBits(type);
5603 if (auto arrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(type))
5604 underlyingTypeSzInBits = getArrayElementSizeInBits(arrTy, dl);
5605
5606 // The size in bytes x number of elements, the sizeInBytes stored is
5607 // the underyling types size, e.g. if ptr<i32>, it'll be the i32's
5608 // size, so we do some on the fly runtime math to get the size in
5609 // bytes from the extent (ub - lb) * sizeInBytes. NOTE: This may need
5610 // some adjustment for members with more complex types.
5611 return builder.CreateMul(elementCount,
5612 builder.getInt64(underlyingTypeSzInBits / 8));
5613 }
5614 }
5615
5616 return builder.getInt64(dl.getTypeSizeInBits(type) / 8);
5617}
5618
5619// Convert the MLIR map flag set to the runtime map flag set for embedding
5620// in LLVM-IR. This is important as the two bit-flag lists do not correspond
5621// 1-to-1 as there's flags the runtime doesn't care about and vice versa.
5622// Certain flags are discarded here such as RefPtee and co.
5623static llvm::omp::OpenMPOffloadMappingFlags
5624convertClauseMapFlags(omp::ClauseMapFlags mlirFlags) {
5625 const bool hasExplicitMap =
5626 (mlirFlags & ~omp::ClauseMapFlags::is_device_ptr) !=
5627 omp::ClauseMapFlags::none;
5628
5629 llvm::omp::OpenMPOffloadMappingFlags mapType =
5630 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
5631
5632 if (bitEnumContainsAll(mlirFlags, omp::ClauseMapFlags::to))
5633 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
5634
5635 if (bitEnumContainsAll(mlirFlags, omp::ClauseMapFlags::from))
5636 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
5637
5638 if (bitEnumContainsAll(mlirFlags, omp::ClauseMapFlags::always))
5639 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
5640
5641 if (bitEnumContainsAll(mlirFlags, omp::ClauseMapFlags::del))
5642 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
5643
5644 if (bitEnumContainsAll(mlirFlags, omp::ClauseMapFlags::return_param))
5645 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
5646
5647 if (bitEnumContainsAll(mlirFlags, omp::ClauseMapFlags::priv))
5648 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRIVATE;
5649
5650 if (bitEnumContainsAll(mlirFlags, omp::ClauseMapFlags::literal))
5651 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
5652
5653 if (bitEnumContainsAll(mlirFlags, omp::ClauseMapFlags::implicit))
5654 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
5655
5656 if (bitEnumContainsAll(mlirFlags, omp::ClauseMapFlags::close))
5657 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
5658
5659 if (bitEnumContainsAll(mlirFlags, omp::ClauseMapFlags::present))
5660 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT;
5661
5662 if (bitEnumContainsAll(mlirFlags, omp::ClauseMapFlags::ompx_hold))
5663 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD;
5664
5665 if (bitEnumContainsAll(mlirFlags, omp::ClauseMapFlags::attach))
5666 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH;
5667
5668 if (bitEnumContainsAll(mlirFlags, omp::ClauseMapFlags::is_device_ptr)) {
5669 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
5670 if (!hasExplicitMap)
5671 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
5672 }
5673
5674 return mapType;
5675}
5676
5678 MapInfoData &mapData, SmallVectorImpl<Value> &mapVars,
5679 LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl,
5680 llvm::IRBuilderBase &builder, ArrayRef<Value> useDevPtrOperands = {},
5681 ArrayRef<Value> useDevAddrOperands = {},
5682 ArrayRef<Value> hasDevAddrOperands = {}) {
5683
5684 auto checkRefPtrOrPteeMapWithAttach = [](omp::ClauseMapFlags mapType) {
5685 bool hasRefType =
5686 bitEnumContainsAll(mapType, omp::ClauseMapFlags::ref_ptr) ||
5687 bitEnumContainsAll(mapType, omp::ClauseMapFlags::ref_ptee);
5688 return hasRefType &&
5689 bitEnumContainsAll(mapType, omp::ClauseMapFlags::attach);
5690 };
5691
5692 auto checkIsAMember = [](const auto &mapVars, auto mapOp) {
5693 // Check if this is a member mapping and correctly assign that it is, if
5694 // it is a member of a larger object.
5695 // TODO: Need better handling of members, and distinguishing of members
5696 // that are implicitly allocated on device vs explicitly passed in as
5697 // arguments.
5698 // TODO: May require some further additions to support nested record
5699 // types, i.e. member maps that can have member maps.
5700 for (Value mapValue : mapVars) {
5701 auto map = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
5702 for (auto member : map.getMembers())
5703 if (member == mapOp)
5704 return true;
5705 }
5706 return false;
5707 };
5708
5709 // Process MapOperands
5710 for (Value mapValue : mapVars) {
5711 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
5712 bool isRefPtrOrPteeMapWithAttach =
5713 checkRefPtrOrPteeMapWithAttach(mapOp.getMapType());
5714 Value offloadPtr = (mapOp.getVarPtrPtr() && !isRefPtrOrPteeMapWithAttach)
5715 ? mapOp.getVarPtrPtr()
5716 : mapOp.getVarPtr();
5717 mapData.OriginalValue.push_back(moduleTranslation.lookupValue(offloadPtr));
5718 mapData.Pointers.push_back(
5719 isRefPtrOrPteeMapWithAttach
5720 ? moduleTranslation.lookupValue(mapOp.getVarPtrPtr())
5721 : mapData.OriginalValue.back());
5722
5723 if (llvm::Value *refPtr =
5724 getRefPtrIfDeclareTarget(offloadPtr, moduleTranslation)) {
5725 mapData.IsDeclareTarget.push_back(true);
5726 mapData.BasePointers.push_back(refPtr);
5727 } else if (isDeclareTargetTo(offloadPtr)) {
5728 mapData.IsDeclareTarget.push_back(true);
5729 mapData.BasePointers.push_back(mapData.OriginalValue.back());
5730 } else { // regular mapped variable
5731 mapData.IsDeclareTarget.push_back(false);
5732 mapData.BasePointers.push_back(mapData.OriginalValue.back());
5733 }
5734
5735 // In every situation we currently have if we have a varPtrPtr present
5736 // we wish to utilise it's type for the base type, main cases are
5737 // currently Fortran descriptor base address maps and attach maps.
5738 mapData.BaseType.push_back(moduleTranslation.convertType(
5739 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtrType().value()
5740 : mapOp.getVarPtrType()));
5741
5742 // For the attach map cases, it's a little odd, as we effectively have to
5743 // utilise the base address (including all bounds offsets) for the pointer
5744 // field, the pointer address for the base address field, and the pointer
5745 // not the data (base addresses) size. So we end up with a mix of base
5746 // types and sizes we wish to insert here.
5747 mlir::Type sizeType = (isRefPtrOrPteeMapWithAttach || !mapOp.getVarPtrPtr())
5748 ? mapOp.getVarPtrType()
5749 : mapOp.getVarPtrPtrType().value();
5750 mapData.Sizes.push_back(getSizeInBytes(
5751 dl, sizeType, isRefPtrOrPteeMapWithAttach ? nullptr : mapOp,
5752 mapData.Pointers.back(), moduleTranslation.convertType(sizeType),
5753 builder, moduleTranslation));
5754 mapData.MapClause.push_back(mapOp.getOperation());
5755 mapData.Types.push_back(convertClauseMapFlags(mapOp.getMapType()));
5756 mapData.Names.push_back(LLVM::createMappingInformation(
5757 mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
5758 mapData.DevicePointers.push_back(llvm::OpenMPIRBuilder::DeviceInfoTy::None);
5759 if (mapOp.getMapperId())
5760 mapData.Mappers.push_back(
5762 mapOp, mapOp.getMapperIdAttr()));
5763 else
5764 mapData.Mappers.push_back(nullptr);
5765 mapData.IsAMapping.push_back(true);
5766 mapData.IsAMember.push_back(checkIsAMember(mapVars, mapOp));
5767 }
5768
5769 auto findMapInfo = [&mapData](llvm::Value *val,
5770 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy,
5771 size_t memberCount) {
5772 unsigned index = 0;
5773 bool found = false;
5774 for (llvm::Value *basePtr : mapData.OriginalValue) {
5775 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[index]);
5776 // TODO: Currently we define an equivalent mapping as
5777 // the same base pointer and an equivalent member count, but
5778 // that is a loose definition. We may have to extend to check
5779 // for other fields (varPtrPtr/individual members being mapped).
5780 // Note: Attach maps are not the same as a normal data transfer
5781 // they specify to the runtime to perform an attach map and they
5782 // (at least at the moment) are never something we would aim to
5783 // return in a use_dev_* clause, so they are skipped in terms of
5784 // duplicate maps.
5785 bool isAttachMap =
5786 (mapData.Types[index] &
5787 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH) ==
5788 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH;
5789 if (!isAttachMap && basePtr == val && mapData.IsAMapping[index] &&
5790 memberCount == mapOp.getMembers().size()) {
5791 found = true;
5792 mapData.Types[index] |=
5793 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
5794 mapData.DevicePointers[index] = devInfoTy;
5795 }
5796 index++;
5797 }
5798 return found;
5799 };
5800
5801 // Process useDevPtr(Addr)Operands
5802 auto addDevInfos = [&](const llvm::ArrayRef<Value> &useDevOperands,
5803 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
5804 for (Value mapValue : useDevOperands) {
5805 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
5806 Value offloadPtr =
5807 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
5808 llvm::Value *origValue = moduleTranslation.lookupValue(offloadPtr);
5809
5810 // Check if map info is already present for this entry.
5811 if (!findMapInfo(origValue, devInfoTy, mapOp.getMembers().size())) {
5812 mapData.OriginalValue.push_back(origValue);
5813 mapData.Pointers.push_back(mapData.OriginalValue.back());
5814 mapData.IsDeclareTarget.push_back(false);
5815 mapData.BasePointers.push_back(mapData.OriginalValue.back());
5816 mlir::Type baseTy = mapOp.getVarPtrPtr()
5817 ? mapOp.getVarPtrPtrType().value()
5818 : mapOp.getVarPtrType();
5819 mapData.BaseType.push_back(moduleTranslation.convertType(baseTy));
5820 mapData.Sizes.push_back(builder.getInt64(0));
5821 mapData.MapClause.push_back(mapOp.getOperation());
5822 mapData.Types.push_back(
5823 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM);
5824 mapData.Names.push_back(LLVM::createMappingInformation(
5825 mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
5826 mapData.DevicePointers.push_back(devInfoTy);
5827 mapData.Mappers.push_back(nullptr);
5828 mapData.IsAMapping.push_back(false);
5829 mapData.IsAMember.push_back(checkIsAMember(useDevOperands, mapOp));
5830 }
5831 }
5832 };
5833
5834 addDevInfos(useDevAddrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
5835 addDevInfos(useDevPtrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer);
5836
5837 for (Value mapValue : hasDevAddrOperands) {
5838 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
5839 Value offloadPtr =
5840 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
5841 llvm::Value *origValue = moduleTranslation.lookupValue(offloadPtr);
5842 auto mapType = convertClauseMapFlags(mapOp.getMapType());
5843 auto mapTypeAlways = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
5844 bool isDevicePtr =
5845 (mapOp.getMapType() & omp::ClauseMapFlags::is_device_ptr) !=
5846 omp::ClauseMapFlags::none;
5847
5848 mapData.OriginalValue.push_back(origValue);
5849 mapData.BasePointers.push_back(origValue);
5850 mapData.Pointers.push_back(origValue);
5851 mapData.IsDeclareTarget.push_back(false);
5852
5853 mlir::Type baseTy = mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtrType().value()
5854 : mapOp.getVarPtrType();
5855 mapData.BaseType.push_back(moduleTranslation.convertType(baseTy));
5856 mapData.Sizes.push_back(builder.getInt64(dl.getTypeSize(baseTy)));
5857
5858 mapData.MapClause.push_back(mapOp.getOperation());
5859 if (llvm::to_underlying(mapType & mapTypeAlways)) {
5860 // Descriptors are mapped with the ALWAYS flag, since they can get
5861 // rematerialized, so the address of the decriptor for a given object
5862 // may change from one place to another.
5863 mapData.Types.push_back(mapType);
5864 // Technically it's possible for a non-descriptor mapping to have
5865 // both has-device-addr and ALWAYS, so lookup the mapper in case it
5866 // exists.
5867 if (mapOp.getMapperId()) {
5868 mapData.Mappers.push_back(
5870 mapOp, mapOp.getMapperIdAttr()));
5871 } else {
5872 mapData.Mappers.push_back(nullptr);
5873 }
5874 } else {
5875 // For is_device_ptr we need the map type to propagate so the runtime
5876 // can materialize the device-side copy of the pointer container.
5877 mapData.Types.push_back(
5878 isDevicePtr ? mapType
5879 : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL);
5880 mapData.Mappers.push_back(nullptr);
5881 }
5882 mapData.Names.push_back(LLVM::createMappingInformation(
5883 mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
5884 mapData.DevicePointers.push_back(
5885 isDevicePtr ? llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer
5886 : llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
5887 mapData.IsAMapping.push_back(false);
5888 mapData.IsAMember.push_back(checkIsAMember(hasDevAddrOperands, mapOp));
5889 }
5890}
5891
5892static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp) {
5893 auto *res = llvm::find(mapData.MapClause, memberOp);
5894 assert(res != mapData.MapClause.end() &&
5895 "MapInfoOp for member not found in MapData, cannot return index");
5896 return std::distance(mapData.MapClause.begin(), res);
5897}
5898
5900 omp::MapInfoOp mapInfo, bool first = true) {
5901 ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
5902 llvm::SmallVector<size_t> occludedChildren;
5903 llvm::sort(
5904 indices.begin(), indices.end(), [&](const size_t a, const size_t b) {
5905 // Bail early if we are asked to look at the same index. If we do not
5906 // bail early, we can end up mistakenly adding indices to
5907 // occludedChildren. This can occur with some types of libc++ hardening.
5908 if (a == b)
5909 return false;
5910
5911 auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]);
5912 auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]);
5913
5914 for (auto it : llvm::zip(memberIndicesA, memberIndicesB)) {
5915 int64_t aIndex = mlir::cast<IntegerAttr>(std::get<0>(it)).getInt();
5916 int64_t bIndex = mlir::cast<IntegerAttr>(std::get<1>(it)).getInt();
5917
5918 if (aIndex == bIndex)
5919 continue;
5920
5921 if (aIndex < bIndex)
5922 return first;
5923
5924 if (aIndex > bIndex)
5925 return !first;
5926 }
5927
5928 // Iterated up until the end of the smallest member and
5929 // they were found to be equal up to that point, so select
5930 // the member with the lowest index count, so the "parent"
5931 bool memberAParent = memberIndicesA.size() < memberIndicesB.size();
5932 if (memberAParent)
5933 occludedChildren.push_back(b);
5934 else
5935 occludedChildren.push_back(a);
5936 return memberAParent;
5937 });
5938
5939 for (auto v : occludedChildren)
5940 indices.erase(std::remove(indices.begin(), indices.end(), v),
5941 indices.end());
5942}
5943
5944static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo,
5945 bool first) {
5946 ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
5947 // Only 1 member has been mapped, we can return it.
5948 if (indexAttr.size() == 1)
5949 return cast<omp::MapInfoOp>(mapInfo.getMembers()[0].getDefiningOp());
5950 llvm::SmallVector<size_t> indices(indexAttr.size());
5951 std::iota(indices.begin(), indices.end(), 0);
5952 sortMapIndices(indices, mapInfo, first);
5953 return llvm::cast<omp::MapInfoOp>(
5954 mapInfo.getMembers()[indices.front()].getDefiningOp());
5955}
5956
5957/// This function calculates the array/pointer offset for map data provided
5958/// with bounds operations, e.g. when provided something like the following:
5959///
5960/// Fortran
5961/// map(tofrom: array(2:5, 3:2))
5962///
5963/// We must calculate the initial pointer offset to pass across, this function
5964/// performs this using bounds.
5965///
5966/// TODO/WARNING: This only supports Fortran's column major indexing currently
5967/// as is noted in the note below and comments in the function, we must extend
5968/// this function when we add a C++ frontend.
5969/// NOTE: which while specified in row-major order it currently needs to be
5970/// flipped for Fortran's column order array allocation and access (as
5971/// opposed to C++'s row-major, hence the backwards processing where order is
5972/// important). This is likely important to keep in mind for the future when
5973/// we incorporate a C++ frontend, both frontends will need to agree on the
5974/// ordering of generated bounds operations (one may have to flip them) to
5975/// make the below lowering frontend agnostic. The offload size
5976/// calcualtion may also have to be adjusted for C++.
5977static std::vector<llvm::Value *>
5979 llvm::IRBuilderBase &builder, bool isArrayTy,
5980 OperandRange bounds) {
5981 std::vector<llvm::Value *> idx;
5982 // There's no bounds to calculate an offset from, we can safely
5983 // ignore and return no indices.
5984 if (bounds.empty())
5985 return idx;
5986
5987 // If we have an array type, then we have its type so can treat it as a
5988 // normal GEP instruction where the bounds operations are simply indexes
5989 // into the array. We currently do reverse order of the bounds, which
5990 // I believe leans more towards Fortran's column-major in memory.
5991 if (isArrayTy) {
5992 idx.push_back(builder.getInt64(0));
5993 for (int i = bounds.size() - 1; i >= 0; --i) {
5994 if (auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
5995 bounds[i].getDefiningOp())) {
5996 idx.push_back(moduleTranslation.lookupValue(boundOp.getLowerBound()));
5997 }
5998 }
5999 } else {
6000 // If we do not have an array type, but we have bounds, then we're dealing
6001 // with a pointer that's being treated like an array and we have the
6002 // underlying type e.g. an i32, or f64 etc, e.g. a fortran descriptor base
6003 // address (pointer pointing to the actual data) so we must caclulate the
6004 // offset using a single index which the following loop attempts to
6005 // compute using the standard column-major algorithm e.g for a 3D array:
6006 //
6007 // ((((c_idx * b_len) + b_idx) * a_len) + a_idx)
6008 //
6009 // It is of note that it's doing column-major rather than row-major at the
6010 // moment, but having a way for the frontend to indicate which major format
6011 // to use or standardizing/canonicalizing the order of the bounds to compute
6012 // the offset may be useful in the future when there's other frontends with
6013 // different formats.
6014 std::vector<llvm::Value *> dimensionIndexSizeOffset;
6015 for (int i = bounds.size() - 1; i >= 0; --i) {
6016 if (auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
6017 bounds[i].getDefiningOp())) {
6018 if (i == ((int)bounds.size() - 1))
6019 idx.emplace_back(
6020 moduleTranslation.lookupValue(boundOp.getLowerBound()));
6021 else
6022 idx.back() = builder.CreateAdd(
6023 builder.CreateMul(idx.back(), moduleTranslation.lookupValue(
6024 boundOp.getExtent())),
6025 moduleTranslation.lookupValue(boundOp.getLowerBound()));
6026 }
6027 }
6028 }
6029
6030 return idx;
6031}
6032
6034 llvm::transform(values, std::back_inserter(ints), [](Attribute value) {
6035 return cast<IntegerAttr>(value).getInt();
6036 });
6037}
6038
6039// Gathers members that are overlapping in the parent, excluding members that
6040// themselves overlap, keeping the top-most (closest to parents level) map.
6041static void
6043 omp::MapInfoOp parentOp) {
6044 // No members mapped, no overlaps.
6045 if (parentOp.getMembers().empty())
6046 return;
6047
6048 // Single member, we can insert and return early.
6049 if (parentOp.getMembers().size() == 1) {
6050 overlapMapDataIdxs.push_back(0);
6051 return;
6052 }
6053
6054 ArrayAttr indexAttr = parentOp.getMembersIndexAttr();
6055 size_t numMembers = indexAttr.size();
6056
6057 // Pre-convert all member indices to integer arrays for efficient comparison.
6058 llvm::SmallVector<llvm::SmallVector<int64_t>> memberIndices(numMembers);
6059 for (auto [i, indicesAttr] : llvm::enumerate(indexAttr))
6060 getAsIntegers(cast<ArrayAttr>(indicesAttr), memberIndices[i]);
6061
6062 // For each member, check if it's superseded by another (shorter prefix)
6063 // member. If member j's indices are a prefix of member i's indices, then
6064 // i is a child of j and should be skipped. e.g. if member [0] is mapped,
6065 // we skip members [0,1], [0,2], etc.
6066 llvm::SmallDenseSet<size_t> skipIndices;
6067 for (size_t i = 0; i < numMembers; ++i) {
6068 const auto &iIndices = memberIndices[i];
6069 for (size_t j = 0; j < numMembers; ++j) {
6070 if (i == j)
6071 continue;
6072 const auto &jIndices = memberIndices[j];
6073 // If j's indices are a strict prefix of i's indices, skip i
6074 if (jIndices.size() < iIndices.size() &&
6075 std::equal(jIndices.begin(), jIndices.end(), iIndices.begin())) {
6076 skipIndices.insert(i);
6077 break; // No need to check other potential parents
6078 }
6079 }
6080 }
6081
6082 // Collect indices of members that are not superseded by a parent.
6083 for (size_t i = 0; i < numMembers; ++i)
6084 if (!skipIndices.contains(i))
6085 overlapMapDataIdxs.push_back(i);
6086}
6087
6088// The intent is to verify if the mapped data being passed is a
6089// pointer -> pointee that requires special handling in certain cases,
6090// e.g. applying the OMP_MAP_PTR_AND_OBJ map type.
6091//
6092// There may be a better way to verify this, but unfortunately with
6093// opaque pointers we lose the ability to easily check if something is
6094// a pointer whilst maintaining access to the underlying type.
6095static bool checkIfPointerMap(omp::MapInfoOp mapOp) {
6096 // If we have a varPtrPtr field assigned then the underlying type is a pointer
6097 if (mapOp.getVarPtrPtr())
6098 return true;
6099
6100 // If the map data is declare target with a link clause, then it's represented
6101 // as a pointer when we lower it to LLVM-IR even if at the MLIR level it has
6102 // no relation to pointers.
6103 if (isDeclareTargetLink(mapOp.getVarPtr()))
6104 return true;
6105
6106 return false;
6107}
6108
6109/// This function handles the insertion of a single item of map data from
6110/// MapInfoData into the OMPIRBuilder's MapInfo list. Utilising this function
6111/// means the map being inserted can be treated as a non-parent map entity,
6112/// if the memberOfFlag is set then the map being inserted is treated as
6113/// a member map of a larger entity. The insertion into the MapInfo list of
6114/// the OMPIRBuilder can vary based on a number of factors, such as if it's
6115/// a ref_ptr or ref_ptee map, if it's a member of a record, what construct
6116/// the map belongs to and the various map type bit flags that are set for
6117/// the map.
6118static void
6119processIndividualMap(llvm::IRBuilderBase &builder,
6120 llvm::OpenMPIRBuilder &ompBuilder, MapInfoData &mapData,
6121 size_t mapDataIdx, MapInfosTy &combinedInfo,
6122 TargetDirectiveEnumTy targetDirective,
6123 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag =
6124 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE,
6125 bool isTargetParam = true, int mapDataParentIdx = -1) {
6126 auto mapFlag = mapData.Types[mapDataIdx];
6127 auto mapInfoOp = llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIdx]);
6128
6129 bool isPtrTy = checkIfPointerMap(mapInfoOp);
6130 bool isAttachMap = ((convertClauseMapFlags(mapInfoOp.getMapType()) &
6131 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH) ==
6132 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH);
6133
6134 // Declare target variables are not passed to the kernel, and for the moment
6135 // attach maps are not passed to the kernel. However, it is possible to create
6136 // attach maps that transfer data and thus can be kernel arguments, but our
6137 // existing frontend does not do this.
6138 if (isTargetParam &&
6139 (targetDirective == TargetDirectiveEnumTy::Target &&
6140 !mapData.IsDeclareTarget[mapDataIdx]) &&
6141 !isAttachMap)
6142 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
6143
6144 if (mapInfoOp.getMapCaptureType() == omp::VariableCaptureKind::ByCopy &&
6145 !isPtrTy)
6146 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
6147
6148 // If we have a pointer and it's part of a MEMBER_OF mapping we do not apply
6149 // MEMBER_OF, as the runtime currently has a work-around that utilises
6150 // MEMBER_OF to prevent reference updating in certain scenarios instead of
6151 // target_param. However, this causes a noticeable issue in cases where we
6152 // map some data (Fortran descriptor primarily at the moment), alter it on
6153 // the host, and then expect it to not be updated in a subsequent implicit map
6154 // (such as an implicit map on a target).
6155 if (memberOfFlag != llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE) {
6156 if (!isPtrTy && !isAttachMap)
6157 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
6158
6159 // The return parameter should be the over-riding parent in cases where we
6160 // have a return parameter that is echoed to all members, the main case of
6161 // this currently is with fortran descriptors. It may need more finessing
6162 // for C/C++ in the future or descriptors that are members of derived
6163 // types.
6164 mapFlag &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
6165 }
6166
6167 // We apply MAP_PTR_AND_OBJ when within a declare mapper object as it enforces
6168 // MEMBER_OF mappings on maps that are passed the initial nesting depth, which
6169 // includes pointed to data and attach members, both of which are technically
6170 // not part of the main object. This has the side effect of causing early
6171 // map-backs in certain cases where an implicit declare mapper has been
6172 // emitted for a target region. Applying MAP_PTR_AND_OBJ in these situations
6173 // circumvents this.
6174 if (isPtrTy && !isAttachMap && mapData.IsDeclareTarget[mapDataIdx])
6175 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
6176
6177 // if we're provided a mapDataParentIdx, then the data being mapped is
6178 // part of a larger object (in a parent <-> member mapping) and in this
6179 // case our BasePointer should be the parent. Except in the edge case
6180 // where we are mapping pointee data, where we try staying close to
6181 // what Clang currently does and utilise the regular base pointer of the
6182 // data.
6183 bool isRefPtee =
6184 !bitEnumContainsAll(mapInfoOp.getMapType(),
6185 omp::ClauseMapFlags::ref_ptr) &&
6186 bitEnumContainsAll(mapInfoOp.getMapType(), omp::ClauseMapFlags::ref_ptee);
6187 bool isRefPtrPtee = bitEnumContainsAll(mapInfoOp.getMapType(),
6188 omp::ClauseMapFlags::ref_ptr |
6189 omp::ClauseMapFlags::ref_ptee);
6190
6191 if (!mapInfoOp->getParentOfType<omp::DeclareMapperOp>() &&
6192 mapDataParentIdx >= 0 && !(isRefPtee || (isRefPtrPtee && isPtrTy))) {
6193 combinedInfo.BasePointers.emplace_back(
6194 mapData.BasePointers[mapDataParentIdx]);
6195 } else {
6196 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIdx]);
6197 }
6198
6199 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIdx]);
6200 combinedInfo.DevicePointers.emplace_back(
6201 memberOfFlag != llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE
6202 ? llvm::OpenMPIRBuilder::DeviceInfoTy::None
6203 : mapData.DevicePointers[mapDataIdx]);
6204 combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIdx]);
6205 combinedInfo.Names.emplace_back(mapData.Names[mapDataIdx]);
6206 combinedInfo.Types.emplace_back(mapFlag);
6207 combinedInfo.Sizes.emplace_back(
6208 isPtrTy ? builder.CreateSelect(
6209 builder.CreateIsNull(mapData.Pointers[mapDataIdx]),
6210 builder.getInt64(0), mapData.Sizes[mapDataIdx])
6211 : mapData.Sizes[mapDataIdx]);
6212}
6213
6214// This creates two insertions into the MapInfosTy data structure for the
6215// "parent" of a set of members, (usually a container e.g.
6216// class/structure/derived type) when subsequent members have also been
6217// explicitly mapped on the same map clause. Certain types, such as Fortran
6218// descriptors are mapped like this as well, however, the members are
6219// implicit as far as a user is concerned, but we must explicitly map them
6220// internally.
6221//
6222// This function also returns the memberOfFlag for this particular parent,
6223// which is utilised in subsequent member mappings (by modifying there map type
6224// with it) to indicate that a member is part of this parent and should be
6225// treated by the runtime as such. Important to achieve the correct mapping.
6226//
6227// This function borrows a lot from Clang's emitCombinedEntry function
6228// inside of CGOpenMPRuntime.cpp
6230 LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder,
6231 llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo,
6232 MapInfoData &mapData, uint64_t mapDataIndex,
6233 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag,
6234 TargetDirectiveEnumTy targetDirective) {
6235 using MapFlags = llvm::omp::OpenMPOffloadMappingFlags;
6236 assert(!ompBuilder.Config.isTargetDevice() &&
6237 "function only supported for host device codegen");
6238 auto parentClause =
6239 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
6240 auto *parentMapper = mapData.Mappers[mapDataIndex];
6241
6242 // Map the first segment of the parent. If a user-defined mapper is attached,
6243 // include the parent's to/from-style bits (and common modifiers) in this
6244 // base entry so the mapper receives correct copy semantics via its 'type'
6245 // parameter. Also keep TARGET_PARAM when required for kernel arguments.
6246 MapFlags baseFlag = (targetDirective == TargetDirectiveEnumTy::Target &&
6247 !mapData.IsDeclareTarget[mapDataIndex])
6248 ? MapFlags::OMP_MAP_TARGET_PARAM
6249 : MapFlags::OMP_MAP_NONE;
6250
6251 if (parentMapper) {
6252 // Preserve relevant map-type bits from the parent clause. These include
6253 // the copy direction (TO/FROM), as well as commonly used modifiers that
6254 // should be visible to the mapper for correct behaviour.
6255 MapFlags parentFlags = mapData.Types[mapDataIndex];
6256 MapFlags preserve = MapFlags::OMP_MAP_TO | MapFlags::OMP_MAP_FROM |
6257 MapFlags::OMP_MAP_ALWAYS | MapFlags::OMP_MAP_CLOSE |
6258 MapFlags::OMP_MAP_PRESENT |
6259 MapFlags::OMP_MAP_OMPX_HOLD |
6260 MapFlags::OMP_MAP_IMPLICIT;
6261 baseFlag |= (parentFlags & preserve);
6262 } else {
6263 MapFlags parentFlags = mapData.Types[mapDataIndex];
6264 MapFlags preserve =
6265 MapFlags::OMP_MAP_PRESENT | MapFlags::OMP_MAP_RETURN_PARAM;
6266 baseFlag |= (parentFlags & preserve);
6267 }
6268
6269 combinedInfo.Types.emplace_back(baseFlag);
6270 combinedInfo.DevicePointers.emplace_back(
6271 mapData.DevicePointers[mapDataIndex]);
6272 // Only attach the mapper to the base entry when we are mapping the whole
6273 // parent. Combined/segment entries must not carry a mapper; otherwise the
6274 // mapper can be invoked with a partial size, which is undefined behaviour.
6275 combinedInfo.Mappers.emplace_back(
6276 parentMapper && !parentClause.getPartialMap() ? parentMapper : nullptr);
6277 combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
6278 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
6279 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
6280
6281 // Calculate size of the parent object being mapped based on the
6282 // addresses at runtime, highAddr - lowAddr = size. This of course
6283 // doesn't factor in allocated data like pointers, hence the further
6284 // processing of members specified by users, or in the case of
6285 // Fortran pointers and allocatables, the mapping of the pointed to
6286 // data by the descriptor (which itself, is a structure containing
6287 // runtime information on the dynamically allocated data).
6288 llvm::Value *lowAddr, *highAddr;
6289 if (!parentClause.getPartialMap()) {
6290 lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
6291 builder.getPtrTy());
6292 highAddr = builder.CreatePointerCast(
6293 builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
6294 mapData.Pointers[mapDataIndex], 1),
6295 builder.getPtrTy());
6296 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
6297 } else {
6298 auto mapOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
6299 int firstMemberIdx = getMapDataMemberIdx(
6300 mapData, getFirstOrLastMappedMemberPtr(mapOp, true));
6301 lowAddr = builder.CreatePointerCast(mapData.BasePointers[firstMemberIdx],
6302 builder.getPtrTy());
6303
6304 int lastMemberIdx = getMapDataMemberIdx(
6305 mapData, getFirstOrLastMappedMemberPtr(mapOp, false));
6306 auto lastMemberMapInfo =
6307 cast<omp::MapInfoOp>(mapData.MapClause[lastMemberIdx]);
6308
6309 // NOTE: Currently, for RefPtee the BaseType is set to the varPtrPtr field,
6310 // which is the pointer datas type and not the member within the structure
6311 // that it's part of, so we have to make sure we use the member type in this
6312 // case when calculating the parents size offsets.
6313 // TODO: May be good to extend MapInfoData to support tracking of both
6314 // VarPtr/VarPtrPtr BaseType's to better distinguish what's being used more
6315 // consistently.
6316 bool isRefPteeMap = bitEnumContainsAll(lastMemberMapInfo.getMapType(),
6317 omp::ClauseMapFlags::ref_ptee) &&
6318 !bitEnumContainsAll(lastMemberMapInfo.getMapType(),
6319 omp::ClauseMapFlags::ref_ptr);
6320 llvm::Type *castType = mapData.BaseType[lastMemberIdx];
6321 if (isRefPteeMap)
6322 castType =
6323 moduleTranslation.convertType(lastMemberMapInfo.getVarPtrType());
6324 highAddr = builder.CreatePointerCast(
6325 builder.CreateGEP(castType, mapData.BasePointers[lastMemberIdx],
6326 builder.getInt64(1)),
6327 builder.getPtrTy());
6328 combinedInfo.Pointers.emplace_back(mapData.BasePointers[firstMemberIdx]);
6329 }
6330
6331 llvm::Value *size = builder.CreateIntCast(
6332 builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
6333 builder.getInt64Ty(),
6334 /*isSigned=*/false);
6335 combinedInfo.Sizes.push_back(size);
6336
6337 // This creates the initial MEMBER_OF mapping that consists of
6338 // the parent/top level container (same as above effectively, except
6339 // with a fixed initial compile time size and separate maptype which
6340 // indicates the true mape type (tofrom etc.). This parent mapping is
6341 // only relevant if the structure in its totality is being mapped,
6342 // otherwise the above suffices.
6343 if (!parentClause.getPartialMap()) {
6344 // TODO: This will need to be expanded to include the whole host of logic
6345 // for the map flags that Clang currently supports (e.g. it should do some
6346 // further case specific flag modifications). For the moment, it handles
6347 // what we support as expected.
6348 MapFlags mapFlag = mapData.Types[mapDataIndex];
6349 bool hasMapClose = (MapFlags(mapFlag) & MapFlags::OMP_MAP_CLOSE) ==
6350 MapFlags::OMP_MAP_CLOSE;
6351 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
6352
6353 llvm::SmallVector<size_t> overlapIdxs;
6354 // Find all of the members that "overlap", i.e. occlude other members that
6355 // were mapped alongside the parent, e.g. member [0], occludes [0,1] and
6356 // [0,2], but not [1,0].
6357 getOverlappedMembers(overlapIdxs, parentClause);
6358
6359 // When we only have one overlap we skip the case that tries to segment the
6360 // mapping as best it can without creating holes, as the calculation is more
6361 // likely to have more overhead than anything we gain from mapping a smaller
6362 // chunk of data. This can be seen in cases where we are mapping Fortran
6363 // descriptors which are a special case of record type mapping.
6364 //
6365 // The cases for close and update are unique edge cases where the segmenting
6366 // does not play well with the runtime currently.
6367 if (targetDirective == TargetDirectiveEnumTy::TargetUpdate || hasMapClose ||
6368 overlapIdxs.size() == 1) {
6369 combinedInfo.Types.emplace_back(mapFlag);
6370 combinedInfo.DevicePointers.emplace_back(
6371 mapData.DevicePointers[mapDataIndex]);
6372 combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
6373 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
6374 combinedInfo.BasePointers.emplace_back(
6375 mapData.BasePointers[mapDataIndex]);
6376 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
6377 combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIndex]);
6378 combinedInfo.Mappers.emplace_back(nullptr);
6379 } else {
6380 // We need to make sure the overlapped members are sorted in order of
6381 // lowest address to highest address.
6382 sortMapIndices(overlapIdxs, parentClause);
6383
6384 lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
6385 builder.getPtrTy());
6386 highAddr = builder.CreatePointerCast(
6387 builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
6388 mapData.Pointers[mapDataIndex], 1),
6389 builder.getPtrTy());
6390
6391 // Currently, the return parameter should be the over-riding parent in
6392 // cases where we have a return parameter that is echoed to all members,
6393 // the main case of this currently is with fortran descriptors. It may
6394 // need more finessing for C/C++ in the future or descriptors that are
6395 // members of derived types.
6396 mapFlag &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
6397
6398 // TODO: We may want to skip arrays/array sections in this as Clang does.
6399 // It appears to be an optimisation rather than a necessity though,
6400 // but this requires further investigation. However, we would have to make
6401 // sure to not exclude maps with bounds that ARE pointers, as these are
6402 // processed as separate components, i.e. pointer + data.
6403 for (auto v : overlapIdxs) {
6404 auto mapDataOverlapIdx = getMapDataMemberIdx(
6405 mapData,
6406 cast<omp::MapInfoOp>(parentClause.getMembers()[v].getDefiningOp()));
6407 auto isPtrMap = checkIfPointerMap(
6408 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataOverlapIdx]));
6409 combinedInfo.Types.emplace_back(mapFlag);
6410 combinedInfo.DevicePointers.emplace_back(
6411 llvm::OpenMPIRBuilder::DeviceInfoTy::None);
6412 combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
6413 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
6414 combinedInfo.BasePointers.emplace_back(
6415 mapData.BasePointers[mapDataIndex]);
6416 combinedInfo.Mappers.emplace_back(nullptr);
6417 combinedInfo.Pointers.emplace_back(lowAddr);
6418 auto sizeCalc = builder.CreateIntCast(
6419 builder.CreatePtrDiff(builder.getInt8Ty(),
6420 mapData.OriginalValue[mapDataOverlapIdx],
6421 lowAddr),
6422 builder.getInt64Ty(), /*isSigned=*/true);
6423 // In certain cases, we'll generate a size of 0 if we're not careful
6424 // (e.g. if lowAddr happens to be the first member), which isn't
6425 // correct, even if the runtimes is sometimes fine with it so, in these
6426 // scenarios we select the types size instead.
6427 auto sizeSel = builder.CreateSelect(
6428 builder.CreateICmpNE(builder.getInt64(0), sizeCalc), sizeCalc,
6429 isPtrMap ? llvm::ConstantExpr::getSizeOf(builder.getPtrTy())
6430 : mapData.Sizes[mapDataOverlapIdx]);
6431 combinedInfo.Sizes.emplace_back(sizeSel);
6432 lowAddr = builder.CreateConstGEP1_32(
6433 isPtrMap ? builder.getPtrTy() : mapData.BaseType[mapDataOverlapIdx],
6434 mapData.BasePointers[mapDataOverlapIdx], 1);
6435 }
6436
6437 combinedInfo.Types.emplace_back(mapFlag);
6438 combinedInfo.DevicePointers.emplace_back(
6439 llvm::OpenMPIRBuilder::DeviceInfoTy::None);
6440 combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
6441 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
6442 combinedInfo.BasePointers.emplace_back(
6443 mapData.BasePointers[mapDataIndex]);
6444 combinedInfo.Mappers.emplace_back(nullptr);
6445 combinedInfo.Pointers.emplace_back(lowAddr);
6446 combinedInfo.Sizes.emplace_back(builder.CreateIntCast(
6447 builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
6448 builder.getInt64Ty(), true));
6449 }
6450 }
6451}
6452
6454 llvm::IRBuilderBase &builder,
6455 llvm::OpenMPIRBuilder &ompBuilder,
6456 DataLayout &dl, MapInfosTy &combinedInfo,
6457 MapInfoData &mapData, uint64_t mapDataIndex,
6458 TargetDirectiveEnumTy targetDirective) {
6459 assert(!ompBuilder.Config.isTargetDevice() &&
6460 "function only supported for host device codegen");
6461
6462 auto parentClause =
6463 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
6464
6465 // If we have a partial map (no parent referenced in the map clauses of the
6466 // directive, only members) and only a single member, we do not need to bind
6467 // the map of the member to the parent, we can pass the member separately.
6468 if (parentClause.getMembers().size() == 1 && parentClause.getPartialMap()) {
6469 auto memberClause = llvm::cast<omp::MapInfoOp>(
6470 parentClause.getMembers()[0].getDefiningOp());
6471 int memberDataIdx = getMapDataMemberIdx(mapData, memberClause);
6472 // Note: Clang treats arrays with explicit bounds that fall into this
6473 // category as a parent with map case, however, it seems this isn't a
6474 // requirement, and processing them as an individual map is fine. So,
6475 // we will handle them as individual maps for the moment, as it's
6476 // difficult for us to check this as we always require bounds to be
6477 // specified currently and it's also marginally more optimal (single
6478 // map rather than two). The difference may come from the fact that
6479 // Clang maps array without bounds as pointers (which we do not
6480 // currently do), whereas we treat them as arrays in all cases
6481 // currently.
6483 builder, ompBuilder, mapData, memberDataIdx, combinedInfo,
6484 targetDirective,
6485 /*MemberOfFlag=*/llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE,
6486 /*isTargetParam=*/true, mapDataIndex);
6487 return;
6488 }
6489
6490 auto collectMapInfoIdxs =
6491 [&](llvm::SmallVectorImpl<int64_t> &mapsAndInfoIdx) {
6492 auto parentClause =
6493 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
6494 mapsAndInfoIdx.push_back(getMapDataMemberIdx(mapData, parentClause));
6495 for (auto member : parentClause.getMembers())
6496 mapsAndInfoIdx.push_back(getMapDataMemberIdx(
6497 mapData, llvm::cast<omp::MapInfoOp>(member.getDefiningOp())));
6498 };
6499
6500 llvm::SmallVector<int64_t> mapInfoIdx;
6501 collectMapInfoIdxs(mapInfoIdx);
6502
6503 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag =
6504 ompBuilder.getMemberOfFlag(combinedInfo.Types.size());
6505 for (size_t i = 0; i < mapInfoIdx.size(); i++) {
6506 // Index == 0 is the parent map and if it gets here it's an unattachable
6507 // type and should have OMP_MAP_TARGET_PARAM applied and no MEMBER_OF flag.
6508 if (i == 0) {
6509 mapParentWithMembers(moduleTranslation, builder, ompBuilder, dl,
6510 combinedInfo, mapData, mapInfoIdx[i], memberOfFlag,
6511 targetDirective);
6512 } else {
6513 processIndividualMap(builder, ompBuilder, mapData, mapInfoIdx[i],
6514 combinedInfo, targetDirective, memberOfFlag,
6515 /*isTargetParam=*/false, mapDataIndex);
6516 }
6517 }
6518}
6519
6520// This is a variation on Clang's GenerateOpenMPCapturedVars, which
6521// generates different operation (e.g. load/store) combinations for
6522// arguments to the kernel, based on map capture kinds which are then
6523// utilised in the combinedInfo in place of the original Map value.
6524static void
6525createAlteredByCaptureMap(MapInfoData &mapData,
6526 LLVM::ModuleTranslation &moduleTranslation,
6527 llvm::IRBuilderBase &builder) {
6528 assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
6529 "function only supported for host device codegen");
6530 for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
6531 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
6532 bool isAttachMap =
6533 ((convertClauseMapFlags(mapOp.getMapType()) &
6534 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH) ==
6535 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH);
6536
6537 // If it's declare target, skip it, it's handled separately. However, if
6538 // it's declare target, and an attach map, we want to calculate the exact
6539 // address offset so that we attach correctly.
6540 if (!mapData.IsDeclareTarget[i] ||
6541 (mapData.IsDeclareTarget[i] && isAttachMap)) {
6542 omp::VariableCaptureKind captureKind = mapOp.getMapCaptureType();
6543 bool isPtrTy = checkIfPointerMap(mapOp);
6544
6545 // Currently handles array sectioning lowerbound case, but more
6546 // logic may be required in the future. Clang invokes EmitLValue,
6547 // which has specialised logic for special Clang types such as user
6548 // defines, so it is possible we will have to extend this for
6549 // structures or other complex types. As the general idea is that this
6550 // function mimics some of the logic from Clang that we require for
6551 // kernel argument passing from host -> device.
6552 switch (captureKind) {
6553 case omp::VariableCaptureKind::ByRef: {
6554 llvm::Value *newV = mapData.Pointers[i];
6555 std::vector<llvm::Value *> offsetIdx = calculateBoundsOffset(
6556 moduleTranslation, builder, mapData.BaseType[i]->isArrayTy(),
6557 mapOp.getBounds());
6558 if (isPtrTy)
6559 newV = builder.CreateLoad(builder.getPtrTy(), newV);
6560
6561 if (!offsetIdx.empty())
6562 newV = builder.CreateInBoundsGEP(mapData.BaseType[i], newV, offsetIdx,
6563 "array_offset");
6564 mapData.Pointers[i] = newV;
6565 } break;
6566 case omp::VariableCaptureKind::ByCopy: {
6567 llvm::Type *type = mapData.BaseType[i];
6568 llvm::Value *newV;
6569 if (mapData.Pointers[i]->getType()->isPointerTy())
6570 newV = builder.CreateLoad(type, mapData.Pointers[i]);
6571 else
6572 newV = mapData.Pointers[i];
6573
6574 if (!isPtrTy) {
6575 auto curInsert = builder.saveIP();
6576 llvm::DebugLoc DbgLoc = builder.getCurrentDebugLocation();
6577 builder.restoreIP(findAllocInsertPoints(builder, moduleTranslation));
6578 auto *memTempAlloc =
6579 builder.CreateAlloca(builder.getPtrTy(), nullptr, ".casted");
6580 builder.SetCurrentDebugLocation(DbgLoc);
6581 builder.restoreIP(curInsert);
6582
6583 builder.CreateStore(newV, memTempAlloc);
6584 newV = builder.CreateLoad(builder.getPtrTy(), memTempAlloc);
6585 }
6586
6587 mapData.Pointers[i] = newV;
6588 mapData.BasePointers[i] = newV;
6589 } break;
6590 case omp::VariableCaptureKind::This:
6591 case omp::VariableCaptureKind::VLAType:
6592 mapData.MapClause[i]->emitOpError("Unhandled capture kind");
6593 break;
6594 }
6595 }
6596 }
6597}
6598
6599// Generate all map related information and fill the combinedInfo.
6600static void genMapInfos(llvm::IRBuilderBase &builder,
6601 LLVM::ModuleTranslation &moduleTranslation,
6602 DataLayout &dl, MapInfosTy &combinedInfo,
6603 MapInfoData &mapData,
6604 TargetDirectiveEnumTy targetDirective) {
6605 assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
6606 "function only supported for host device codegen");
6607 // We wish to modify some of the methods in which arguments are
6608 // passed based on their capture type by the target region, this can
6609 // involve generating new loads and stores, which changes the
6610 // MLIR value to LLVM value mapping, however, we only wish to do this
6611 // locally for the current function/target and also avoid altering
6612 // ModuleTranslation, so we remap the base pointer or pointer stored
6613 // in the map infos corresponding MapInfoData, which is later accessed
6614 // by genMapInfos and createTarget to help generate the kernel and
6615 // kernel arg structure. It primarily becomes relevant in cases like
6616 // bycopy, or byref range'd arrays. In the default case, we simply
6617 // pass thee pointer byref as both basePointer and pointer.
6618 createAlteredByCaptureMap(mapData, moduleTranslation, builder);
6619
6620 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
6621
6622 // We operate under the assumption that all vectors that are
6623 // required in MapInfoData are of equal lengths (either filled with
6624 // default constructed data or appropiate information) so we can
6625 // utilise the size from any component of MapInfoData, if we can't
6626 // something is missing from the initial MapInfoData construction.
6627 for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
6628 if (mapData.IsAMember[i])
6629 continue;
6630
6631 auto mapInfoOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[i]);
6632 if (!mapInfoOp.getMembers().empty()) {
6633 processMapWithMembersOf(moduleTranslation, builder, *ompBuilder, dl,
6634 combinedInfo, mapData, i, targetDirective);
6635 continue;
6636 }
6637
6638 processIndividualMap(builder, *ompBuilder, mapData, i, combinedInfo,
6639 targetDirective);
6640 }
6641}
6642
6643static llvm::Expected<llvm::Function *>
6644emitUserDefinedMapper(Operation *declMapperOp, llvm::IRBuilderBase &builder,
6645 LLVM::ModuleTranslation &moduleTranslation,
6646 llvm::StringRef mapperFuncName,
6647 TargetDirectiveEnumTy targetDirective);
6648
6649static llvm::Expected<llvm::Function *>
6650getOrCreateUserDefinedMapperFunc(Operation *op, llvm::IRBuilderBase &builder,
6651 LLVM::ModuleTranslation &moduleTranslation,
6652 TargetDirectiveEnumTy targetDirective) {
6653 assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
6654 "function only supported for host device codegen");
6655 auto declMapperOp = cast<omp::DeclareMapperOp>(op);
6656 std::string mapperFuncName =
6657 moduleTranslation.getOpenMPBuilder()->createPlatformSpecificName(
6658 {"omp_mapper", declMapperOp.getSymName()});
6659
6660 if (auto *lookupFunc = moduleTranslation.lookupFunction(mapperFuncName))
6661 return lookupFunc;
6662
6663 // Recursive types can cause re-entrant mapper emission. The mapper function
6664 // is created by OpenMPIRBuilder before the callbacks run, so it may already
6665 // exist in the LLVM module even though it is not yet registered in the
6666 // ModuleTranslation mapping table. Reuse and register it to break the
6667 // recursion.
6668 if (llvm::Function *existingFunc =
6669 moduleTranslation.getLLVMModule()->getFunction(mapperFuncName)) {
6670 moduleTranslation.mapFunction(mapperFuncName, existingFunc);
6671 return existingFunc;
6672 }
6673
6674 return emitUserDefinedMapper(declMapperOp, builder, moduleTranslation,
6675 mapperFuncName, targetDirective);
6676}
6677
6678static llvm::Expected<llvm::Function *>
6679emitUserDefinedMapper(Operation *op, llvm::IRBuilderBase &builder,
6680 LLVM::ModuleTranslation &moduleTranslation,
6681 llvm::StringRef mapperFuncName,
6682 TargetDirectiveEnumTy targetDirective) {
6683 assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
6684 "function only supported for host device codegen");
6685 auto declMapperOp = cast<omp::DeclareMapperOp>(op);
6686 auto declMapperInfoOp = declMapperOp.getDeclareMapperInfo();
6687 DataLayout dl = DataLayout(declMapperOp->getParentOfType<ModuleOp>());
6688 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
6689 llvm::Type *varType = moduleTranslation.convertType(declMapperOp.getType());
6690 SmallVector<Value> mapVars = declMapperInfoOp.getMapVars();
6691
6692 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
6693
6694 // Fill up the arrays with all the mapped variables.
6695 MapInfosTy combinedInfo;
6696 auto genMapInfoCB =
6697 [&](InsertPointTy codeGenIP, llvm::Value *ptrPHI,
6698 llvm::Value *unused2) -> llvm::OpenMPIRBuilder::MapInfosOrErrorTy {
6699 builder.restoreIP(codeGenIP);
6700 moduleTranslation.mapValue(declMapperOp.getSymVal(), ptrPHI);
6701 moduleTranslation.mapBlock(&declMapperOp.getRegion().front(),
6702 builder.GetInsertBlock());
6703 if (failed(moduleTranslation.convertBlock(declMapperOp.getRegion().front(),
6704 /*ignoreArguments=*/true,
6705 builder)))
6706 return llvm::make_error<PreviouslyReportedError>();
6707 MapInfoData mapData;
6708 collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, dl,
6709 builder);
6710 genMapInfos(builder, moduleTranslation, dl, combinedInfo, mapData,
6711 targetDirective);
6712
6713 // Drop the mapping that is no longer necessary so that the same region
6714 // can be processed multiple times.
6715 moduleTranslation.forgetMapping(declMapperOp.getRegion());
6716 return combinedInfo;
6717 };
6718
6719 auto customMapperCB = [&](unsigned i) -> llvm::Expected<llvm::Function *> {
6720 if (!combinedInfo.Mappers[i])
6721 return nullptr;
6722 return getOrCreateUserDefinedMapperFunc(combinedInfo.Mappers[i], builder,
6723 moduleTranslation, targetDirective);
6724 };
6725
6726 llvm::Expected<llvm::Function *> newFn = ompBuilder->emitUserDefinedMapper(
6727 genMapInfoCB, varType, mapperFuncName, customMapperCB,
6728 /*PreserveMemberOfFlags=*/true);
6729 if (!newFn)
6730 return newFn.takeError();
6731 if ([[maybe_unused]] llvm::Function *mappedFunc =
6732 moduleTranslation.lookupFunction(mapperFuncName)) {
6733 assert(mappedFunc == *newFn &&
6734 "mapper function mapping disagrees with emitted function");
6735 } else {
6736 moduleTranslation.mapFunction(mapperFuncName, *newFn);
6737 }
6738 return *newFn;
6739}
6740
6741static LogicalResult
6742convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
6743 LLVM::ModuleTranslation &moduleTranslation) {
6744 llvm::Value *ifCond = nullptr;
6745 llvm::Value *deviceID = builder.getInt64(llvm::omp::OMP_DEVICEID_UNDEF);
6746 SmallVector<Value> mapVars;
6747 SmallVector<Value> useDevicePtrVars;
6748 SmallVector<Value> useDeviceAddrVars;
6749 llvm::omp::RuntimeFunction RTLFn;
6750 DataLayout DL = DataLayout(op->getParentOfType<ModuleOp>());
6751 TargetDirectiveEnumTy targetDirective = getTargetDirectiveEnumTyFromOp(op);
6752
6753 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
6754 llvm::OpenMPIRBuilder::TargetDataInfo info(
6755 /*RequiresDevicePointerInfo=*/true,
6756 /*SeparateBeginEndCalls=*/true);
6757 assert(!ompBuilder->Config.isTargetDevice() &&
6758 "target data/enter/exit/update are host ops");
6759 bool isOffloadEntry = !ompBuilder->Config.TargetTriples.empty();
6760
6761 auto getDeviceID = [&](mlir::Value dev) -> llvm::Value * {
6762 llvm::Value *v = moduleTranslation.lookupValue(dev);
6763 return builder.CreateIntCast(v, builder.getInt64Ty(), /*isSigned=*/true);
6764 };
6765
6766 LogicalResult result =
6768 .Case([&](omp::TargetDataOp dataOp) {
6769 if (failed(checkImplementationStatus(*dataOp)))
6770 return failure();
6771
6772 if (auto ifVar = dataOp.getIfExpr())
6773 ifCond = moduleTranslation.lookupValue(ifVar);
6774
6775 if (mlir::Value devId = dataOp.getDevice())
6776 deviceID = getDeviceID(devId);
6777
6778 mapVars = dataOp.getMapVars();
6779 useDevicePtrVars = dataOp.getUseDevicePtrVars();
6780 useDeviceAddrVars = dataOp.getUseDeviceAddrVars();
6781 return success();
6782 })
6783 .Case([&](omp::TargetEnterDataOp enterDataOp) -> LogicalResult {
6784 if (failed(checkImplementationStatus(*enterDataOp)))
6785 return failure();
6786
6787 if (auto ifVar = enterDataOp.getIfExpr())
6788 ifCond = moduleTranslation.lookupValue(ifVar);
6789
6790 if (mlir::Value devId = enterDataOp.getDevice())
6791 deviceID = getDeviceID(devId);
6792
6793 RTLFn =
6794 enterDataOp.getNowait()
6795 ? llvm::omp::OMPRTL___tgt_target_data_begin_nowait_mapper
6796 : llvm::omp::OMPRTL___tgt_target_data_begin_mapper;
6797 mapVars = enterDataOp.getMapVars();
6798 info.HasNoWait = enterDataOp.getNowait();
6799 return success();
6800 })
6801 .Case([&](omp::TargetExitDataOp exitDataOp) -> LogicalResult {
6802 if (failed(checkImplementationStatus(*exitDataOp)))
6803 return failure();
6804
6805 if (auto ifVar = exitDataOp.getIfExpr())
6806 ifCond = moduleTranslation.lookupValue(ifVar);
6807
6808 if (mlir::Value devId = exitDataOp.getDevice())
6809 deviceID = getDeviceID(devId);
6810
6811 RTLFn = exitDataOp.getNowait()
6812 ? llvm::omp::OMPRTL___tgt_target_data_end_nowait_mapper
6813 : llvm::omp::OMPRTL___tgt_target_data_end_mapper;
6814 mapVars = exitDataOp.getMapVars();
6815 info.HasNoWait = exitDataOp.getNowait();
6816 return success();
6817 })
6818 .Case([&](omp::TargetUpdateOp updateDataOp) -> LogicalResult {
6819 if (failed(checkImplementationStatus(*updateDataOp)))
6820 return failure();
6821
6822 if (auto ifVar = updateDataOp.getIfExpr())
6823 ifCond = moduleTranslation.lookupValue(ifVar);
6824
6825 if (mlir::Value devId = updateDataOp.getDevice())
6826 deviceID = getDeviceID(devId);
6827
6828 RTLFn =
6829 updateDataOp.getNowait()
6830 ? llvm::omp::OMPRTL___tgt_target_data_update_nowait_mapper
6831 : llvm::omp::OMPRTL___tgt_target_data_update_mapper;
6832 mapVars = updateDataOp.getMapVars();
6833 info.HasNoWait = updateDataOp.getNowait();
6834 return success();
6835 })
6836 .DefaultUnreachable("unexpected operation");
6837
6838 if (failed(result))
6839 return failure();
6840 // Pretend we have IF(false) if we're not doing offload.
6841 if (!isOffloadEntry)
6842 ifCond = builder.getFalse();
6843
6844 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
6845 MapInfoData mapData;
6846 collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, DL,
6847 builder, useDevicePtrVars, useDeviceAddrVars);
6848
6849 // Fill up the arrays with all the mapped variables.
6850 MapInfosTy combinedInfo;
6851 auto genMapInfoCB = [&](InsertPointTy codeGenIP) -> MapInfosTy & {
6852 builder.restoreIP(codeGenIP);
6853 genMapInfos(builder, moduleTranslation, DL, combinedInfo, mapData,
6854 targetDirective);
6855 return combinedInfo;
6856 };
6857
6858 // Define a lambda to apply mappings between use_device_addr and
6859 // use_device_ptr base pointers, and their associated block arguments.
6860 auto mapUseDevice =
6861 [&moduleTranslation](
6862 llvm::OpenMPIRBuilder::DeviceInfoTy type,
6864 llvm::SmallVectorImpl<Value> &useDeviceVars, MapInfoData &mapInfoData,
6865 llvm::function_ref<llvm::Value *(llvm::Value *)> mapper = nullptr) {
6866 for (auto [arg, useDevVar] :
6867 llvm::zip_equal(blockArgs, useDeviceVars)) {
6868
6869 auto getMapBasePtr = [](omp::MapInfoOp mapInfoOp) {
6870 return mapInfoOp.getVarPtrPtr() ? mapInfoOp.getVarPtrPtr()
6871 : mapInfoOp.getVarPtr();
6872 };
6873
6874 auto useDevMap = cast<omp::MapInfoOp>(useDevVar.getDefiningOp());
6875 for (auto [mapClause, devicePointer, basePointer] : llvm::zip_equal(
6876 mapInfoData.MapClause, mapInfoData.DevicePointers,
6877 mapInfoData.BasePointers)) {
6878 auto mapOp = cast<omp::MapInfoOp>(mapClause);
6879 if (getMapBasePtr(mapOp) != getMapBasePtr(useDevMap) ||
6880 devicePointer != type)
6881 continue;
6882
6883 if (llvm::Value *devPtrInfoMap =
6884 mapper ? mapper(basePointer) : basePointer) {
6885 moduleTranslation.mapValue(arg, devPtrInfoMap);
6886 break;
6887 }
6888 }
6889 }
6890 };
6891
6892 using BodyGenTy = llvm::OpenMPIRBuilder::BodyGenTy;
6893 auto bodyGenCB = [&](InsertPointTy codeGenIP, BodyGenTy bodyGenType)
6894 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
6895 // We must always restoreIP regardless of doing anything the caller
6896 // does not restore it, leading to incorrect (no) branch generation.
6897 builder.restoreIP(codeGenIP);
6898 assert(isa<omp::TargetDataOp>(op) &&
6899 "BodyGen requested for non TargetDataOp");
6900 auto blockArgIface = cast<omp::BlockArgOpenMPOpInterface>(op);
6901 Region &region = cast<omp::TargetDataOp>(op).getRegion();
6902 switch (bodyGenType) {
6903 case BodyGenTy::Priv:
6904 // Check if any device ptr/addr info is available
6905 if (!info.DevicePtrInfoMap.empty()) {
6906 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
6907 blockArgIface.getUseDeviceAddrBlockArgs(),
6908 useDeviceAddrVars, mapData,
6909 [&](llvm::Value *basePointer) -> llvm::Value * {
6910 if (!info.DevicePtrInfoMap[basePointer].second)
6911 return nullptr;
6912 return builder.CreateLoad(
6913 builder.getPtrTy(),
6914 info.DevicePtrInfoMap[basePointer].second);
6915 });
6916 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
6917 blockArgIface.getUseDevicePtrBlockArgs(), useDevicePtrVars,
6918 mapData, [&](llvm::Value *basePointer) {
6919 return info.DevicePtrInfoMap[basePointer].second;
6920 });
6921
6922 if (failed(inlineConvertOmpRegions(region, "omp.data.region", builder,
6923 moduleTranslation)))
6924 return llvm::make_error<PreviouslyReportedError>();
6925 }
6926 break;
6927 case BodyGenTy::DupNoPriv:
6928 if (info.DevicePtrInfoMap.empty()) {
6929 // For host device we still need to do the mapping for codegen,
6930 // otherwise it may try to lookup a missing value.
6931 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
6932 blockArgIface.getUseDeviceAddrBlockArgs(),
6933 useDeviceAddrVars, mapData);
6934 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
6935 blockArgIface.getUseDevicePtrBlockArgs(), useDevicePtrVars,
6936 mapData);
6937 }
6938 break;
6939 case BodyGenTy::NoPriv:
6940 // If device info is available then region has already been generated
6941 if (info.DevicePtrInfoMap.empty()) {
6942 if (failed(inlineConvertOmpRegions(region, "omp.data.region", builder,
6943 moduleTranslation)))
6944 return llvm::make_error<PreviouslyReportedError>();
6945 }
6946 break;
6947 }
6948 return builder.saveIP();
6949 };
6950
6951 auto customMapperCB =
6952 [&](unsigned int i) -> llvm::Expected<llvm::Function *> {
6953 if (!combinedInfo.Mappers[i])
6954 return nullptr;
6955 info.HasMapper = true;
6956 return getOrCreateUserDefinedMapperFunc(combinedInfo.Mappers[i], builder,
6957 moduleTranslation, targetDirective);
6958 };
6959
6960 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
6962 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
6963 findAllocInsertPoints(builder, moduleTranslation, &deallocBlocks);
6964 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = [&]() {
6965 if (isa<omp::TargetDataOp>(op))
6966 return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(),
6967 deallocBlocks, deviceID, ifCond, info,
6968 genMapInfoCB, customMapperCB,
6969 /*MapperFunc=*/nullptr, bodyGenCB,
6970 /*DeviceAddrCB=*/nullptr);
6971 return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(),
6972 deallocBlocks, deviceID, ifCond, info,
6973 genMapInfoCB, customMapperCB, &RTLFn);
6974 }();
6975
6976 if (failed(handleError(afterIP, *op)))
6977 return failure();
6978
6979 builder.restoreIP(*afterIP);
6980 return success();
6981}
6982
6983static LogicalResult
6984convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder,
6985 LLVM::ModuleTranslation &moduleTranslation) {
6986 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
6987 auto distributeOp = cast<omp::DistributeOp>(opInst);
6988 if (failed(checkImplementationStatus(opInst)))
6989 return failure();
6990
6991 /// Process teams op reduction in distribute if the reduction is contained in
6992 /// this specific distribute op.
6993 omp::TeamsOp teamsOp = opInst.getParentOfType<omp::TeamsOp>();
6994 bool doDistributeReduction =
6995 teamsOp && getDistributeCapturingTeamsReduction(teamsOp) == distributeOp;
6996
6997 DenseMap<Value, llvm::Value *> reductionVariableMap;
6998 unsigned numReductionVars = teamsOp ? teamsOp.getNumReductionVars() : 0;
7000 SmallVector<llvm::Value *> privateReductionVariables(numReductionVars);
7001 llvm::ArrayRef<bool> isByRef;
7002
7003 if (doDistributeReduction) {
7004 isByRef = getIsByRef(teamsOp.getReductionByref());
7005 assert(isByRef.size() == teamsOp.getNumReductionVars());
7006
7007 collectReductionDecls(teamsOp, reductionDecls);
7008 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
7009 findAllocInsertPoints(builder, moduleTranslation);
7010
7011 MutableArrayRef<BlockArgument> reductionArgs =
7012 llvm::cast<omp::BlockArgOpenMPOpInterface>(*teamsOp)
7013 .getReductionBlockArgs();
7014
7016 teamsOp, reductionArgs, builder, moduleTranslation, allocaIP,
7017 reductionDecls, privateReductionVariables, reductionVariableMap,
7018 isByRef)))
7019 return failure();
7020 }
7021
7022 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
7023 auto bodyGenCB =
7024 [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
7025 llvm::ArrayRef<llvm::BasicBlock *> deallocBlocks) -> llvm::Error {
7026 // Save the alloca insertion point on ModuleTranslation stack for use in
7027 // nested regions.
7029 moduleTranslation, allocaIP, deallocBlocks);
7030
7031 // DistributeOp has only one region associated with it.
7032 builder.restoreIP(codeGenIP);
7033 PrivateVarsInfo privVarsInfo(distributeOp);
7034
7036 distributeOp, builder, moduleTranslation, privVarsInfo, allocaIP);
7037 if (handleError(afterAllocas, opInst).failed())
7038 return llvm::make_error<PreviouslyReportedError>();
7039
7040 if (handleError(initPrivateVars(builder, moduleTranslation, privVarsInfo),
7041 opInst)
7042 .failed())
7043 return llvm::make_error<PreviouslyReportedError>();
7044
7045 if (failed(copyFirstPrivateVars(
7046 distributeOp, builder, moduleTranslation, privVarsInfo.mlirVars,
7047 privVarsInfo.llvmVars, privVarsInfo.privatizers,
7048 distributeOp.getPrivateNeedsBarrier())))
7049 return llvm::make_error<PreviouslyReportedError>();
7050
7051 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
7052 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
7054 convertOmpOpRegions(distributeOp.getRegion(), "omp.distribute.region",
7055 builder, moduleTranslation);
7056 if (!regionBlock)
7057 return regionBlock.takeError();
7058 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
7059
7060 // Skip applying a workshare loop below when translating 'distribute
7061 // parallel do' (it's been already handled by this point while translating
7062 // the nested omp.wsloop).
7063 if (!isa_and_present<omp::WsloopOp>(distributeOp.getNestedWrapper())) {
7064 // TODO: Add support for clauses which are valid for DISTRIBUTE
7065 // constructs. Static schedule is the default.
7066 bool hasDistSchedule = distributeOp.getDistScheduleStatic();
7067 auto schedule = hasDistSchedule ? omp::ClauseScheduleKind::Distribute
7068 : omp::ClauseScheduleKind::Static;
7069 // dist_schedule clauses are ordered - otherise this should be false
7070 bool isOrdered = hasDistSchedule;
7071 std::optional<omp::ScheduleModifier> scheduleMod;
7072 bool isSimd = false;
7073 llvm::omp::WorksharingLoopType workshareLoopType =
7074 llvm::omp::WorksharingLoopType::DistributeStaticLoop;
7075 bool loopNeedsBarrier = false;
7076 llvm::Value *chunk = moduleTranslation.lookupValue(
7077 distributeOp.getDistScheduleChunkSize());
7078 llvm::CanonicalLoopInfo *loopInfo =
7079 findCurrentLoopInfo(moduleTranslation);
7080 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
7081 ompBuilder->applyWorkshareLoop(
7082 ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
7083 convertToScheduleKind(schedule), chunk, isSimd,
7084 scheduleMod == omp::ScheduleModifier::monotonic,
7085 scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
7086 workshareLoopType, false, hasDistSchedule, chunk);
7087
7088 if (!wsloopIP)
7089 return wsloopIP.takeError();
7090 }
7091 if (failed(cleanupPrivateVars(distributeOp, builder, moduleTranslation,
7092 distributeOp.getLoc(), privVarsInfo)))
7093 return llvm::make_error<PreviouslyReportedError>();
7094
7095 return llvm::Error::success();
7096 };
7097
7099 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
7100 findAllocInsertPoints(builder, moduleTranslation, &deallocBlocks);
7101 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
7102 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
7103 ompBuilder->createDistribute(ompLoc, allocaIP, deallocBlocks, bodyGenCB);
7104
7105 if (failed(handleError(afterIP, opInst)))
7106 return failure();
7107
7108 builder.restoreIP(*afterIP);
7109
7110 if (doDistributeReduction) {
7111 // Process the reductions if required.
7113 teamsOp, builder, moduleTranslation, allocaIP, reductionDecls,
7114 privateReductionVariables, isByRef,
7115 /*isNoWait*/ false, /*isTeamsReduction*/ true);
7116 }
7117 return success();
7118}
7119
7120/// Lowers the FlagsAttr which is applied to the module on the device
7121/// pass when offloading, this attribute contains OpenMP RTL globals that can
7122/// be passed as flags to the frontend, otherwise they are set to default
7123static LogicalResult
7124convertFlagsAttr(Operation *op, mlir::omp::FlagsAttr attribute,
7125 LLVM::ModuleTranslation &moduleTranslation) {
7126 if (!cast<mlir::ModuleOp>(op))
7127 return failure();
7128
7129 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
7130
7131 ompBuilder->M.addModuleFlag(llvm::Module::Max, "openmp-device",
7132 attribute.getOpenmpDeviceVersion());
7133
7134 if (attribute.getNoGpuLib())
7135 return success();
7136
7137 ompBuilder->createGlobalFlag(
7138 attribute.getDebugKind() /*LangOpts().OpenMPTargetDebug*/,
7139 "__omp_rtl_debug_kind");
7140 ompBuilder->createGlobalFlag(
7141 attribute
7142 .getAssumeTeamsOversubscription() /*LangOpts().OpenMPTeamSubscription*/
7143 ,
7144 "__omp_rtl_assume_teams_oversubscription");
7145 ompBuilder->createGlobalFlag(
7146 attribute
7147 .getAssumeThreadsOversubscription() /*LangOpts().OpenMPThreadSubscription*/
7148 ,
7149 "__omp_rtl_assume_threads_oversubscription");
7150 ompBuilder->createGlobalFlag(
7151 attribute.getAssumeNoThreadState() /*LangOpts().OpenMPNoThreadState*/,
7152 "__omp_rtl_assume_no_thread_state");
7153 ompBuilder->createGlobalFlag(
7154 attribute
7155 .getAssumeNoNestedParallelism() /*LangOpts().OpenMPNoNestedParallelism*/
7156 ,
7157 "__omp_rtl_assume_no_nested_parallelism");
7158 return success();
7159}
7160
7161static void getTargetEntryUniqueInfo(llvm::TargetRegionEntryInfo &targetInfo,
7162 omp::TargetOp targetOp,
7163 llvm::OpenMPIRBuilder &ompBuilder,
7164 llvm::vfs::FileSystem &vfs,
7165 llvm::StringRef parentName = "") {
7166 auto fileLoc = targetOp.getLoc()->findInstanceOf<FileLineColLoc>();
7167 assert(fileLoc && "No file found from location");
7168
7169 auto fileInfoCallBack = [&fileLoc]() {
7170 return std::pair<std::string, uint64_t>(
7171 llvm::StringRef(fileLoc.getFilename()), fileLoc.getLine());
7172 };
7173
7174 targetInfo =
7175 ompBuilder.getTargetEntryUniqueInfo(fileInfoCallBack, vfs, parentName);
7176}
7177
7178static void
7179handleDeclareTargetMapVar(MapInfoData &mapData,
7180 LLVM::ModuleTranslation &moduleTranslation,
7181 llvm::IRBuilderBase &builder, llvm::Function *func) {
7182 assert(moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
7183 "function only supported for target device codegen");
7184 llvm::IRBuilderBase::InsertPointGuard guard(builder);
7185 for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
7186 // In the case of declare target mapped variables, the basePointer is
7187 // the reference pointer generated by the convertDeclareTargetAttr
7188 // method. Whereas the kernelValue is the original variable, so for
7189 // the device we must replace all uses of this original global variable
7190 // (stored in kernelValue) with the reference pointer (stored in
7191 // basePointer for declare target mapped variables), as for device the
7192 // data is mapped into this reference pointer and should be loaded
7193 // from it, the original variable is discarded. On host both exist and
7194 // metadata is generated (elsewhere in the convertDeclareTargetAttr)
7195 // function to link the two variables in the runtime and then both the
7196 // reference pointer and the pointer are assigned in the kernel argument
7197 // structure for the host.
7198 if (!mapData.IsDeclareTarget[i])
7199 continue;
7200 // If the original map value is a constant, then we have to make sure all
7201 // of it's uses within the current kernel/function that we are going to
7202 // rewrite are converted to instructions, as we will be altering the old
7203 // use (OriginalValue) from a constant to an instruction, which will be
7204 // illegal and ICE the compiler if the user is a constant expression of
7205 // some kind e.g. a constant GEP.
7206 if (auto *constant = dyn_cast<llvm::Constant>(mapData.OriginalValue[i]))
7207 convertUsersOfConstantsToInstructions(constant, func, false);
7208
7209 // The users iterator will get invalidated if we modify an element,
7210 // so we populate this vector of uses to alter each user on an
7211 // individual basis to emit its own load (rather than one load for
7212 // all).
7214 for (llvm::User *user : mapData.OriginalValue[i]->users())
7215 userVec.push_back(user);
7216
7217 for (llvm::User *user : userVec) {
7218 auto *insn = dyn_cast<llvm::Instruction>(user);
7219 if (!insn || insn->getFunction() != func)
7220 continue;
7221 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
7222 llvm::Value *substitute = mapData.BasePointers[i];
7223 auto declTarPtr =
7224 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
7225 if (isDeclareTargetLink(declTarPtr) ||
7226 (isDeclareTargetTo(declTarPtr) &&
7227 moduleTranslation.getOpenMPBuilder()
7228 ->Config.hasRequiresUnifiedSharedMemory())) {
7229 builder.SetCurrentDebugLocation(insn->getDebugLoc());
7230 substitute = builder.CreateLoad(mapData.BasePointers[i]->getType(),
7231 mapData.BasePointers[i]);
7232 cast<llvm::LoadInst>(substitute)->moveBefore(insn->getIterator());
7233 }
7234 user->replaceUsesOfWith(mapData.OriginalValue[i], substitute);
7235 }
7236 }
7237}
7238
7239// The createDeviceArgumentAccessor function generates
7240// instructions for retrieving (acessing) kernel
7241// arguments inside of the device kernel for use by
7242// the kernel. This enables different semantics such as
7243// the creation of temporary copies of data allowing
7244// semantics like read-only/no host write back kernel
7245// arguments.
7246//
7247// This currently implements a very light version of Clang's
7248// EmitParmDecl's handling of direct argument handling as well
7249// as a portion of the argument access generation based on
7250// capture types found at the end of emitOutlinedFunctionPrologue
7251// in Clang. The indirect path handling of EmitParmDecl's may be
7252// required for future work, but a direct 1-to-1 copy doesn't seem
7253// possible as the logic is rather scattered throughout Clang's
7254// lowering and perhaps we wish to deviate slightly.
7255//
7256// \param mapData - A container containing vectors of information
7257// corresponding to the input argument, which should have a
7258// corresponding entry in the MapInfoData containers
7259// OrigialValue's.
7260// \param arg - This is the generated kernel function argument that
7261// corresponds to the passed in input argument. We generated different
7262// accesses of this Argument, based on capture type and other Input
7263// related information.
7264// \param input - This is the host side value that will be passed to
7265// the kernel i.e. the kernel input, we rewrite all uses of this within
7266// the kernel (as we generate the kernel body based on the target's region
7267// which maintians references to the original input) to the retVal argument
7268// apon exit of this function inside of the OMPIRBuilder. This interlinks
7269// the kernel argument to future uses of it in the function providing
7270// appropriate "glue" instructions inbetween.
7271// \param retVal - This is the value that all uses of input inside of the
7272// kernel will be re-written to, the goal of this function is to generate
7273// an appropriate location for the kernel argument to be accessed from,
7274// e.g. ByRef will result in a temporary allocation location and then
7275// a store of the kernel argument into this allocated memory which
7276// will then be loaded from, ByCopy will use the allocated memory
7277// directly.
7278static llvm::IRBuilderBase::InsertPoint createDeviceArgumentAccessor(
7279 omp::TargetOp targetOp, MapInfoData &mapData, llvm::Argument &arg,
7280 llvm::Value *input, llvm::Value *&retVal, llvm::IRBuilderBase &builder,
7281 llvm::OpenMPIRBuilder &ompBuilder,
7282 LLVM::ModuleTranslation &moduleTranslation,
7283 llvm::IRBuilderBase::InsertPoint allocaIP,
7284 llvm::IRBuilderBase::InsertPoint codeGenIP,
7286 assert(ompBuilder.Config.isTargetDevice() &&
7287 "function only supported for target device codegen");
7288 builder.restoreIP(allocaIP);
7289
7290 omp::VariableCaptureKind capture = omp::VariableCaptureKind::ByRef;
7291 LLVM::TypeToLLVMIRTranslator typeToLLVMIRTranslator(
7292 ompBuilder.M.getContext());
7293 unsigned alignmentValue = 0;
7294 BlockArgument mlirArg;
7296 cast<omp::BlockArgOpenMPOpInterface>(*targetOp).getBlockArgsPairs(
7297 blockArgsPairs);
7298 // Find the associated MapInfoData entry for the current input
7299 for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
7300 if (mapData.OriginalValue[i] == input) {
7301 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
7302 capture = mapOp.getMapCaptureType();
7303 // Get information of alignment of mapped object
7304 alignmentValue = typeToLLVMIRTranslator.getPreferredAlignment(
7305 mapOp.getVarPtrType(), ompBuilder.M.getDataLayout());
7306
7307 // Find the corresponding entry block argument, which can be associated to
7308 // a map, use_device* or has_device* clause.
7309 for (auto &[val, arg] : blockArgsPairs) {
7310 if (mapOp.getResult() == val) {
7311 mlirArg = arg;
7312 break;
7313 }
7314 }
7315 assert(mlirArg && "expected to find entry block argument for map clause");
7316 break;
7317 }
7318 }
7319
7320 unsigned int allocaAS = ompBuilder.M.getDataLayout().getAllocaAddrSpace();
7321 unsigned int defaultAS =
7322 ompBuilder.M.getDataLayout().getProgramAddressSpace();
7323
7324 // Create the allocation for the argument.
7325 llvm::Value *v = nullptr;
7326 if (omp::opInSharedDeviceContext(*targetOp) &&
7328 // Use the beginning of the codeGenIP rather than the usual allocation point
7329 // for shared memory allocations because otherwise these would be done prior
7330 // to the target initialization call. Also, the exit block (where the
7331 // deallocation is placed) is only executed if the initialization call
7332 // succeeds.
7333 builder.SetInsertPoint(codeGenIP.getBlock()->getFirstInsertionPt());
7334 v = ompBuilder.createOMPAllocShared(builder, arg.getType());
7335
7336 // Create deallocations in all provided deallocation points and then restore
7337 // the insertion point to right after the new allocations.
7338 llvm::IRBuilderBase::InsertPointGuard guard(builder);
7339 for (auto deallocIP : deallocIPs) {
7340 builder.SetInsertPoint(deallocIP.getBlock(), deallocIP.getPoint());
7341 ompBuilder.createOMPFreeShared(builder, v, arg.getType());
7342 }
7343 } else {
7344 // Use the current point, which was previously set to allocaIP.
7345 v = builder.CreateAlloca(arg.getType(), allocaAS);
7346
7347 if (allocaAS != defaultAS && arg.getType()->isPointerTy())
7348 v = builder.CreateAddrSpaceCast(v, builder.getPtrTy(defaultAS));
7349 }
7350
7351 builder.CreateStore(&arg, v);
7352
7353 builder.restoreIP(codeGenIP);
7354
7355 switch (capture) {
7356 case omp::VariableCaptureKind::ByCopy: {
7357 retVal = v;
7358 break;
7359 }
7360 case omp::VariableCaptureKind::ByRef: {
7361 llvm::LoadInst *loadInst = builder.CreateAlignedLoad(
7362 v->getType(), v,
7363 ompBuilder.M.getDataLayout().getPrefTypeAlign(v->getType()));
7364 // CreateAlignedLoad function creates similar LLVM IR:
7365 // %res = load ptr, ptr %input, align 8
7366 // This LLVM IR does not contain information about alignment
7367 // of the loaded value. We need to add !align metadata to unblock
7368 // optimizer. The existence of the !align metadata on the instruction
7369 // tells the optimizer that the value loaded is known to be aligned to
7370 // a boundary specified by the integer value in the metadata node.
7371 // Example:
7372 // %res = load ptr, ptr %input, align 8, !align !align_md_node
7373 // ^ ^
7374 // | |
7375 // alignment of %input address |
7376 // |
7377 // alignment of %res object
7378 if (v->getType()->isPointerTy() && alignmentValue) {
7379 llvm::MDBuilder MDB(builder.getContext());
7380 loadInst->setMetadata(
7381 llvm::LLVMContext::MD_align,
7382 llvm::MDNode::get(builder.getContext(),
7383 MDB.createConstant(llvm::ConstantInt::get(
7384 llvm::Type::getInt64Ty(builder.getContext()),
7385 alignmentValue))));
7386 }
7387 retVal = loadInst;
7388
7389 break;
7390 }
7391 case omp::VariableCaptureKind::This:
7392 case omp::VariableCaptureKind::VLAType:
7393 // TODO: Consider returning error to use standard reporting for
7394 // unimplemented features.
7395 assert(false && "Currently unsupported capture kind");
7396 break;
7397 }
7398
7399 return builder.saveIP();
7400}
7401
7402/// Follow uses of `host_eval`-defined block arguments of the given `omp.target`
7403/// operation and populate output variables with their corresponding host value
7404/// (i.e. operand evaluated outside of the target region), based on their uses
7405/// inside of the target region.
7406///
7407/// Loop bounds and steps are only optionally populated, if output vectors are
7408/// provided.
7409static void
7410extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
7411 Value &numTeamsLower, Value &numTeamsUpper,
7412 Value &threadLimit,
7413 llvm::SmallVectorImpl<Value> *lowerBounds = nullptr,
7414 llvm::SmallVectorImpl<Value> *upperBounds = nullptr,
7415 llvm::SmallVectorImpl<Value> *steps = nullptr) {
7416 auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
7417 for (auto item : llvm::zip_equal(targetOp.getHostEvalVars(),
7418 blockArgIface.getHostEvalBlockArgs())) {
7419 Value hostEvalVar = std::get<0>(item), blockArg = std::get<1>(item);
7420
7421 for (Operation *user : blockArg.getUsers()) {
7423 .Case([&](omp::TeamsOp teamsOp) {
7424 if (teamsOp.getNumTeamsLower() == blockArg)
7425 numTeamsLower = hostEvalVar;
7426 else if (llvm::is_contained(teamsOp.getNumTeamsUpperVars(),
7427 blockArg))
7428 numTeamsUpper = hostEvalVar;
7429 else if (!teamsOp.getThreadLimitVars().empty() &&
7430 teamsOp.getThreadLimit(0) == blockArg)
7431 threadLimit = hostEvalVar;
7432 else
7433 llvm_unreachable("unsupported host_eval use");
7434 })
7435 .Case([&](omp::ParallelOp parallelOp) {
7436 if (!parallelOp.getNumThreadsVars().empty() &&
7437 parallelOp.getNumThreads(0) == blockArg)
7438 numThreads = hostEvalVar;
7439 else
7440 llvm_unreachable("unsupported host_eval use");
7441 })
7442 .Case([&](omp::LoopNestOp loopOp) {
7443 auto processBounds =
7444 [&](OperandRange opBounds,
7445 llvm::SmallVectorImpl<Value> *outBounds) -> bool {
7446 bool found = false;
7447 for (auto [i, lb] : llvm::enumerate(opBounds)) {
7448 if (lb == blockArg) {
7449 found = true;
7450 if (outBounds)
7451 (*outBounds)[i] = hostEvalVar;
7452 }
7453 }
7454 return found;
7455 };
7456 bool found =
7457 processBounds(loopOp.getLoopLowerBounds(), lowerBounds);
7458 found = processBounds(loopOp.getLoopUpperBounds(), upperBounds) ||
7459 found;
7460 found = processBounds(loopOp.getLoopSteps(), steps) || found;
7461 (void)found;
7462 assert(found && "unsupported host_eval use");
7463 })
7464 .DefaultUnreachable("unsupported host_eval use");
7465 }
7466 }
7467}
7468
7469/// If \p op is of the given type parameter, return it casted to that type.
7470/// Otherwise, if its immediate parent operation (or some other higher-level
7471/// parent, if \p immediateParent is false) is of that type, return that parent
7472/// casted to the given type.
7473///
7474/// If \p op is \c null or neither it or its parent(s) are of the specified
7475/// type, return a \c null operation.
7476template <typename OpTy>
7477static OpTy castOrGetParentOfType(Operation *op, bool immediateParent = false) {
7478 if (!op)
7479 return OpTy();
7480
7481 if (OpTy casted = dyn_cast<OpTy>(op))
7482 return casted;
7483
7484 if (immediateParent)
7485 return dyn_cast_if_present<OpTy>(op->getParentOp());
7486
7487 return op->getParentOfType<OpTy>();
7488}
7489
7490/// If the given \p value is defined by an \c llvm.mlir.constant operation and
7491/// it is of an integer type, return its value.
7492static std::optional<int64_t> extractConstInteger(Value value) {
7493 if (!value)
7494 return std::nullopt;
7495
7496 if (auto constOp = value.getDefiningOp<LLVM::ConstantOp>())
7497 if (auto constAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
7498 return constAttr.getInt();
7499
7500 return std::nullopt;
7501}
7502
7503static uint64_t getTypeByteSize(mlir::Type type, const DataLayout &dl) {
7504 uint64_t sizeInBits = dl.getTypeSizeInBits(type);
7505 uint64_t sizeInBytes = sizeInBits / 8;
7506 return sizeInBytes;
7507}
7508
7509template <typename OpTy>
7510static uint64_t getReductionDataSize(OpTy &op) {
7511 if (op.getNumReductionVars() > 0) {
7513 collectReductionDecls(op, reductions);
7514
7516 members.reserve(reductions.size());
7517 for (omp::DeclareReductionOp &red : reductions) {
7518 // For by-ref reductions, use the actual element type rather than the
7519 // pointer type so that the buffer size matches the access pattern in
7520 // the copy/reduce callbacks generated by OMPIRBuilder.
7521 if (red.getByrefElementType())
7522 members.push_back(*red.getByrefElementType());
7523 else
7524 members.push_back(red.getType());
7525 }
7526 Operation *opp = op.getOperation();
7527 auto structType = mlir::LLVM::LLVMStructType::getLiteral(
7528 opp->getContext(), members, /*isPacked=*/false);
7529 DataLayout dl = DataLayout(opp->getParentOfType<ModuleOp>());
7530 return getTypeByteSize(structType, dl);
7531 }
7532 return 0;
7533}
7534
7535/// Populate default `MinTeams`, `MaxTeams` and `MaxThreads` to their default
7536/// values as stated by the corresponding clauses, if constant.
7537///
7538/// These default values must be set before the creation of the outlined LLVM
7539/// function for the target region, so that they can be used to initialize the
7540/// corresponding global `ConfigurationEnvironmentTy` structure.
7541static void
7542initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
7543 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs,
7544 bool isTargetDevice, bool isGPU) {
7545 // TODO: Handle constant 'if' clauses.
7546
7547 Value numThreads, numTeamsLower, numTeamsUpper, threadLimit;
7548 if (!isTargetDevice) {
7549 extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper,
7550 threadLimit);
7551 } else {
7552 // In the target device, values for these clauses are not passed as
7553 // host_eval, but instead evaluated prior to entry to the region. This
7554 // ensures values are mapped and available inside of the target region.
7555 if (auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
7556 numTeamsLower = teamsOp.getNumTeamsLower();
7557 // Handle num_teams upper bounds (only first value for now)
7558 if (!teamsOp.getNumTeamsUpperVars().empty())
7559 numTeamsUpper = teamsOp.getNumTeams(0);
7560 if (!teamsOp.getThreadLimitVars().empty())
7561 threadLimit = teamsOp.getThreadLimit(0);
7562 }
7563
7564 if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) {
7565 if (!parallelOp.getNumThreadsVars().empty())
7566 numThreads = parallelOp.getNumThreads(0);
7567 }
7568 }
7569
7570 // Handle clauses impacting the number of teams.
7571
7572 int32_t minTeamsVal = 1, maxTeamsVal = -1;
7573 if (castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
7574 // TODO: Use `hostNumTeamsLower` to initialize `minTeamsVal`. For now,
7575 // match clang and set min and max to the same value.
7576 if (numTeamsUpper) {
7577 if (auto val = extractConstInteger(numTeamsUpper))
7578 minTeamsVal = maxTeamsVal = *val;
7579 } else {
7580 minTeamsVal = maxTeamsVal = 0;
7581 }
7582 } else if (castOrGetParentOfType<omp::ParallelOp>(capturedOp,
7583 /*immediateParent=*/true) ||
7585 /*immediateParent=*/true)) {
7586 minTeamsVal = maxTeamsVal = 1;
7587 } else {
7588 minTeamsVal = maxTeamsVal = -1;
7589 }
7590
7591 // Handle clauses impacting the number of threads.
7592
7593 auto setMaxValueFromClause = [](Value clauseValue, int32_t &result) {
7594 if (!clauseValue)
7595 return;
7596
7597 if (auto val = extractConstInteger(clauseValue))
7598 result = *val;
7599
7600 // Found an applicable clause, so it's not undefined. Mark as unknown
7601 // because it's not constant.
7602 if (result < 0)
7603 result = 0;
7604 };
7605
7606 // Extract 'thread_limit' clause from 'target' and 'teams' directives.
7607 int32_t targetThreadLimitVal = -1, teamsThreadLimitVal = -1;
7608 if (!targetOp.getThreadLimitVars().empty())
7609 setMaxValueFromClause(targetOp.getThreadLimit(0), targetThreadLimitVal);
7610 setMaxValueFromClause(threadLimit, teamsThreadLimitVal);
7611
7612 // Extract 'max_threads' clause from 'parallel' or set to 1 if it's SIMD.
7613 int32_t maxThreadsVal = -1;
7615 setMaxValueFromClause(numThreads, maxThreadsVal);
7616 else if (castOrGetParentOfType<omp::SimdOp>(capturedOp,
7617 /*immediateParent=*/true))
7618 maxThreadsVal = 1;
7619
7620 // For max values, < 0 means unset, == 0 means set but unknown. Select the
7621 // minimum value between 'max_threads' and 'thread_limit' clauses that were
7622 // set.
7623 int32_t combinedMaxThreadsVal = targetThreadLimitVal;
7624 if (combinedMaxThreadsVal < 0 ||
7625 (teamsThreadLimitVal >= 0 && teamsThreadLimitVal < combinedMaxThreadsVal))
7626 combinedMaxThreadsVal = teamsThreadLimitVal;
7627
7628 if (combinedMaxThreadsVal < 0 ||
7629 (maxThreadsVal >= 0 && maxThreadsVal < combinedMaxThreadsVal))
7630 combinedMaxThreadsVal = maxThreadsVal;
7631
7632 int32_t reductionDataSize = 0;
7633 if (isGPU && capturedOp) {
7634 if (auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp))
7635 reductionDataSize = getReductionDataSize(teamsOp);
7636 }
7637
7638 // Update kernel bounds structure for the `OpenMPIRBuilder` to use.
7639 omp::TargetExecMode execMode = targetOp.getKernelExecFlags(capturedOp);
7640 switch (execMode) {
7641 case omp::TargetExecMode::bare:
7642 attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_BARE;
7643 break;
7644 case omp::TargetExecMode::generic:
7645 attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_GENERIC;
7646 break;
7647 case omp::TargetExecMode::spmd:
7648 attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_SPMD;
7649 break;
7650 case omp::TargetExecMode::no_loop:
7651 attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_SPMD_NO_LOOP;
7652 break;
7653 }
7654 attrs.MinTeams = minTeamsVal;
7655 attrs.MaxTeams.front() = maxTeamsVal;
7656 attrs.MinThreads = 1;
7657 attrs.MaxThreads.front() = combinedMaxThreadsVal;
7658 attrs.ReductionDataSize = reductionDataSize;
7659 // TODO: Allow modified buffer length similar to
7660 // fopenmp-cuda-teams-reduction-recs-num flag in clang.
7661 if (attrs.ReductionDataSize != 0)
7662 attrs.ReductionBufferLength = 1024;
7663}
7664
7665/// Gather LLVM runtime values for all clauses evaluated in the host that are
7666/// passed to the kernel invocation.
7667///
7668/// This function must be called only when compiling for the host. Also, it will
7669/// only provide correct results if it's called after the body of \c targetOp
7670/// has been fully generated.
7671static void
7672initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
7673 LLVM::ModuleTranslation &moduleTranslation,
7674 omp::TargetOp targetOp, Operation *capturedOp,
7675 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs) {
7676 omp::LoopNestOp loopOp = castOrGetParentOfType<omp::LoopNestOp>(capturedOp);
7677 unsigned numLoops = loopOp ? loopOp.getNumLoops() : 0;
7678
7679 Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
7680 llvm::SmallVector<Value> lowerBounds(numLoops), upperBounds(numLoops),
7681 steps(numLoops);
7682 extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper,
7683 teamsThreadLimit, &lowerBounds, &upperBounds, &steps);
7684
7685 // TODO: Handle constant 'if' clauses.
7686 if (!targetOp.getThreadLimitVars().empty()) {
7687 Value targetThreadLimit = targetOp.getThreadLimit(0);
7688 attrs.TargetThreadLimit.front() =
7689 moduleTranslation.lookupValue(targetThreadLimit);
7690 }
7691
7692 // The __kmpc_push_num_teams_51 function expects int32 as the arguments. So,
7693 // truncate or sign extend lower and upper num_teams bounds as well as
7694 // thread_limit to match int32 ABI requirements for the OpenMP runtime.
7695 if (numTeamsLower)
7696 attrs.MinTeams = builder.CreateSExtOrTrunc(
7697 moduleTranslation.lookupValue(numTeamsLower), builder.getInt32Ty());
7698
7699 if (numTeamsUpper)
7700 attrs.MaxTeams.front() = builder.CreateSExtOrTrunc(
7701 moduleTranslation.lookupValue(numTeamsUpper), builder.getInt32Ty());
7702
7703 if (teamsThreadLimit)
7704 attrs.TeamsThreadLimit.front() = builder.CreateSExtOrTrunc(
7705 moduleTranslation.lookupValue(teamsThreadLimit), builder.getInt32Ty());
7706
7707 if (numThreads)
7708 attrs.MaxThreads = moduleTranslation.lookupValue(numThreads);
7709
7710 bool hostEvalTripCount;
7711 targetOp.getKernelExecFlags(capturedOp, &hostEvalTripCount);
7712 if (hostEvalTripCount) {
7713 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
7714 attrs.LoopTripCount = nullptr;
7715
7716 // To calculate the trip count, we multiply together the trip counts of
7717 // every collapsed canonical loop. We don't need to create the loop nests
7718 // here, since we're only interested in the trip count.
7719 for (auto [loopLower, loopUpper, loopStep] :
7720 llvm::zip_equal(lowerBounds, upperBounds, steps)) {
7721 llvm::Value *lowerBound = moduleTranslation.lookupValue(loopLower);
7722 llvm::Value *upperBound = moduleTranslation.lookupValue(loopUpper);
7723 llvm::Value *step = moduleTranslation.lookupValue(loopStep);
7724
7725 if (!lowerBound || !upperBound || !step) {
7726 attrs.LoopTripCount = nullptr;
7727 break;
7728 }
7729
7730 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
7731 llvm::Value *tripCount = ompBuilder->calculateCanonicalLoopTripCount(
7732 loc, lowerBound, upperBound, step, /*IsSigned=*/true,
7733 loopOp.getLoopInclusive());
7734
7735 if (!attrs.LoopTripCount) {
7736 attrs.LoopTripCount = tripCount;
7737 continue;
7738 }
7739
7740 // TODO: Enable UndefinedSanitizer to diagnose an overflow here.
7741 attrs.LoopTripCount = builder.CreateMul(attrs.LoopTripCount, tripCount,
7742 {}, /*HasNUW=*/true);
7743 }
7744 }
7745
7746 attrs.DeviceID = builder.getInt64(llvm::omp::OMP_DEVICEID_UNDEF);
7747 if (mlir::Value devId = targetOp.getDevice()) {
7748 attrs.DeviceID = moduleTranslation.lookupValue(devId);
7749 attrs.DeviceID =
7750 builder.CreateSExtOrTrunc(attrs.DeviceID, builder.getInt64Ty());
7751 }
7752}
7753
7754static llvm::omp::OMPDynGroupprivateFallbackType
7755getDynGroupprivateFallbackType(omp::FallbackModifierAttr fallbackAttr) {
7756 omp::FallbackModifier fb = fallbackAttr ? fallbackAttr.getValue()
7757 : omp::FallbackModifier::default_mem;
7758 switch (fb) {
7759 case omp::FallbackModifier::abort:
7760 return llvm::omp::OMPDynGroupprivateFallbackType::Abort;
7761 case omp::FallbackModifier::null:
7762 return llvm::omp::OMPDynGroupprivateFallbackType::Null;
7763 case omp::FallbackModifier::default_mem:
7764 return llvm::omp::OMPDynGroupprivateFallbackType::DefaultMem;
7765 }
7766
7767 llvm_unreachable("unexpected dyn_groupprivate fallback type");
7768}
7769
7770static LogicalResult
7771convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
7772 LLVM::ModuleTranslation &moduleTranslation) {
7773 auto targetOp = cast<omp::TargetOp>(opInst);
7774
7775 // The current debug location already has the DISubprogram for the outlined
7776 // function that will be created for the target op. We save it here so that
7777 // we can set it on the outlined function.
7778 llvm::DebugLoc outlinedFnLoc = builder.getCurrentDebugLocation();
7779 if (failed(checkImplementationStatus(opInst)))
7780 return failure();
7781
7782 // During the handling of target op, we will generate instructions in the
7783 // parent function like call to the oulined function or branch to a new
7784 // BasicBlock. We set the debug location here to parent function so that those
7785 // get the correct debug locations. For outlined functions, the normal MLIR op
7786 // conversion will automatically pick the correct location.
7787 llvm::BasicBlock *parentBB = builder.GetInsertBlock();
7788 assert(parentBB && "No insert block is set for the builder");
7789 llvm::Function *parentLLVMFn = parentBB->getParent();
7790 assert(parentLLVMFn && "Parent Function must be valid");
7791 if (llvm::DISubprogram *SP = parentLLVMFn->getSubprogram())
7792 builder.SetCurrentDebugLocation(llvm::DILocation::get(
7793 parentLLVMFn->getContext(), outlinedFnLoc.getLine(),
7794 outlinedFnLoc.getCol(), SP, outlinedFnLoc.getInlinedAt()));
7795
7796 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
7797 bool isTargetDevice = ompBuilder->Config.isTargetDevice();
7798 bool isGPU = ompBuilder->Config.isGPU();
7799
7800 auto parentFn = opInst.getParentOfType<LLVM::LLVMFuncOp>();
7801 auto argIface = cast<omp::BlockArgOpenMPOpInterface>(opInst);
7802 auto &targetRegion = targetOp.getRegion();
7803 // Holds the private vars that have been mapped along with the block
7804 // argument that corresponds to the MapInfoOp corresponding to the private
7805 // var in question. So, for instance:
7806 //
7807 // %10 = omp.map.info var_ptr(%6#0 : !fir.ref<!fir.box<!fir.heap<i32>>>, ..)
7808 // omp.target map_entries(%10 -> %arg0) private(@box.privatizer %6#0-> %arg1)
7809 //
7810 // Then, %10 has been created so that the descriptor can be used by the
7811 // privatizer @box.privatizer on the device side. Here we'd record {%6#0,
7812 // %arg0} in the mappedPrivateVars map.
7813 llvm::DenseMap<Value, Value> mappedPrivateVars;
7814 DataLayout dl = DataLayout(opInst.getParentOfType<ModuleOp>());
7815 SmallVector<Value> mapVars = targetOp.getMapVars();
7816 SmallVector<Value> hdaVars = targetOp.getHasDeviceAddrVars();
7817 ArrayRef<BlockArgument> mapBlockArgs = argIface.getMapBlockArgs();
7818 ArrayRef<BlockArgument> hdaBlockArgs = argIface.getHasDeviceAddrBlockArgs();
7819 llvm::Function *llvmOutlinedFn = nullptr;
7820 TargetDirectiveEnumTy targetDirective =
7821 getTargetDirectiveEnumTyFromOp(&opInst);
7822
7823 // TODO: It can also be false if a compile-time constant `false` IF clause is
7824 // specified.
7825 bool isOffloadEntry =
7826 isTargetDevice || !ompBuilder->Config.TargetTriples.empty();
7827
7828 // For some private variables, the MapsForPrivatizedVariablesPass
7829 // creates MapInfoOp instances. Go through the private variables and
7830 // the mapped variables so that during codegeneration we are able
7831 // to quickly look up the corresponding map variable, if any for each
7832 // private variable.
7833 if (!targetOp.getPrivateVars().empty() && !targetOp.getMapVars().empty()) {
7834 OperandRange privateVars = targetOp.getPrivateVars();
7835 std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms();
7836 std::optional<DenseI64ArrayAttr> privateMapIndices =
7837 targetOp.getPrivateMapsAttr();
7838
7839 for (auto [privVarIdx, privVarSymPair] :
7840 llvm::enumerate(llvm::zip_equal(privateVars, *privateSyms))) {
7841 auto privVar = std::get<0>(privVarSymPair);
7842 auto privSym = std::get<1>(privVarSymPair);
7843
7844 SymbolRefAttr privatizerName = llvm::cast<SymbolRefAttr>(privSym);
7845 omp::PrivateClauseOp privatizer =
7846 findPrivatizer(targetOp, privatizerName);
7847
7848 if (!privatizer.needsMap())
7849 continue;
7850
7851 mlir::Value mappedValue =
7852 targetOp.getMappedValueForPrivateVar(privVarIdx);
7853 assert(mappedValue && "Expected to find mapped value for a privatized "
7854 "variable that needs mapping");
7855
7856 // The MapInfoOp defining the map var isn't really needed later.
7857 // So, we don't store it in any datastructure. Instead, we just
7858 // do some sanity checks on it right now.
7859 auto mapInfoOp = mappedValue.getDefiningOp<omp::MapInfoOp>();
7860 [[maybe_unused]] Type varType = mapInfoOp.getVarPtrType();
7861
7862 // Check #1: Check that the type of the private variable matches
7863 // the type of the variable being mapped.
7864 if (!isa<LLVM::LLVMPointerType>(privVar.getType()))
7865 assert(
7866 varType == privVar.getType() &&
7867 "Type of private var doesn't match the type of the mapped value");
7868
7869 // Ok, only 1 sanity check for now.
7870 // Record the block argument corresponding to this mapvar.
7871 mappedPrivateVars.insert(
7872 {privVar,
7873 targetRegion.getArgument(argIface.getMapBlockArgsStart() +
7874 (*privateMapIndices)[privVarIdx])});
7875 }
7876 }
7877
7878 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
7879 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
7880 ArrayRef<llvm::BasicBlock *> deallocBlocks)
7881 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
7882 llvm::IRBuilderBase::InsertPointGuard guard(builder);
7883 builder.SetCurrentDebugLocation(llvm::DebugLoc());
7884 // Forward target-cpu and target-features function attributes from the
7885 // original function to the new outlined function.
7886 llvm::Function *llvmParentFn =
7887 moduleTranslation.lookupFunction(parentFn.getName());
7888 llvmOutlinedFn = codeGenIP.getBlock()->getParent();
7889 assert(llvmParentFn && llvmOutlinedFn &&
7890 "Both parent and outlined functions must exist at this point");
7891
7892 if (outlinedFnLoc && llvmParentFn->getSubprogram())
7893 llvmOutlinedFn->setSubprogram(outlinedFnLoc->getScope()->getSubprogram());
7894
7895 if (auto attr = llvmParentFn->getFnAttribute("target-cpu");
7896 attr.isStringAttribute())
7897 llvmOutlinedFn->addFnAttr(attr);
7898
7899 if (auto attr = llvmParentFn->getFnAttribute("target-features");
7900 attr.isStringAttribute())
7901 llvmOutlinedFn->addFnAttr(attr);
7902
7903 for (auto [arg, mapOp] : llvm::zip_equal(mapBlockArgs, mapVars)) {
7904 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
7905 llvm::Value *mapOpValue =
7906 moduleTranslation.lookupValue(mapInfoOp.getVarPtr());
7907 moduleTranslation.mapValue(arg, mapOpValue);
7908 }
7909 for (auto [arg, mapOp] : llvm::zip_equal(hdaBlockArgs, hdaVars)) {
7910 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
7911 llvm::Value *mapOpValue =
7912 moduleTranslation.lookupValue(mapInfoOp.getVarPtr());
7913 moduleTranslation.mapValue(arg, mapOpValue);
7914 }
7915
7916 // Do privatization after moduleTranslation has already recorded
7917 // mapped values.
7918 PrivateVarsInfo privateVarsInfo(targetOp);
7919
7921 allocatePrivateVars(targetOp, builder, moduleTranslation,
7922 privateVarsInfo, allocaIP, &mappedPrivateVars);
7923
7924 if (failed(handleError(afterAllocas, *targetOp)))
7925 return llvm::make_error<PreviouslyReportedError>();
7926
7927 builder.restoreIP(codeGenIP);
7928 if (handleError(initPrivateVars(builder, moduleTranslation, privateVarsInfo,
7929 &mappedPrivateVars),
7930 *targetOp)
7931 .failed())
7932 return llvm::make_error<PreviouslyReportedError>();
7933
7934 if (failed(copyFirstPrivateVars(
7935 targetOp, builder, moduleTranslation, privateVarsInfo.mlirVars,
7936 privateVarsInfo.llvmVars, privateVarsInfo.privatizers,
7937 targetOp.getPrivateNeedsBarrier(), &mappedPrivateVars)))
7938 return llvm::make_error<PreviouslyReportedError>();
7939
7941 moduleTranslation, allocaIP, deallocBlocks);
7943 targetRegion, "omp.target", builder, moduleTranslation);
7944
7945 if (failed(handleError(exitBlock, *targetOp)))
7946 return llvm::make_error<PreviouslyReportedError>();
7947
7948 builder.SetInsertPoint(exitBlock.get()->getTerminator());
7949
7950 if (failed(cleanupPrivateVars(targetOp, builder, moduleTranslation,
7951 targetOp.getLoc(), privateVarsInfo)))
7952 return llvm::make_error<PreviouslyReportedError>();
7953
7954 return builder.saveIP();
7955 };
7956
7957 StringRef parentName = parentFn.getName();
7958
7959 llvm::TargetRegionEntryInfo entryInfo;
7960
7961 getTargetEntryUniqueInfo(entryInfo, targetOp,
7962 *moduleTranslation.getOpenMPBuilder(),
7963 moduleTranslation.getFileSystem(), parentName);
7964
7965 MapInfoData mapData;
7966 collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, dl,
7967 builder, /*useDevPtrOperands=*/{},
7968 /*useDevAddrOperands=*/{}, hdaVars);
7969
7970 MapInfosTy combinedInfos;
7971 auto genMapInfoCB =
7972 [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP) -> MapInfosTy & {
7973 builder.restoreIP(codeGenIP);
7974 genMapInfos(builder, moduleTranslation, dl, combinedInfos, mapData,
7975 targetDirective);
7976
7977 // Append a null entry for the implicit dyn_ptr argument so the argument
7978 // count sent to the runtime already includes it.
7979 auto *nullPtr = llvm::Constant::getNullValue(builder.getPtrTy());
7980 combinedInfos.BasePointers.push_back(nullPtr);
7981 combinedInfos.Pointers.push_back(nullPtr);
7982 combinedInfos.DevicePointers.push_back(
7983 llvm::OpenMPIRBuilder::DeviceInfoTy::None);
7984 combinedInfos.Sizes.push_back(builder.getInt64(0));
7985 combinedInfos.Types.push_back(
7986 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM |
7987 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL);
7988 if (!combinedInfos.Names.empty())
7989 combinedInfos.Names.push_back(nullPtr);
7990 combinedInfos.Mappers.push_back(nullptr);
7991
7992 return combinedInfos;
7993 };
7994
7995 auto argAccessorCB = [&](llvm::Argument &arg, llvm::Value *input,
7996 llvm::Value *&retVal, InsertPointTy allocaIP,
7997 InsertPointTy codeGenIP,
7999 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
8000 llvm::IRBuilderBase::InsertPointGuard guard(builder);
8001 builder.SetCurrentDebugLocation(llvm::DebugLoc());
8002 // We just return the unaltered argument for the host function
8003 // for now, some alterations may be required in the future to
8004 // keep host fallback functions working identically to the device
8005 // version (e.g. pass ByCopy values should be treated as such on
8006 // host and device, currently not always the case)
8007 if (!isTargetDevice) {
8008 retVal = cast<llvm::Value>(&arg);
8009 return codeGenIP;
8010 }
8011
8012 return createDeviceArgumentAccessor(targetOp, mapData, arg, input, retVal,
8013 builder, *ompBuilder, moduleTranslation,
8014 allocaIP, codeGenIP, deallocIPs);
8015 };
8016
8017 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs runtimeAttrs;
8018 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs;
8019 Operation *targetCapturedOp = targetOp.getInnermostCapturedOmpOp();
8020 initTargetDefaultAttrs(targetOp, targetCapturedOp, defaultAttrs,
8021 isTargetDevice, isGPU);
8022
8023 // Collect host-evaluated values needed to properly launch the kernel from the
8024 // host.
8025 if (!isTargetDevice)
8026 initTargetRuntimeAttrs(builder, moduleTranslation, targetOp,
8027 targetCapturedOp, runtimeAttrs);
8028
8029 // Pass host-evaluated values as parameters to the kernel / host fallback,
8030 // except if they are constants. In any case, map the MLIR block argument to
8031 // the corresponding LLVM values.
8033 SmallVector<Value> hostEvalVars = targetOp.getHostEvalVars();
8034 ArrayRef<BlockArgument> hostEvalBlockArgs = argIface.getHostEvalBlockArgs();
8035 for (auto [arg, var] : llvm::zip_equal(hostEvalBlockArgs, hostEvalVars)) {
8036 llvm::Value *value = moduleTranslation.lookupValue(var);
8037 moduleTranslation.mapValue(arg, value);
8038
8039 if (!llvm::isa<llvm::Constant>(value))
8040 kernelInput.push_back(value);
8041 }
8042
8043 for (size_t i = 0, e = mapData.OriginalValue.size(); i != e; ++i) {
8044 // 1) Declare target arguments are not passed to kernels as arguments.
8045 // 2) Attach maps are not passed in as arguments to kernels.
8046 // 3) Children of record objects are not passed in as arguments.
8047 // TODO: We currently do not handle cases where a member is explicitly
8048 // passed in as an argument, this will likley need to be handled in
8049 // the near future, rather than using IsAMember, it may be better to
8050 // test if the relevant BlockArg is used within the target region and
8051 // then use that as a basis for exclusion in the kernel inputs.
8052 bool isAttachMap = (mapData.Types[i] &
8053 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH) ==
8054 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH;
8055 if (!mapData.IsDeclareTarget[i] && !mapData.IsAMember[i] && !isAttachMap)
8056 kernelInput.push_back(mapData.OriginalValue[i]);
8057 }
8058
8060 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
8061 findAllocInsertPoints(builder, moduleTranslation, &deallocBlocks);
8062
8063 llvm::OpenMPIRBuilder::DependenciesInfo dds;
8064 if (failed(buildDependData(
8065 targetOp.getDependVars(), targetOp.getDependKinds(),
8066 targetOp.getDependIterated(), targetOp.getDependIteratedKinds(),
8067 builder, moduleTranslation, dds)))
8068 return failure();
8069
8070 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
8071
8072 llvm::OpenMPIRBuilder::TargetDataInfo info(
8073 /*RequiresDevicePointerInfo=*/false,
8074 /*SeparateBeginEndCalls=*/true);
8075
8076 auto customMapperCB =
8077 [&](unsigned int i) -> llvm::Expected<llvm::Function *> {
8078 if (!combinedInfos.Mappers[i])
8079 return nullptr;
8080 info.HasMapper = true;
8081 return getOrCreateUserDefinedMapperFunc(combinedInfos.Mappers[i], builder,
8082 moduleTranslation, targetDirective);
8083 };
8084
8085 llvm::Value *ifCond = nullptr;
8086 if (Value targetIfCond = targetOp.getIfExpr())
8087 ifCond = moduleTranslation.lookupValue(targetIfCond);
8088
8089 Value dynGroupPrivateSize = targetOp.getDynGroupprivateSize();
8090 llvm::Value *dynSizeVal = nullptr;
8091 if (dynGroupPrivateSize) {
8092 dynSizeVal = moduleTranslation.lookupValue(dynGroupPrivateSize);
8093 dynSizeVal = builder.CreateIntCast(dynSizeVal, builder.getInt32Ty(),
8094 /*isSigned=*/false);
8095 }
8096
8097 llvm::omp::OMPDynGroupprivateFallbackType fallbackType =
8098 getDynGroupprivateFallbackType(targetOp.getDynGroupprivateFallbackAttr());
8099
8100 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
8101 moduleTranslation.getOpenMPBuilder()->createTarget(
8102 ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), deallocBlocks,
8103 info, entryInfo, defaultAttrs, runtimeAttrs, ifCond, kernelInput,
8104 genMapInfoCB, bodyCB, argAccessorCB, customMapperCB, dds,
8105 targetOp.getNowait(), dynSizeVal, fallbackType);
8106
8107 if (failed(handleError(afterIP, opInst)))
8108 return failure();
8109
8110 builder.restoreIP(*afterIP);
8111
8112 if (dds.DepArray)
8113 builder.CreateFree(dds.DepArray);
8114
8115 // Remap access operations to declare target reference pointers for the
8116 // device, essentially generating extra loadop's as necessary
8117 if (moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice())
8118 handleDeclareTargetMapVar(mapData, moduleTranslation, builder,
8119 llvmOutlinedFn);
8120
8121 return success();
8122}
8123
8124static LogicalResult
8125convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute,
8126 llvm::OpenMPIRBuilder *ompBuilder,
8127 LLVM::ModuleTranslation &moduleTranslation) {
8128 // Amend omp.declare_target by deleting the IR of the outlined functions
8129 // created for target regions. They cannot be filtered out from MLIR earlier
8130 // because the omp.target operation inside must be translated to LLVM, but
8131 // the wrapper functions themselves must not remain at the end of the
8132 // process. We know that functions where omp.declare_target does not match
8133 // omp.is_target_device at this stage can only be wrapper functions because
8134 // those that aren't are removed earlier as an MLIR transformation pass.
8135 if (FunctionOpInterface funcOp = dyn_cast<FunctionOpInterface>(op)) {
8136 if (auto offloadMod = dyn_cast<omp::OffloadModuleInterface>(
8137 op->getParentOfType<ModuleOp>().getOperation())) {
8138 if (!offloadMod.getIsTargetDevice())
8139 return success();
8140
8141 omp::DeclareTargetDeviceType declareType =
8142 attribute.getDeviceType().getValue();
8143
8144 if (declareType == omp::DeclareTargetDeviceType::host) {
8145 llvm::Function *llvmFunc =
8146 moduleTranslation.lookupFunction(funcOp.getName());
8147 llvmFunc->dropAllReferences();
8148 llvmFunc->eraseFromParent();
8149
8150 // Invalidate the builder's current insertion point, as it now points to
8151 // a deleted block.
8152 ompBuilder->Builder.ClearInsertionPoint();
8153 ompBuilder->Builder.SetCurrentDebugLocation(llvm::DebugLoc());
8154 }
8155 }
8156 return success();
8157 }
8158
8159 if (LLVM::GlobalOp gOp = dyn_cast<LLVM::GlobalOp>(op)) {
8160 llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
8161 if (auto *gVal = llvmModule->getNamedValue(gOp.getSymName())) {
8162 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
8163 bool isDeclaration = gOp.isDeclaration();
8164 bool isExternallyVisible =
8165 gOp.getVisibility() != mlir::SymbolTable::Visibility::Private;
8166 auto loc = op->getLoc()->findInstanceOf<FileLineColLoc>();
8167 llvm::StringRef mangledName = gOp.getSymName();
8168 auto captureClause =
8169 convertToCaptureClauseKind(attribute.getCaptureClause().getValue());
8170 auto deviceClause =
8171 convertToDeviceClauseKind(attribute.getDeviceType().getValue());
8172 // unused for MLIR at the moment, required in Clang for book
8173 // keeping
8174 std::vector<llvm::GlobalVariable *> generatedRefs;
8175
8176 std::vector<llvm::Triple> targetTriple;
8177 auto targetTripleAttr = dyn_cast_or_null<mlir::StringAttr>(
8178 op->getParentOfType<mlir::ModuleOp>()->getAttr(
8179 LLVM::LLVMDialect::getTargetTripleAttrName()));
8180 if (targetTripleAttr)
8181 targetTriple.emplace_back(targetTripleAttr.data());
8182
8183 auto fileInfoCallBack = [&loc]() {
8184 std::string filename = "";
8185 std::uint64_t lineNo = 0;
8186
8187 if (loc) {
8188 filename = loc.getFilename().str();
8189 lineNo = loc.getLine();
8190 }
8191
8192 return std::pair<std::string, std::uint64_t>(llvm::StringRef(filename),
8193 lineNo);
8194 };
8195
8196 llvm::vfs::FileSystem &vfs = moduleTranslation.getFileSystem();
8197
8198 ompBuilder->registerTargetGlobalVariable(
8199 captureClause, deviceClause, isDeclaration, isExternallyVisible,
8200 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack, vfs),
8201 mangledName, generatedRefs, /*OpenMPSimd*/ false, targetTriple,
8202 /*GlobalInitializer*/ nullptr, /*VariableLinkage*/ nullptr,
8203 gVal->getType(), gVal);
8204
8205 if (ompBuilder->Config.isTargetDevice() &&
8206 (attribute.getCaptureClause().getValue() !=
8207 mlir::omp::DeclareTargetCaptureClause::to ||
8208 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
8209 ompBuilder->getAddrOfDeclareTargetVar(
8210 captureClause, deviceClause, isDeclaration, isExternallyVisible,
8211 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack, vfs),
8212 mangledName, generatedRefs, /*OpenMPSimd*/ false, targetTriple,
8213 gVal->getType(), /*GlobalInitializer*/ nullptr,
8214 /*VariableLinkage*/ nullptr);
8215 }
8216 }
8217 }
8218
8219 return success();
8220}
8221
8222namespace {
8223
8224/// Implementation of the dialect interface that converts operations belonging
8225/// to the OpenMP dialect to LLVM IR.
8226class OpenMPDialectLLVMIRTranslationInterface
8227 : public LLVMTranslationDialectInterface {
8228public:
8229 using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
8230
8231 /// Translates the given operation to LLVM IR using the provided IR builder
8232 /// and saving the state in `moduleTranslation`.
8233 LogicalResult
8234 convertOperation(Operation *op, llvm::IRBuilderBase &builder,
8235 LLVM::ModuleTranslation &moduleTranslation) const final;
8236
8237 /// Given an OpenMP MLIR attribute, create the corresponding LLVM-IR,
8238 /// runtime calls, or operation amendments
8239 LogicalResult
8240 amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
8241 NamedAttribute attribute,
8242 LLVM::ModuleTranslation &moduleTranslation) const final;
8243
8244 /// Records the LLVM alloc pointer produced for an OMP ALLOCATE variable so
8245 /// that the paired omp.allocate_free op can generate the matching
8246 /// __kmpc_free call.
8247 void registerAllocatedPtr(Value var, llvm::Value *ptr) const {
8248 ompAllocatedPtrs[var] = ptr;
8249 }
8250
8251 /// Returns the LLVM alloc pointer previously registered for var, or
8252 /// nullptr if no allocation was recorded.
8253 llvm::Value *lookupAllocatedPtr(Value var) const {
8254 auto it = ompAllocatedPtrs.find(var);
8255 return it != ompAllocatedPtrs.end() ? it->second : nullptr;
8256 }
8257
8258private:
8259 /// Maps each MLIR variable value that appeared in an omp.allocate_dir op to
8260 /// the LLVM pointer returned by the corresponding __kmpc_alloc call. The
8261 /// paired omp.allocate_free op looks up these pointers to emit __kmpc_free.
8262 mutable DenseMap<Value, llvm::Value *> ompAllocatedPtrs;
8263};
8264
8265} // namespace
8266
8267LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
8268 Operation *op, ArrayRef<llvm::Instruction *> instructions,
8269 NamedAttribute attribute,
8270 LLVM::ModuleTranslation &moduleTranslation) const {
8271 return llvm::StringSwitch<llvm::function_ref<LogicalResult(Attribute)>>(
8272 attribute.getName())
8273 .Case("omp.is_target_device",
8274 [&](Attribute attr) {
8275 if (auto deviceAttr = dyn_cast<BoolAttr>(attr)) {
8276 llvm::OpenMPIRBuilderConfig &config =
8277 moduleTranslation.getOpenMPBuilder()->Config;
8278 config.setIsTargetDevice(deviceAttr.getValue());
8279 return success();
8280 }
8281 return failure();
8282 })
8283 .Case("omp.is_gpu",
8284 [&](Attribute attr) {
8285 if (auto gpuAttr = dyn_cast<BoolAttr>(attr)) {
8286 llvm::OpenMPIRBuilderConfig &config =
8287 moduleTranslation.getOpenMPBuilder()->Config;
8288 config.setIsGPU(gpuAttr.getValue());
8289 return success();
8290 }
8291 return failure();
8292 })
8293 .Case("omp.host_ir_filepath",
8294 [&](Attribute attr) {
8295 if (auto filepathAttr = dyn_cast<StringAttr>(attr)) {
8296 llvm::OpenMPIRBuilder *ompBuilder =
8297 moduleTranslation.getOpenMPBuilder();
8298 ompBuilder->loadOffloadInfoMetadata(
8299 moduleTranslation.getFileSystem(), filepathAttr.getValue());
8300 return success();
8301 }
8302 return failure();
8303 })
8304 .Case("omp.flags",
8305 [&](Attribute attr) {
8306 if (auto rtlAttr = dyn_cast<omp::FlagsAttr>(attr))
8307 return convertFlagsAttr(op, rtlAttr, moduleTranslation);
8308 return failure();
8309 })
8310 .Case("omp.version",
8311 [&](Attribute attr) {
8312 if (auto versionAttr = dyn_cast<omp::VersionAttr>(attr)) {
8313 llvm::OpenMPIRBuilder *ompBuilder =
8314 moduleTranslation.getOpenMPBuilder();
8315 ompBuilder->M.addModuleFlag(llvm::Module::Max, "openmp",
8316 versionAttr.getVersion());
8317 return success();
8318 }
8319 return failure();
8320 })
8321 .Case("omp.declare_target",
8322 [&](Attribute attr) {
8323 if (auto declareTargetAttr =
8324 dyn_cast<omp::DeclareTargetAttr>(attr)) {
8325 llvm::OpenMPIRBuilder *ompBuilder =
8326 moduleTranslation.getOpenMPBuilder();
8327 return convertDeclareTargetAttr(op, declareTargetAttr,
8328 ompBuilder, moduleTranslation);
8329 }
8330 return failure();
8331 })
8332 .Case("omp.requires",
8333 [&](Attribute attr) {
8334 if (auto requiresAttr = dyn_cast<omp::ClauseRequiresAttr>(attr)) {
8335 using Requires = omp::ClauseRequires;
8336 Requires flags = requiresAttr.getValue();
8337 llvm::OpenMPIRBuilderConfig &config =
8338 moduleTranslation.getOpenMPBuilder()->Config;
8339 config.setHasRequiresReverseOffload(
8340 bitEnumContainsAll(flags, Requires::reverse_offload));
8341 config.setHasRequiresUnifiedAddress(
8342 bitEnumContainsAll(flags, Requires::unified_address));
8343 config.setHasRequiresUnifiedSharedMemory(
8344 bitEnumContainsAll(flags, Requires::unified_shared_memory));
8345 config.setHasRequiresDynamicAllocators(
8346 bitEnumContainsAll(flags, Requires::dynamic_allocators));
8347 return success();
8348 }
8349 return failure();
8350 })
8351 .Case("omp.target_triples",
8352 [&](Attribute attr) {
8353 if (auto triplesAttr = dyn_cast<ArrayAttr>(attr)) {
8354 llvm::OpenMPIRBuilderConfig &config =
8355 moduleTranslation.getOpenMPBuilder()->Config;
8356 config.TargetTriples.clear();
8357 config.TargetTriples.reserve(triplesAttr.size());
8358 for (Attribute tripleAttr : triplesAttr) {
8359 if (auto tripleStrAttr = dyn_cast<StringAttr>(tripleAttr))
8360 config.TargetTriples.emplace_back(tripleStrAttr.getValue());
8361 else
8362 return failure();
8363 }
8364 return success();
8365 }
8366 return failure();
8367 })
8368 .Default([](Attribute) {
8369 // Fall through for omp attributes that do not require lowering.
8370 return success();
8371 })(attribute.getValue());
8372
8373 return failure();
8374}
8375
8376// Returns true if the operation is not inside a TargetOp, it is part of a
8377// function and that function is not declare target.
8378static bool isHostDeviceOp(Operation *op) {
8379 // Assumes no reverse offloading
8380 if (op->getParentOfType<omp::TargetOp>())
8381 return false;
8382
8383 if (auto parentFn = op->getParentOfType<LLVM::LLVMFuncOp>()) {
8384 if (auto declareTargetIface =
8385 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
8386 parentFn.getOperation()))
8387 if (declareTargetIface.isDeclareTarget() &&
8388 declareTargetIface.getDeclareTargetDeviceType() !=
8389 mlir::omp::DeclareTargetDeviceType::host)
8390 return false;
8391
8392 return true;
8393 }
8394
8395 return false;
8396}
8397
8398static llvm::Function *getOmpTargetAlloc(llvm::IRBuilderBase &builder,
8399 llvm::Module *llvmModule) {
8400 llvm::Type *i64Ty = builder.getInt64Ty();
8401 llvm::Type *i32Ty = builder.getInt32Ty();
8402 llvm::Type *returnType = builder.getPtrTy(0);
8403 llvm::FunctionType *fnType =
8404 llvm::FunctionType::get(returnType, {i64Ty, i32Ty}, false);
8405 llvm::Function *func = cast<llvm::Function>(
8406 llvmModule->getOrInsertFunction("omp_target_alloc", fnType).getCallee());
8407 return func;
8408}
8409
8410template <typename T>
8411static llvm::Value *
8412getAllocationSize(llvm::IRBuilderBase &builder,
8413 LLVM::ModuleTranslation &moduleTranslation, T op) {
8414 llvm::DataLayout dataLayout =
8415 moduleTranslation.getLLVMModule()->getDataLayout();
8416 llvm::Type *llvmHeapTy =
8417 moduleTranslation.convertType(op.getMemElemTypeAttr().getValue());
8418
8419 auto alignment = op.getMemAlignment();
8420 llvm::TypeSize typeSize = llvm::alignTo(
8421 dataLayout.getTypeStoreSize(llvmHeapTy),
8422 alignment ? *alignment : dataLayout.getABITypeAlign(llvmHeapTy).value());
8423
8424 llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue());
8425 return builder.CreateMul(
8426 allocSize,
8427 builder.CreateIntCast(moduleTranslation.lookupValue(op.getMemArraySize()),
8428 builder.getInt64Ty(),
8429 /*isSigned=*/false));
8430}
8431
8432template <>
8433llvm::Value *getAllocationSize(llvm::IRBuilderBase &builder,
8434 LLVM::ModuleTranslation &moduleTranslation,
8435 omp::TargetAllocMemOp op) {
8436 llvm::DataLayout dataLayout =
8437 moduleTranslation.getLLVMModule()->getDataLayout();
8438 llvm::Type *llvmHeapTy = moduleTranslation.convertType(op.getAllocatedType());
8439 llvm::TypeSize typeSize = dataLayout.getTypeAllocSize(llvmHeapTy);
8440 llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue());
8441 for (auto typeParam : op.getTypeparams()) {
8442 allocSize = builder.CreateMul(
8443 allocSize,
8444 builder.CreateIntCast(moduleTranslation.lookupValue(typeParam),
8445 builder.getInt64Ty(),
8446 /*isSigned=*/false));
8447 }
8448 return allocSize;
8449}
8450
8451static LogicalResult
8452convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder,
8453 LLVM::ModuleTranslation &moduleTranslation) {
8454 auto allocMemOp = cast<omp::TargetAllocMemOp>(opInst);
8455 if (!allocMemOp)
8456 return failure();
8457
8458 // Get "omp_target_alloc" function
8459 llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
8460 llvm::Function *ompTargetAllocFunc = getOmpTargetAlloc(builder, llvmModule);
8461 // Get the corresponding device value in llvm
8462 mlir::Value deviceNum = allocMemOp.getDevice();
8463 llvm::Value *llvmDeviceNum = moduleTranslation.lookupValue(deviceNum);
8464 // Get the allocation size.
8465 llvm::Value *allocSize =
8466 getAllocationSize(builder, moduleTranslation, allocMemOp);
8467 // Create call to "omp_target_alloc" with the args as translated llvm values.
8468 llvm::CallInst *call =
8469 builder.CreateCall(ompTargetAllocFunc, {allocSize, llvmDeviceNum});
8470 llvm::Value *resultI64 = builder.CreatePtrToInt(call, builder.getInt64Ty());
8471
8472 // Map the result
8473 moduleTranslation.mapValue(allocMemOp.getResult(), resultI64);
8474 return success();
8475}
8476
8477static LogicalResult
8478convertAllocSharedMemOp(omp::AllocSharedMemOp allocMemOp,
8479 llvm::IRBuilderBase &builder,
8480 LLVM::ModuleTranslation &moduleTranslation) {
8481 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
8482 llvm::Value *size = getAllocationSize(builder, moduleTranslation, allocMemOp);
8483 moduleTranslation.mapValue(allocMemOp.getResult(),
8484 ompBuilder->createOMPAllocShared(builder, size));
8485 return success();
8486}
8487
8488static LogicalResult
8489convertAllocateDirOp(Operation &opInst, llvm::IRBuilderBase &builder,
8490 LLVM::ModuleTranslation &moduleTranslation,
8491 const OpenMPDialectLLVMIRTranslationInterface &ompIface) {
8492 auto allocateDirOp = cast<omp::AllocateDirOp>(opInst);
8493 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
8494
8495 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
8496 llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
8497 llvm::DataLayout dataLayout = llvmModule->getDataLayout();
8498 SmallVector<Value> vars = allocateDirOp.getVarList();
8499 std::optional<int64_t> alignAttr = allocateDirOp.getAlign();
8500
8501 llvm::Value *allocator;
8502 if (auto allocatorVar = allocateDirOp.getAllocator()) {
8503 allocator = moduleTranslation.lookupValue(allocatorVar);
8504 if (allocator->getType()->isIntegerTy())
8505 allocator = builder.CreateIntToPtr(allocator, builder.getPtrTy());
8506 else if (allocator->getType()->isPointerTy())
8507 allocator = builder.CreatePointerBitCastOrAddrSpaceCast(
8508 allocator, builder.getPtrTy());
8509 } else {
8510 allocator = llvm::ConstantPointerNull::get(builder.getPtrTy());
8511 }
8512
8513 for (Value var : vars) {
8514 llvm::Type *llvmVarTy = moduleTranslation.convertType(var.getType());
8515
8516 // Opaque pointers lose element type. Trace to GlobalOp for type
8517 // Falls back to llvmVarTy when not from a global.
8518 llvm::Type *typeToInspect = llvmVarTy;
8519 if (llvmVarTy->isPointerTy()) {
8520 Value baseVar = getBaseValueForTypeLookup(var);
8521 if (Operation *globalOp = getGlobalOpFromValue(baseVar)) {
8522 if (auto gop = dyn_cast<LLVM::GlobalOp>(globalOp))
8523 typeToInspect = moduleTranslation.convertType(gop.getGlobalType());
8524 }
8525 }
8526
8527 llvm::Value *size;
8528 if (auto arrTy = llvm::dyn_cast<llvm::ArrayType>(typeToInspect)) {
8529 llvm::Value *elementCount = builder.getInt64(1);
8530 llvm::Type *currentType = arrTy;
8531 while (auto nestedArrTy = llvm::dyn_cast<llvm::ArrayType>(currentType)) {
8532 elementCount = builder.CreateMul(
8533 elementCount, builder.getInt64(nestedArrTy->getNumElements()));
8534 currentType = nestedArrTy->getElementType();
8535 }
8536 uint64_t elemSizeInBits = dataLayout.getTypeSizeInBits(currentType);
8537 size =
8538 builder.CreateMul(elementCount, builder.getInt64(elemSizeInBits / 8));
8539 } else {
8540 size = builder.getInt64(
8541 dataLayout.getTypeStoreSize(typeToInspect).getFixedValue());
8542 }
8543
8544 uint64_t alignValue =
8545 alignAttr ? alignAttr.value()
8546 : dataLayout.getABITypeAlign(typeToInspect).value();
8547 llvm::Value *alignConst = builder.getInt64(alignValue);
8548 // Align the size: ((size + align - 1) / align) * align
8549 size = builder.CreateAdd(size, builder.getInt64(alignValue - 1), "", true);
8550 size = builder.CreateUDiv(size, alignConst);
8551 size = builder.CreateMul(size, alignConst, "", true);
8552
8553 std::string allocName =
8554 ompBuilder->createPlatformSpecificName({".void.addr"});
8555 llvm::CallInst *allocCall;
8556 if (alignAttr.has_value()) {
8557 allocCall = ompBuilder->createOMPAlignedAlloc(
8558 ompLoc, builder.getInt64(alignAttr.value()), size, allocator,
8559 allocName);
8560 } else {
8561 allocCall =
8562 ompBuilder->createOMPAlloc(ompLoc, size, allocator, allocName);
8563 }
8564 // Record the alloc pointer keyed by the MLIR variable value.
8565 ompIface.registerAllocatedPtr(var, allocCall);
8566 }
8567
8568 return success();
8569}
8570
8571static LogicalResult
8572convertAllocateFreeOp(Operation &opInst, llvm::IRBuilderBase &builder,
8573 LLVM::ModuleTranslation &moduleTranslation,
8574 const OpenMPDialectLLVMIRTranslationInterface &ompIface) {
8575 auto freeOp = cast<omp::AllocateFreeOp>(opInst);
8576 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
8577 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
8578
8579 llvm::Value *allocator;
8580 if (auto allocatorVar = freeOp.getAllocator()) {
8581 allocator = moduleTranslation.lookupValue(allocatorVar);
8582 if (allocator->getType()->isIntegerTy())
8583 allocator = builder.CreateIntToPtr(allocator, builder.getPtrTy());
8584 else if (allocator->getType()->isPointerTy())
8585 allocator = builder.CreatePointerBitCastOrAddrSpaceCast(
8586 allocator, builder.getPtrTy());
8587 } else {
8588 allocator = llvm::ConstantPointerNull::get(builder.getPtrTy());
8589 }
8590
8591 // Emit __kmpc_free for each variable in reverse allocation order.
8592 SmallVector<Value> vars = freeOp.getVarList();
8593 for (Value var : llvm::reverse(vars)) {
8594 llvm::Value *allocPtr = ompIface.lookupAllocatedPtr(var);
8595 if (!allocPtr)
8596 return opInst.emitError("omp.allocate_free: no allocation recorded");
8597 ompBuilder->createOMPFree(ompLoc, allocPtr, allocator, "");
8598 }
8599
8600 return success();
8601}
8602
8603static llvm::Function *getOmpTargetFree(llvm::IRBuilderBase &builder,
8604 llvm::Module *llvmModule) {
8605 llvm::Type *ptrTy = builder.getPtrTy(0);
8606 llvm::Type *i32Ty = builder.getInt32Ty();
8607 llvm::Type *voidTy = builder.getVoidTy();
8608 llvm::FunctionType *fnType =
8609 llvm::FunctionType::get(voidTy, {ptrTy, i32Ty}, false);
8610 llvm::Function *func = dyn_cast<llvm::Function>(
8611 llvmModule->getOrInsertFunction("omp_target_free", fnType).getCallee());
8612 return func;
8613}
8614
8615static LogicalResult
8616convertTargetFreeMemOp(Operation &opInst, llvm::IRBuilderBase &builder,
8617 LLVM::ModuleTranslation &moduleTranslation) {
8618 auto freeMemOp = cast<omp::TargetFreeMemOp>(opInst);
8619 if (!freeMemOp)
8620 return failure();
8621
8622 // Get "omp_target_free" function
8623 llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
8624 llvm::Function *ompTragetFreeFunc = getOmpTargetFree(builder, llvmModule);
8625 // Get the corresponding device value in llvm
8626 mlir::Value deviceNum = freeMemOp.getDevice();
8627 llvm::Value *llvmDeviceNum = moduleTranslation.lookupValue(deviceNum);
8628 // Get the corresponding heapref value in llvm
8629 mlir::Value heapref = freeMemOp.getHeapref();
8630 llvm::Value *llvmHeapref = moduleTranslation.lookupValue(heapref);
8631 // Convert heapref int to ptr and call "omp_target_free"
8632 llvm::Value *intToPtr =
8633 builder.CreateIntToPtr(llvmHeapref, builder.getPtrTy(0));
8634 builder.CreateCall(ompTragetFreeFunc, {intToPtr, llvmDeviceNum});
8635 return success();
8636}
8637
8638static LogicalResult
8639convertFreeSharedMemOp(omp::FreeSharedMemOp freeMemOp,
8640 llvm::IRBuilderBase &builder,
8641 LLVM::ModuleTranslation &moduleTranslation) {
8642 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
8643 llvm::Value *size = getAllocationSize(builder, moduleTranslation, freeMemOp);
8644 ompBuilder->createOMPFreeShared(
8645 builder, moduleTranslation.lookupValue(freeMemOp.getHeapref()), size);
8646 return success();
8647}
8648
8649/// Converts an OpenMP groupprivate operation into LLVM IR.
8650static LogicalResult
8651convertOmpGroupprivate(Operation &opInst, llvm::IRBuilderBase &builder,
8652 LLVM::ModuleTranslation &moduleTranslation) {
8653 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
8654 auto groupprivateOp = cast<omp::GroupprivateOp>(opInst);
8655
8656 if (failed(checkImplementationStatus(opInst)))
8657 return failure();
8658
8659 bool isTargetDevice = ompBuilder->Config.isTargetDevice();
8660
8661 // Determine whether group-private storage should be allocated based on
8662 // device_type. When not specified, default to 'any' (allocate on both).
8663 bool shouldAllocate = true;
8664 switch (groupprivateOp.getDeviceType().value_or(
8665 mlir::omp::DeclareTargetDeviceType::any)) {
8666 case mlir::omp::DeclareTargetDeviceType::host:
8667 shouldAllocate = !isTargetDevice;
8668 break;
8669 case mlir::omp::DeclareTargetDeviceType::nohost:
8670 shouldAllocate = isTargetDevice;
8671 break;
8672 case mlir::omp::DeclareTargetDeviceType::any:
8673 shouldAllocate = true;
8674 break;
8675 }
8676
8677 // Look up the global variable directly by symbol name.
8679 &opInst, groupprivateOp.getSymNameAttr());
8680 if (!global)
8681 return opInst.emitError()
8682 << "expected symbol '" << groupprivateOp.getSymName()
8683 << "' to reference an LLVM global variable";
8684
8685 llvm::GlobalValue *globalValue = moduleTranslation.lookupGlobal(global);
8686 llvm::Type *varType = moduleTranslation.convertType(global.getType());
8687 std::string varName = globalValue->getName().str();
8688
8689 llvm::Value *resultPtr;
8690 if (shouldAllocate && isTargetDevice) {
8691 llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
8692 llvm::Triple targetTriple(llvmModule->getTargetTriple());
8693 unsigned sharedAddressSpace;
8694 if (targetTriple.isAMDGCN())
8695 sharedAddressSpace = llvm::AMDGPUAS::LOCAL_ADDRESS;
8696 else if (targetTriple.isNVPTX())
8697 sharedAddressSpace = llvm::NVPTXAS::ADDRESS_SPACE_SHARED;
8698 else
8699 return opInst.emitError() << "groupprivate is not supported for target: "
8700 << targetTriple.str();
8701 llvm::GlobalVariable *sharedVar = new llvm::GlobalVariable(
8702 *llvmModule, varType, /*isConstant=*/false,
8703 llvm::GlobalValue::InternalLinkage, llvm::PoisonValue::get(varType),
8704 varName, /*InsertBefore=*/nullptr, llvm::GlobalValue::NotThreadLocal,
8705 sharedAddressSpace,
8706 /*isExternallyInitialized=*/false);
8707 resultPtr = sharedVar;
8708 } else {
8709 if (shouldAllocate && !isTargetDevice)
8710 opInst.emitWarning("groupprivate directive is currently ignored on the "
8711 "host, using original global");
8712 resultPtr = globalValue;
8713 }
8714
8715 moduleTranslation.mapValue(opInst.getResult(0), resultPtr);
8716 return success();
8717}
8718
8719/// Given an OpenMP MLIR operation, create the corresponding LLVM IR (including
8720/// OpenMP runtime calls).
8721LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
8722 Operation *op, llvm::IRBuilderBase &builder,
8723 LLVM::ModuleTranslation &moduleTranslation) const {
8724 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
8725
8726 if (ompBuilder->Config.isTargetDevice() &&
8727 !isa<omp::TargetOp, omp::MapInfoOp, omp::TerminatorOp, omp::YieldOp>(
8728 op) &&
8729 isHostDeviceOp(op))
8730 return op->emitOpError() << "unsupported host op found in device";
8731
8732 // For each loop, introduce one stack frame to hold loop information. Ensure
8733 // this is only done for the outermost loop wrapper to prevent introducing
8734 // multiple stack frames for a single loop. Initially set to null, the loop
8735 // information structure is initialized during translation of the nested
8736 // omp.loop_nest operation, making it available to translation of all loop
8737 // wrappers after their body has been successfully translated.
8738 bool isOutermostLoopWrapper =
8739 isa_and_present<omp::LoopWrapperInterface>(op) &&
8740 !dyn_cast_if_present<omp::LoopWrapperInterface>(op->getParentOp());
8741
8742 // The TASKLOOP construct is implemented with an outer taskloop.context
8743 // operation which is not a loop wrapper, containing an inner taskloop
8744 // operation which is a loop wrapper. The stack frame should be pushed when
8745 // translating the outer taskloop.context and popped when translating the
8746 // inner taskloop which is a loop wrapper. We need access to the loop
8747 // information in the outer taskloop context so we need to create it and pop
8748 // it around the taskloop context not the inner loop wrapper.
8749 if (isa<omp::TaskloopContextOp>(op))
8750 isOutermostLoopWrapper = true;
8751 else if (isa<omp::TaskloopWrapperOp>(op))
8752 isOutermostLoopWrapper = false;
8753
8754 if (isOutermostLoopWrapper)
8755 moduleTranslation.stackPush<OpenMPLoopInfoStackFrame>();
8756
8757 auto result =
8758 llvm::TypeSwitch<Operation *, LogicalResult>(op)
8759 .Case([&](omp::BarrierOp op) -> LogicalResult {
8761 return failure();
8762
8763 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
8764 ompBuilder->createBarrier(builder.saveIP(),
8765 llvm::omp::OMPD_barrier);
8766 LogicalResult res = handleError(afterIP, *op);
8767 if (res.succeeded()) {
8768 // If the barrier generated a cancellation check, the insertion
8769 // point might now need to be changed to a new continuation block
8770 builder.restoreIP(*afterIP);
8771 }
8772 return res;
8773 })
8774 .Case([&](omp::TaskyieldOp op) {
8776 return failure();
8777
8778 ompBuilder->createTaskyield(builder.saveIP());
8779 return success();
8780 })
8781 .Case([&](omp::FlushOp op) {
8783 return failure();
8784
8785 // No support in Openmp runtime function (__kmpc_flush) to accept
8786 // the argument list.
8787 // OpenMP standard states the following:
8788 // "An implementation may implement a flush with a list by ignoring
8789 // the list, and treating it the same as a flush without a list."
8790 //
8791 // The argument list is discarded so that, flush with a list is
8792 // treated same as a flush without a list.
8793 ompBuilder->createFlush(builder.saveIP());
8794 return success();
8795 })
8796 .Case([&](omp::ParallelOp op) {
8797 return convertOmpParallel(op, builder, moduleTranslation);
8798 })
8799 .Case([&](omp::MaskedOp) {
8800 return convertOmpMasked(*op, builder, moduleTranslation);
8801 })
8802 .Case([&](omp::MasterOp) {
8803 return convertOmpMaster(*op, builder, moduleTranslation);
8804 })
8805 .Case([&](omp::CriticalOp) {
8806 return convertOmpCritical(*op, builder, moduleTranslation);
8807 })
8808 .Case([&](omp::OrderedRegionOp) {
8809 return convertOmpOrderedRegion(*op, builder, moduleTranslation);
8810 })
8811 .Case([&](omp::OrderedOp) {
8812 return convertOmpOrdered(*op, builder, moduleTranslation);
8813 })
8814 .Case([&](omp::WsloopOp) {
8815 return convertOmpWsloop(*op, builder, moduleTranslation);
8816 })
8817 .Case([&](omp::SimdOp) {
8818 return convertOmpSimd(*op, builder, moduleTranslation);
8819 })
8820 .Case([&](omp::AtomicReadOp) {
8821 return convertOmpAtomicRead(*op, builder, moduleTranslation);
8822 })
8823 .Case([&](omp::AtomicWriteOp) {
8824 return convertOmpAtomicWrite(*op, builder, moduleTranslation);
8825 })
8826 .Case([&](omp::AtomicUpdateOp op) {
8827 return convertOmpAtomicUpdate(op, builder, moduleTranslation);
8828 })
8829 .Case([&](omp::AtomicCaptureOp op) {
8830 return convertOmpAtomicCapture(op, builder, moduleTranslation);
8831 })
8832 .Case([&](omp::AtomicCompareOp op) {
8833 return convertOmpAtomicCompare(op, builder, moduleTranslation);
8834 })
8835 .Case([&](omp::CancelOp op) {
8836 return convertOmpCancel(op, builder, moduleTranslation);
8837 })
8838 .Case([&](omp::CancellationPointOp op) {
8839 return convertOmpCancellationPoint(op, builder, moduleTranslation);
8840 })
8841 .Case([&](omp::SectionsOp) {
8842 return convertOmpSections(*op, builder, moduleTranslation);
8843 })
8844 .Case([&](omp::ScopeOp op) {
8845 return convertOmpScope(op, builder, moduleTranslation);
8846 })
8847 .Case([&](omp::SingleOp op) {
8848 return convertOmpSingle(op, builder, moduleTranslation);
8849 })
8850 .Case([&](omp::TeamsOp op) {
8851 return convertOmpTeams(op, builder, moduleTranslation);
8852 })
8853 .Case([&](omp::TaskOp op) {
8854 return convertOmpTaskOp(op, builder, moduleTranslation);
8855 })
8856 .Case([&](omp::TaskloopWrapperOp op) {
8857 return convertOmpTaskloopWrapperOp(op, builder, moduleTranslation);
8858 })
8859 .Case([&](omp::TaskloopContextOp op) {
8860 return convertOmpTaskloopContextOp(op, builder, moduleTranslation);
8861 })
8862 .Case([&](omp::TaskgroupOp op) {
8863 return convertOmpTaskgroupOp(op, builder, moduleTranslation);
8864 })
8865 .Case([&](omp::TaskwaitOp op) {
8866 return convertOmpTaskwaitOp(op, builder, moduleTranslation);
8867 })
8868 .Case<omp::YieldOp, omp::TerminatorOp, omp::DeclareMapperOp,
8869 omp::DeclareMapperInfoOp, omp::DeclareReductionOp,
8870 omp::CriticalDeclareOp>([](auto op) {
8871 // `yield` and `terminator` can be just omitted. The block structure
8872 // was created in the region that handles their parent operation.
8873 // `declare_reduction` will be used by reductions and is not
8874 // converted directly, skip it.
8875 // `declare_mapper` and `declare_mapper.info` are handled whenever
8876 // they are referred to through a `map` clause.
8877 // `critical.declare` is only used to declare names of critical
8878 // sections which will be used by `critical` ops and hence can be
8879 // ignored for lowering. The OpenMP IRBuilder will create unique
8880 // name for critical section names.
8881 return success();
8882 })
8883 .Case([&](omp::ThreadprivateOp) {
8884 return convertOmpThreadprivate(*op, builder, moduleTranslation);
8885 })
8886 .Case<omp::TargetDataOp, omp::TargetEnterDataOp,
8887 omp::TargetExitDataOp, omp::TargetUpdateOp>([&](auto op) {
8888 return convertOmpTargetData(op, builder, moduleTranslation);
8889 })
8890 .Case([&](omp::TargetOp) {
8891 return convertOmpTarget(*op, builder, moduleTranslation);
8892 })
8893 .Case([&](omp::DistributeOp) {
8894 return convertOmpDistribute(*op, builder, moduleTranslation);
8895 })
8896 .Case([&](omp::LoopNestOp) {
8897 return convertOmpLoopNest(*op, builder, moduleTranslation);
8898 })
8899 .Case<omp::MapInfoOp, omp::MapBoundsOp, omp::PrivateClauseOp,
8900 omp::AffinityEntryOp, omp::IteratorOp>([&](auto op) {
8901 // No-op, should be handled by relevant owning operations e.g.
8902 // TargetOp, TargetEnterDataOp, TargetExitDataOp, TargetDataOp
8903 // etc. and then discarded
8904 return success();
8905 })
8906 .Case([&](omp::NewCliOp op) {
8907 // Meta-operation: Doesn't do anything by itself, but used to
8908 // identify a loop.
8909 return success();
8910 })
8911 .Case([&](omp::CanonicalLoopOp op) {
8912 return convertOmpCanonicalLoopOp(op, builder, moduleTranslation);
8913 })
8914 .Case([&](omp::UnrollHeuristicOp op) {
8915 // FIXME: Handling omp.unroll_heuristic as an executable requires
8916 // that the generator (e.g. omp.canonical_loop) has been seen first.
8917 // For construct that require all codegen to occur inside a callback
8918 // (e.g. OpenMPIRBilder::createParallel), all codegen of that
8919 // contained region including their transformations must occur at
8920 // the omp.canonical_loop.
8921 return applyUnrollHeuristic(op, builder, moduleTranslation);
8922 })
8923 .Case([&](omp::TileOp op) {
8924 return applyTile(op, builder, moduleTranslation);
8925 })
8926 .Case([&](omp::FuseOp op) {
8927 return applyFuse(op, builder, moduleTranslation);
8928 })
8929 .Case([&](omp::TargetAllocMemOp) {
8930 return convertTargetAllocMemOp(*op, builder, moduleTranslation);
8931 })
8932 .Case([&](omp::TargetFreeMemOp) {
8933 return convertTargetFreeMemOp(*op, builder, moduleTranslation);
8934 })
8935 .Case([&](omp::AllocateDirOp) {
8936 return convertAllocateDirOp(*op, builder, moduleTranslation, *this);
8937 })
8938 .Case([&](omp::AllocateFreeOp) {
8939 return convertAllocateFreeOp(*op, builder, moduleTranslation,
8940 *this);
8941 })
8942 .Case([&](omp::AllocSharedMemOp op) {
8943 return convertAllocSharedMemOp(op, builder, moduleTranslation);
8944 })
8945 .Case([&](omp::FreeSharedMemOp op) {
8946 return convertFreeSharedMemOp(op, builder, moduleTranslation);
8947 })
8948 .Case([&](omp::GroupprivateOp) {
8949 return convertOmpGroupprivate(*op, builder, moduleTranslation);
8950 })
8951 .Default([&](Operation *inst) {
8952 return inst->emitError()
8953 << "not yet implemented: " << inst->getName();
8954 });
8955
8956 if (isOutermostLoopWrapper)
8957 moduleTranslation.stackPop();
8958
8959 return result;
8960}
8961
8963 registry.insert<omp::OpenMPDialect>();
8964 registry.addExtension(+[](MLIRContext *ctx, omp::OpenMPDialect *dialect) {
8965 dialect->addInterfaces<OpenMPDialectLLVMIRTranslationInterface>();
8966 });
8967}
8968
8970 DialectRegistry registry;
8972 context.appendDialectRegistry(registry);
8973}
for(Operation *op :ops)
return success()
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
if(!isCopyOut)
static mlir::LogicalResult buildDependData(OperandRange dependVars, std::optional< ArrayAttr > dependKinds, OperandRange dependIterated, std::optional< ArrayAttr > dependIteratedKinds, llvm::IRBuilderBase &builder, mlir::LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::DependenciesInfo &taskDeps)
static void handleDeclareTargetMapVar(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::Function *func)
static LogicalResult convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP atomic update operation using OpenMPIRBuilder.
static llvm::omp::OrderKind convertOrderKind(std::optional< omp::ClauseOrderKind > o)
Convert Order attribute to llvm::omp::OrderKind.
static void mapParentWithMembers(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, llvm::omp::OpenMPOffloadMappingFlags memberOfFlag, TargetDirectiveEnumTy targetDirective)
static void processIndividualMap(llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, MapInfoData &mapData, size_t mapDataIdx, MapInfosTy &combinedInfo, TargetDirectiveEnumTy targetDirective, llvm::omp::OpenMPOffloadMappingFlags memberOfFlag=llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE, bool isTargetParam=true, int mapDataParentIdx=-1)
This function handles the insertion of a single item of map data from MapInfoData into the OMPIRBuild...
static llvm::OpenMPIRBuilder::InsertPointTy findAllocInsertPoints(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::SmallVectorImpl< llvm::BasicBlock * > *deallocBlocks=nullptr)
Find the insertion point for allocas given the current insertion point for normal operations in the b...
static void sortMapIndices(llvm::SmallVectorImpl< size_t > &indices, omp::MapInfoOp mapInfo, bool first=true)
static LogicalResult convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
owningDataPtrPtrReductionGens[i]
static LogicalResult convertOmpTaskloopContextOp(omp::TaskloopContextOp contextOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static Operation * getGlobalOpFromValue(Value value)
static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind convertToCaptureClauseKind(mlir::omp::DeclareTargetCaptureClause captureClause)
static mlir::LogicalResult convertIteratorRegion(llvm::Value *linearIV, IteratorInfo &iterInfo, mlir::Block &iteratorRegionBlock, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo, bool first)
static OpTy castOrGetParentOfType(Operation *op, bool immediateParent=false)
If op is of the given type parameter, return it casted to that type. Otherwise, if its immediate pare...
static LogicalResult convertOmpOrderedRegion(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'ordered_region' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertFreeSharedMemOp(omp::FreeSharedMemOp freeMemOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertTargetFreeMemOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpAtomicWrite(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an omp.atomic.write operation to LLVM IR.
static OwningAtomicReductionGen makeAtomicReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Create an OpenMPIRBuilder-compatible atomic reduction generator for the given reduction declaration.
static OwningDataPtrPtrReductionGen makeRefDataPtrGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, bool isByRef)
Create an OpenMPIRBuilder-compatible data_ptr_ptr reduction generator for the given reduction declara...
static void popCancelFinalizationCB(const ArrayRef< llvm::UncondBrInst * > cancelTerminators, llvm::OpenMPIRBuilder &ompBuilder, const llvm::OpenMPIRBuilder::InsertPointTy &afterIP)
If we cancelled the construct, we should branch to the finalization block of that construct....
static llvm::Value * getRefPtrIfDeclareTarget(Value value, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult applyUnrollHeuristic(omp::UnrollHeuristicOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Apply a #pragma omp unroll / "!$omp unroll" transformation using the OpenMPIRBuilder.
static LogicalResult convertOmpMaster(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'master' operation into LLVM IR using OpenMPIRBuilder.
static void getAsIntegers(ArrayAttr values, llvm::SmallVector< int64_t > &ints)
static llvm::Value * findAssociatedValue(Value privateVar, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Return the llvm::Value * corresponding to the privateVar that is being privatized....
static ArrayRef< bool > getIsByRef(std::optional< ArrayRef< bool > > attr)
static llvm::Expected< llvm::Value * > lookupOrTranslatePureValue(Value value, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder)
Look up the given value in the mapping, and if it's not there, translate its defining operation at th...
static LogicalResult allocReductionVars(T op, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, SmallVectorImpl< DeferredStore > &deferredStores, llvm::ArrayRef< bool > isByRefs)
Allocate space for privatized reduction variables.
static LogicalResult convertOmpTaskwaitOp(omp::TaskwaitOp twOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP loop nest into LLVM IR using OpenMPIRBuilder.
static mlir::LogicalResult fillIteratorLoop(mlir::omp::IteratorOp itersOp, llvm::IRBuilderBase &builder, mlir::LLVM::ModuleTranslation &moduleTranslation, IteratorInfo &iterInfo, llvm::StringRef loopName, IteratorStoreEntryTy genStoreEntry)
static llvm::Expected< llvm::BasicBlock * > allocatePrivateVars(T op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, PrivateVarsInfo &privateVarsInfo, const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Allocate and initialize delayed private variables. Returns the basic block which comes after all of t...
static void createAlteredByCaptureMap(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder)
static LogicalResult convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP task construct into LLVM IR using OpenMPIRBuilder.
static void genMapInfos(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, TargetDirectiveEnumTy targetDirective)
static llvm::AtomicOrdering convertAtomicOrdering(std::optional< omp::ClauseMemoryOrderKind > ao)
Convert an Atomic Ordering attribute to llvm::AtomicOrdering.
static void setInsertPointForPossiblyEmptyBlock(llvm::IRBuilderBase &builder, llvm::BasicBlock *block=nullptr)
llvm::function_ref< void(llvm::Value *linearIV, mlir::omp::YieldOp yield)> IteratorStoreEntryTy
static LogicalResult convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'critical' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static omp::DistributeOp getDistributeCapturingTeamsReduction(omp::TeamsOp teamsOp)
static LogicalResult convertOmpCanonicalLoopOp(omp::CanonicalLoopOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Convert an omp.canonical_loop to LLVM-IR.
static LogicalResult convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static std::optional< int64_t > extractConstInteger(Value value)
If the given value is defined by an llvm.mlir.constant operation and it is of an integer type,...
static llvm::Expected< llvm::Value * > initPrivateVar(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, omp::PrivateClauseOp &privDecl, llvm::Value *nonPrivateVar, BlockArgument &blockArg, llvm::Value *llvmPrivateVar, llvm::BasicBlock *privInitBlock, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Initialize a single (first)private variable. You probably want to use allocateAndInitPrivateVars inst...
static mlir::LogicalResult buildAffinityData(mlir::omp::TaskOp &taskOp, llvm::IRBuilderBase &builder, mlir::LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::AffinityData &ad)
static LogicalResult allocAndInitializeReductionVars(OP op, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, llvm::ArrayRef< bool > isByRef)
static LogicalResult convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Value * getAllocationSize(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, T op)
static llvm::Function * getOmpTargetAlloc(llvm::IRBuilderBase &builder, llvm::Module *llvmModule)
static llvm::omp::OMPDynGroupprivateFallbackType getDynGroupprivateFallbackType(omp::FallbackModifierAttr fallbackAttr)
static llvm::Expected< llvm::Function * > emitUserDefinedMapper(Operation *declMapperOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::StringRef mapperFuncName, TargetDirectiveEnumTy targetDirective)
static LogicalResult convertOmpOrdered(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'ordered' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult cleanupPrivateVars(T op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, Location loc, PrivateVarsInfo &privateVarsInfo)
static void processMapWithMembersOf(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, TargetDirectiveEnumTy targetDirective)
static LogicalResult convertOmpMasked(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'masked' operation into LLVM IR using OpenMPIRBuilder.
static llvm::AtomicRMWInst::BinOp convertBinOpToAtomic(Operation &op)
Converts an LLVM dialect binary operation to the corresponding enum value for atomicrmw supported bin...
static LogicalResult convertOmpCancel(omp::CancelOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp)
allocatedType moduleTranslation static convertType(allocatedType) LogicalResult inlineOmpRegionCleanup(llvm::SmallVectorImpl< Region * > &cleanupRegions, llvm::ArrayRef< llvm::Value * > privateVariables, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, StringRef regionName, bool shouldLoadCleanupRegionArg=true)
handling of DeclareReductionOp's cleanup region
static LogicalResult applyFuse(omp::FuseOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Apply a #pragma omp fuse / !$omp fuse transformation using the OpenMPIRBuilder.
static LogicalResult convertOmpScope(omp::ScopeOp &scopeOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP scope construct into LLVM IR.
static llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Error initPrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, PrivateVarsInfo &privateVarsInfo, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
static LogicalResult convertAllocSharedMemOp(omp::AllocSharedMemOp allocMemOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::CanonicalLoopInfo * findCurrentLoopInfo(LLVM::ModuleTranslation &moduleTranslation)
Find the loop information structure for the loop nest being translated.
static OwningReductionGen makeReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Create an OpenMPIRBuilder-compatible reduction generator for the given reduction declaration.
static std::vector< llvm::Value * > calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, bool isArrayTy, OperandRange bounds)
This function calculates the array/pointer offset for map data provided with bounds operations,...
static void storeAffinityEntry(llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, llvm::Value *affinityList, llvm::Value *index, llvm::Value *addr, llvm::Value *len)
static LogicalResult convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts the OpenMP parallel operation to LLVM IR.
static void pushCancelFinalizationCB(SmallVectorImpl< llvm::UncondBrInst * > &cancelTerminators, llvm::IRBuilderBase &llvmBuilder, llvm::OpenMPIRBuilder &ompBuilder, mlir::Operation *op, llvm::omp::Directive cancelDirective)
Shared implementation of a callback which adds a termiator for the new block created for the branch t...
static LogicalResult inlineConvertOmpRegions(Region &region, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::Value * > *continuationBlockArgs=nullptr)
Translates the blocks contained in the given region and appends them to at the current insertion poin...
static void getTargetEntryUniqueInfo(llvm::TargetRegionEntryInfo &targetInfo, omp::TargetOp targetOp, llvm::OpenMPIRBuilder &ompBuilder, llvm::vfs::FileSystem &vfs, llvm::StringRef parentName="")
static LogicalResult convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP Threadprivate operation into LLVM IR using OpenMPIRBuilder.
static omp::PrivateClauseOp findPrivatizer(Operation *from, SymbolRefAttr symbolName)
Looks up from the operation from and returns the PrivateClauseOp with name symbolName.
static LogicalResult convertOmpGroupprivate(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP groupprivate operation into LLVM IR.
static llvm::Expected< llvm::Function * > getOrCreateUserDefinedMapperFunc(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, TargetDirectiveEnumTy targetDirective)
static uint64_t getTypeByteSize(mlir::Type type, const DataLayout &dl)
static llvm::SmallString< 64 > getDeclareTargetRefPtrSuffix(LLVM::GlobalOp globalOp, llvm::OpenMPIRBuilder &ompBuilder, llvm::vfs::FileSystem &vfs)
static void extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads, Value &numTeamsLower, Value &numTeamsUpper, Value &threadLimit, llvm::SmallVectorImpl< Value > *lowerBounds=nullptr, llvm::SmallVectorImpl< Value > *upperBounds=nullptr, llvm::SmallVectorImpl< Value > *steps=nullptr)
Follow uses of host_eval-defined block arguments of the given omp.target operation and populate outpu...
static llvm::Expected< llvm::BasicBlock * > convertOmpOpRegions(Region &region, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::PHINode * > *continuationBlockPHIs=nullptr)
Converts the given region that appears within an OpenMP dialect operation to LLVM IR,...
static LogicalResult convertOmpAtomicCompare(omp::AtomicCompareOp atomicCompareOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an omp.atomic.compare operation to LLVM IR.
static LogicalResult copyFirstPrivateVars(mlir::Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::Value * > &moldVars, ArrayRef< llvm::Value * > llvmPrivateVars, SmallVectorImpl< omp::PrivateClauseOp > &privateDecls, bool insertBarrier, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
static LogicalResult convertAllocateDirOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, const OpenMPDialectLLVMIRTranslationInterface &ompIface)
static bool constructIsCancellable(Operation *op)
Returns true if the construct contains omp.cancel or omp.cancellation_point.
static llvm::omp::OpenMPOffloadMappingFlags convertClauseMapFlags(omp::ClauseMapFlags mlirFlags)
static void buildDependDataLocator(std::optional< ArrayAttr > dependKinds, OperandRange dependVars, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::OpenMPIRBuilder::DependData > &dds)
static std::optional< llvm::omp::OMPAtomicCompareOp > convertFCmpPredicateToAtomicCompareOp(LLVM::FCmpPredicate predicate)
Helper to extract the OMPAtomicCompareOp from a floating-point comparison predicate....
static void mapInitializationArgs(T loop, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, DenseMap< Value, llvm::Value * > &reductionVariableMap, unsigned i)
Map input arguments to reduction initialization region.
static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind)
Convert ProcBindKind from MLIR-generated enum to LLVM enum.
static void fillAffinityLocators(Operation::operand_range affinityVars, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::Value *affinityList)
static LogicalResult convertOmpTaskloopWrapperOp(omp::TaskloopWrapperOp loopWrapperOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
The correct entry point is convertOmpTaskloopContextOp. This gets called whilst lowering the body of ...
static void getOverlappedMembers(llvm::SmallVectorImpl< size_t > &overlapMapDataIdxs, omp::MapInfoOp parentOp)
static LogicalResult convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP single construct into LLVM IR using OpenMPIRBuilder.
static bool isDeclareTargetTo(Value value)
static uint64_t getArrayElementSizeInBits(LLVM::LLVMArrayType arrTy, DataLayout &dl)
static void collectReductionDecls(T op, SmallVectorImpl< omp::DeclareReductionOp > &reductions)
Populates reductions with reduction declarations used in the given op.
static LogicalResult handleError(llvm::Error error, Operation &op)
static LogicalResult convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind convertToDeviceClauseKind(mlir::omp::DeclareTargetDeviceType deviceClause)
static std::optional< llvm::omp::OMPAtomicCompareOp > convertICmpPredicateToAtomicCompareOp(LLVM::ICmpPredicate predicate)
Helper to extract the OMPAtomicCompareOp from an integer comparison predicate. Returns std::nullopt f...
static llvm::Error computeTaskloopBounds(omp::LoopNestOp loopOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::Value *&lbVal, llvm::Value *&ubVal, llvm::Value *&stepVal)
static LogicalResult checkImplementationStatus(Operation &op)
Check whether translation to LLVM IR for the given operation is currently supported.
static llvm::IRBuilderBase::InsertPoint createDeviceArgumentAccessor(omp::TargetOp targetOp, MapInfoData &mapData, llvm::Argument &arg, llvm::Value *input, llvm::Value *&retVal, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase::InsertPoint allocaIP, llvm::IRBuilderBase::InsertPoint codeGenIP, llvm::ArrayRef< llvm::IRBuilderBase::InsertPoint > deallocIPs)
static LogicalResult createReductionsAndCleanup(OP op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, ArrayRef< llvm::Value * > privateReductionVariables, ArrayRef< bool > isByRef, bool isNowait=false, bool isTeamsReduction=false)
static LogicalResult convertOmpCancellationPoint(omp::CancellationPointOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static bool opIsInSingleThread(mlir::Operation *op)
This can't always be determined statically, but when we can, it is good to avoid generating compiler-...
static uint64_t getReductionDataSize(OpTy &op)
static LogicalResult convertOmpAtomicRead(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Convert omp.atomic.read operation to LLVM IR.
static llvm::omp::Directive convertCancellationConstructType(omp::ClauseCancellationConstructType directive)
static void initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp, llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs, bool isTargetDevice, bool isGPU)
Populate default MinTeams, MaxTeams and MaxThreads to their default values as stated by the correspon...
static llvm::omp::RTLDependenceKindTy convertDependKind(mlir::omp::ClauseTaskDepend kind)
static void initTargetRuntimeAttrs(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, omp::TargetOp targetOp, Operation *capturedOp, llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs)
Gather LLVM runtime values for all clauses evaluated in the host that are passed to the kernel invoca...
static LogicalResult convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static Value getBaseValueForTypeLookup(Value value)
static bool isHostDeviceOp(Operation *op)
static LogicalResult convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute, llvm::OpenMPIRBuilder *ompBuilder, LLVM::ModuleTranslation &moduleTranslation)
static bool isDeclareTargetLink(Value value)
static LogicalResult convertFlagsAttr(Operation *op, mlir::omp::FlagsAttr attribute, LLVM::ModuleTranslation &moduleTranslation)
Lowers the FlagsAttr which is applied to the module on the device pass when offloading,...
static bool checkIfPointerMap(omp::MapInfoOp mapOp)
static LogicalResult applyTile(omp::TileOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Apply a #pragma omp tile / !$omp tile transformation using the OpenMPIRBuilder.
static LogicalResult convertAllocateFreeOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, const OpenMPDialectLLVMIRTranslationInterface &ompIface)
static llvm::Function * getOmpTargetFree(llvm::IRBuilderBase &builder, llvm::Module *llvmModule)
static LogicalResult convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder.
static void collectMapDataFromMapOperands(MapInfoData &mapData, SmallVectorImpl< Value > &mapVars, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, llvm::IRBuilderBase &builder, ArrayRef< Value > useDevPtrOperands={}, ArrayRef< Value > useDevAddrOperands={}, ArrayRef< Value > hasDevAddrOperands={})
static void extractAtomicControlFlags(omp::AtomicUpdateOp atomicUpdateOp, bool &isIgnoreDenormalMode, bool &isFineGrainedMemory, bool &isRemoteMemory)
static Operation * genLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr, unsigned numCases, bool needsUniv, ArrayRef< TensorLevel > tidLvls)
Generates a for-loop or a while-loop, depending on whether it implements singleton iteration or co-it...
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition TypeID.h:331
#define div(a, b)
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:306
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
OpListType & getOperations()
Definition Block.h:147
Operation & front()
Definition Block.h:163
Operation & back()
Definition Block.h:162
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition Block.h:222
iterator begin()
Definition Block.h:153
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
An instance of this location represents a tuple of file, line number, and column number.
Definition Location.h:174
Implementation class for module translation.
llvm::BasicBlock * lookupBlock(Block *block) const
Finds an LLVM IR basic block that corresponds to the given MLIR block.
WalkResult stackWalk(llvm::function_ref< WalkResult(T &)> callback)
Calls callback for every ModuleTranslation stack frame of type T starting from the top of the stack.
void stackPush(Args &&...args)
Creates a stack frame of type T on ModuleTranslation stack.
LogicalResult convertBlock(Block &bb, bool ignoreArguments, llvm::IRBuilderBase &builder)
Translates the contents of the given block to LLVM IR using this translator.
SmallVector< llvm::Value * > lookupValues(ValueRange values)
Looks up remapped a list of remapped values.
void mapFunction(StringRef name, llvm::Function *func)
Stores the mapping between a function name and its LLVM IR representation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
void invalidateOmpLoop(omp::NewCliOp mlir)
Mark an OpenMP loop as having been consumed.
SymbolTableCollection & symbolTable()
llvm::Type * convertType(Type type)
Converts the type from MLIR LLVM dialect to LLVM.
llvm::OpenMPIRBuilder * getOpenMPBuilder()
Returns the OpenMP IR builder associated with the LLVM IR module being constructed.
llvm::vfs::FileSystem & getFileSystem()
Returns the virtual filesystem to use for file operations.
void mapOmpLoop(omp::NewCliOp mlir, llvm::CanonicalLoopInfo *llvm)
Map an MLIR OpenMP dialect CanonicalLoopInfo to its lowered LLVM-IR OpenMPIRBuilder CanonicalLoopInfo...
llvm::GlobalValue * lookupGlobal(Operation *op)
Finds an LLVM IR global value that corresponds to the given MLIR operation defining a global value.
SaveStateStack< T, ModuleTranslation > SaveStack
RAII object calling stackPush/stackPop on construction/destruction.
LogicalResult convertOperation(Operation &op, llvm::IRBuilderBase &builder)
Converts the given MLIR operation into LLVM IR using this translator.
llvm::Function * lookupFunction(StringRef name) const
Finds an LLVM IR function by its name.
void mapBlock(Block *mlir, llvm::BasicBlock *llvm)
Stores the mapping between an MLIR block and LLVM IR basic block.
llvm::Module * getLLVMModule()
Returns the LLVM module in which the IR is being constructed.
void stackPop()
Pops the last element from the ModuleTranslation stack.
void forgetMapping(Region &region)
Removes the mapping for blocks contained in the region and values defined in these blocks.
void mapValue(Value mlir, llvm::Value *llvm)
Stores the mapping between an MLIR value and its LLVM IR counterpart.
llvm::CanonicalLoopInfo * lookupOMPLoop(omp::NewCliOp mlir) const
Find the LLVM-IR loop that represents an MLIR loop.
llvm::LLVMContext & getLLVMContext() const
Returns the LLVM context in which the IR is being constructed.
Utility class to translate MLIR LLVM dialect types to LLVM IR.
Definition TypeToLLVM.h:39
unsigned getPreferredAlignment(Type type, const llvm::DataLayout &layout)
Returns the preferred alignment for the type given the data layout.
T findInstanceOf()
Return an instance of the given location type if one is nested under the current location.
Definition Location.h:45
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
void appendDialectRegistry(const DialectRegistry &registry)
Append the contents of the given dialect registry to the registry associated with this context.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:44
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:711
Value getOperand(unsigned idx)
Definition Operation.h:375
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:432
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:699
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:240
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:251
unsigned getNumOperands()
Definition Operation.h:371
OperandRange operand_range
Definition Operation.h:396
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:255
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:115
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:403
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:822
user_range getUsers()
Returns a range of all users.
Definition Operation.h:898
result_range getResults()
Definition Operation.h:440
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:233
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:429
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
BlockArgListType getArguments()
Definition Region.h:81
bool empty()
Definition Region.h:60
unsigned getNumArguments()
Definition Region.h:123
iterator begin()
Definition Region.h:55
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition Region.h:200
BlockListType & getBlocks()
Definition Region.h:45
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
Concrete CRTP base class for StateStack frames.
Definition StateStack.h:47
@ Private
The symbol is private and may only be referenced by SymbolRefAttrs local to the operations within the...
Definition SymbolTable.h:97
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult skip()
Definition WalkResult.h:48
static WalkResult advance()
Definition WalkResult.h:47
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition WalkResult.h:51
static WalkResult interrupt()
Definition WalkResult.h:46
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Definition CallGraph.h:228
void connectPHINodes(Region &region, const ModuleTranslation &state)
For all blocks in the region that were converted to LLVM IR using the given ModuleTranslation,...
llvm::Constant * createMappingInformation(Location loc, llvm::OpenMPIRBuilder &builder)
Create a constant string representing the mapping information extracted from the MLIR location inform...
bool opInSharedDeviceContext(Operation &op)
Check whether the given operation is located in a context where an allocation to be used by multiple ...
Definition Utils.cpp:66
bool allocaUsesRequireSharedMem(Value alloc)
Check whether the value representing an allocation, assumed to have been defined in a shared device c...
Definition Utils.cpp:51
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
auto getDims(VectorType vType)
Returns a range over the dims (size and scalability) of a VectorType.
Include the generated interface declarations.
SetVector< Block * > getBlocksSortedByDominance(Region &region)
Gets a list of blocks that is sorted according to dominance.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool isPure(Operation *op)
Returns true if the given operation is pure, i.e., is speculatable that does not touch memory.
void registerOpenMPDialectTranslation(DialectRegistry &registry)
Register the OpenMP dialect and the translation from it to the LLVM IR in the given registry;.
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:125
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition Utils.cpp:1330
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:120
A util to collect info needed to convert delayed privatizers from MLIR to LLVM.
SmallVector< mlir::Value > mlirVars
SmallVector< omp::PrivateClauseOp > privatizers
MutableArrayRef< BlockArgument > blockArgs
SmallVector< llvm::Value * > llvmVars
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.