MLIR 23.0.0git
OpenMPToLLVMIRTranslation.cpp
Go to the documentation of this file.
1//===- OpenMPToLLVMIRTranslation.cpp - Translate OpenMP dialect to LLVM IR-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a translation between the MLIR OpenMP dialect and LLVM
10// IR.
11//
12//===----------------------------------------------------------------------===//
20#include "mlir/IR/Operation.h"
22#include "mlir/Support/LLVM.h"
25
26#include "llvm/ADT/ArrayRef.h"
27#include "llvm/ADT/SmallVector.h"
28#include "llvm/ADT/TypeSwitch.h"
29#include "llvm/Frontend/OpenMP/OMPConstants.h"
30#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
31#include "llvm/IR/Constants.h"
32#include "llvm/IR/DebugInfoMetadata.h"
33#include "llvm/IR/DerivedTypes.h"
34#include "llvm/IR/IRBuilder.h"
35#include "llvm/IR/MDBuilder.h"
36#include "llvm/IR/ReplaceConstant.h"
37#include "llvm/Support/AMDGPUAddrSpace.h"
38#include "llvm/Support/FileSystem.h"
39#include "llvm/Support/NVPTXAddrSpace.h"
40#include "llvm/Support/VirtualFileSystem.h"
41#include "llvm/TargetParser/Triple.h"
42#include "llvm/Transforms/Utils/ModuleUtils.h"
43
44#include <cstdint>
45#include <iterator>
46#include <numeric>
47#include <optional>
48#include <utility>
49
50using namespace mlir;
51
52namespace {
53static llvm::omp::ScheduleKind
54convertToScheduleKind(std::optional<omp::ClauseScheduleKind> schedKind) {
55 if (!schedKind.has_value())
56 return llvm::omp::OMP_SCHEDULE_Default;
57 switch (schedKind.value()) {
58 case omp::ClauseScheduleKind::Static:
59 return llvm::omp::OMP_SCHEDULE_Static;
60 case omp::ClauseScheduleKind::Dynamic:
61 return llvm::omp::OMP_SCHEDULE_Dynamic;
62 case omp::ClauseScheduleKind::Guided:
63 return llvm::omp::OMP_SCHEDULE_Guided;
64 case omp::ClauseScheduleKind::Auto:
65 return llvm::omp::OMP_SCHEDULE_Auto;
66 case omp::ClauseScheduleKind::Runtime:
67 return llvm::omp::OMP_SCHEDULE_Runtime;
68 case omp::ClauseScheduleKind::Distribute:
69 return llvm::omp::OMP_SCHEDULE_Distribute;
70 }
71 llvm_unreachable("unhandled schedule clause argument");
72}
73
74/// ModuleTranslation stack frame for OpenMP operations. This keeps track of the
75/// insertion points for allocas.
76class OpenMPAllocStackFrame
77 : public StateStackFrameBase<OpenMPAllocStackFrame> {
78public:
80
81 explicit OpenMPAllocStackFrame(
82 llvm::OpenMPIRBuilder::InsertPointTy allocaIP,
83 llvm::ArrayRef<llvm::BasicBlock *> deallocBlocks)
84 : allocInsertPoint(allocaIP), deallocBlocks(deallocBlocks) {}
85 llvm::OpenMPIRBuilder::InsertPointTy allocInsertPoint;
86 llvm::SmallVector<llvm::BasicBlock *> deallocBlocks;
87};
88
89/// Stack frame to hold a \see llvm::CanonicalLoopInfo representing the
90/// collapsed canonical loop information corresponding to an \c omp.loop_nest
91/// operation.
92class OpenMPLoopInfoStackFrame
93 : public StateStackFrameBase<OpenMPLoopInfoStackFrame> {
94public:
95 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPLoopInfoStackFrame)
96 llvm::CanonicalLoopInfo *loopInfo = nullptr;
97};
98
99/// Custom error class to signal translation errors that don't need reporting,
100/// since encountering them will have already triggered relevant error messages.
101///
102/// Its purpose is to serve as the glue between MLIR failures represented as
103/// \see LogicalResult instances and \see llvm::Error instances used to
104/// propagate errors through the \see llvm::OpenMPIRBuilder. Generally, when an
105/// error of the first type is raised, a message is emitted directly (the \see
106/// LogicalResult itself does not hold any information). If we need to forward
107/// this error condition as an \see llvm::Error while avoiding triggering some
108/// redundant error reporting later on, we need a custom \see llvm::ErrorInfo
109/// class to just signal this situation has happened.
110///
111/// For example, this class should be used to trigger errors from within
112/// callbacks passed to the \see OpenMPIRBuilder when they were triggered by the
113/// translation of their own regions. This unclutters the error log from
114/// redundant messages.
115class PreviouslyReportedError
116 : public llvm::ErrorInfo<PreviouslyReportedError> {
117public:
118 void log(raw_ostream &) const override {
119 // Do not log anything.
120 }
121
122 std::error_code convertToErrorCode() const override {
123 llvm_unreachable(
124 "PreviouslyReportedError doesn't support ECError conversion");
125 }
126
127 // Used by ErrorInfo::classID.
128 static char ID;
129};
130
131char PreviouslyReportedError::ID = 0;
132
133/*
134 * Custom class for processing linear clause for omp.wsloop
135 * and omp.simd. Linear clause translation requires setup,
136 * initialization, update, and finalization at varying
137 * basic blocks in the IR. This class helps maintain
138 * internal state to allow consistent translation in
139 * each of these stages.
140 */
141
142class LinearClauseProcessor {
143
144private:
145 SmallVector<llvm::Value *> linearPreconditionVars;
146 SmallVector<llvm::Value *> linearLoopBodyTemps;
147 SmallVector<llvm::Value *> linearOrigVal;
148 SmallVector<llvm::Value *> linearSteps;
149 SmallVector<llvm::Type *> linearVarTypes;
150 llvm::BasicBlock *linearFinalizationBB;
151 llvm::BasicBlock *linearExitBB;
152 llvm::BasicBlock *linearLastIterExitBB;
153
154public:
155 // Register type for the linear variables
156 void registerType(LLVM::ModuleTranslation &moduleTranslation,
157 mlir::Attribute &ty) {
158 linearVarTypes.push_back(moduleTranslation.convertType(
159 mlir::cast<mlir::TypeAttr>(ty).getValue()));
160 }
161
162 // Allocate space for linear variabes
163 void createLinearVar(llvm::IRBuilderBase &builder,
164 LLVM::ModuleTranslation &moduleTranslation,
165 llvm::Value *linearVar, int idx) {
166 linearPreconditionVars.push_back(
167 builder.CreateAlloca(linearVarTypes[idx], nullptr, ".linear_var"));
168 llvm::Value *linearLoopBodyTemp =
169 builder.CreateAlloca(linearVarTypes[idx], nullptr, ".linear_result");
170 linearOrigVal.push_back(linearVar);
171 linearLoopBodyTemps.push_back(linearLoopBodyTemp);
172 }
173
174 // Initialize linear step
175 inline void initLinearStep(LLVM::ModuleTranslation &moduleTranslation,
176 mlir::Value &linearStep) {
177 linearSteps.push_back(moduleTranslation.lookupValue(linearStep));
178 }
179
180 // Emit IR for initialization of linear variables
181 void initLinearVar(llvm::IRBuilderBase &builder,
182 LLVM::ModuleTranslation &moduleTranslation,
183 llvm::BasicBlock *loopPreHeader) {
184 builder.SetInsertPoint(loopPreHeader->getTerminator());
185 for (size_t index = 0; index < linearOrigVal.size(); index++) {
186 llvm::LoadInst *linearVarLoad =
187 builder.CreateLoad(linearVarTypes[index], linearOrigVal[index]);
188 builder.CreateStore(linearVarLoad, linearPreconditionVars[index]);
189 }
190 }
191
192 // Emit IR for updating Linear variables
193 void updateLinearVar(llvm::IRBuilderBase &builder, llvm::BasicBlock *loopBody,
194 llvm::Value *loopInductionVar) {
195 builder.SetInsertPoint(loopBody->getTerminator());
196 for (size_t index = 0; index < linearPreconditionVars.size(); index++) {
197 llvm::Type *linearVarType = linearVarTypes[index];
198 llvm::Value *iv = loopInductionVar;
199 llvm::Value *step = linearSteps[index];
200
201 if (!iv->getType()->isIntegerTy())
202 llvm_unreachable("OpenMP loop induction variable must be an integer "
203 "type");
204
205 if (linearVarType->isIntegerTy()) {
206 // Integer path: normalize all arithmetic to linearVarType
207 iv = builder.CreateSExtOrTrunc(iv, linearVarType);
208 step = builder.CreateSExtOrTrunc(step, linearVarType);
209
210 llvm::LoadInst *linearVarStart =
211 builder.CreateLoad(linearVarType, linearPreconditionVars[index]);
212 llvm::Value *mulInst = builder.CreateMul(iv, step);
213 llvm::Value *addInst = builder.CreateAdd(linearVarStart, mulInst);
214 builder.CreateStore(addInst, linearLoopBodyTemps[index]);
215 } else if (linearVarType->isFloatingPointTy()) {
216 // Float path: perform multiply in integer, then convert to float
217 step = builder.CreateSExtOrTrunc(step, iv->getType());
218 llvm::Value *mulInst = builder.CreateMul(iv, step);
219
220 llvm::LoadInst *linearVarStart =
221 builder.CreateLoad(linearVarType, linearPreconditionVars[index]);
222 llvm::Value *mulFp = builder.CreateSIToFP(mulInst, linearVarType);
223 llvm::Value *addInst = builder.CreateFAdd(linearVarStart, mulFp);
224 builder.CreateStore(addInst, linearLoopBodyTemps[index]);
225 } else {
226 llvm_unreachable(
227 "Linear variable must be of integer or floating-point type");
228 }
229 }
230 }
231
232 // Linear variable finalization is conditional on the last logical iteration.
233 // Create BB splits to manage the same.
234 void splitLinearFiniBB(llvm::IRBuilderBase &builder,
235 llvm::BasicBlock *loopExit) {
236 linearFinalizationBB = loopExit->splitBasicBlock(
237 loopExit->getTerminator(), "omp_loop.linear_finalization");
238 linearExitBB = linearFinalizationBB->splitBasicBlock(
239 linearFinalizationBB->getTerminator(), "omp_loop.linear_exit");
240 linearLastIterExitBB = linearFinalizationBB->splitBasicBlock(
241 linearFinalizationBB->getTerminator(), "omp_loop.linear_lastiter_exit");
242 }
243
244 // Finalize the linear vars
245 llvm::OpenMPIRBuilder::InsertPointOrErrorTy
246 finalizeLinearVar(llvm::IRBuilderBase &builder,
247 LLVM::ModuleTranslation &moduleTranslation,
248 llvm::Value *lastIter) {
249 // Emit condition to check whether last logical iteration is being executed
250 builder.SetInsertPoint(linearFinalizationBB->getTerminator());
251 llvm::Value *loopLastIterLoad = builder.CreateLoad(
252 llvm::Type::getInt32Ty(builder.getContext()), lastIter);
253 llvm::Value *isLast =
254 builder.CreateCmp(llvm::CmpInst::ICMP_NE, loopLastIterLoad,
255 llvm::ConstantInt::get(
256 llvm::Type::getInt32Ty(builder.getContext()), 0));
257 // Store the linear variable values to original variables.
258 builder.SetInsertPoint(linearLastIterExitBB->getTerminator());
259 for (size_t index = 0; index < linearOrigVal.size(); index++) {
260 llvm::LoadInst *linearVarTemp =
261 builder.CreateLoad(linearVarTypes[index], linearLoopBodyTemps[index]);
262 builder.CreateStore(linearVarTemp, linearOrigVal[index]);
263 }
264
265 // Create conditional branch such that the linear variable
266 // values are stored to original variables only at the
267 // last logical iteration
268 builder.SetInsertPoint(linearFinalizationBB->getTerminator());
269 builder.CreateCondBr(isLast, linearLastIterExitBB, linearExitBB);
270 linearFinalizationBB->getTerminator()->eraseFromParent();
271 // Emit barrier
272 builder.SetInsertPoint(linearExitBB->getTerminator());
273 return moduleTranslation.getOpenMPBuilder()->createBarrier(
274 builder.saveIP(), llvm::omp::OMPD_barrier);
275 }
276
277 // Emit stores for linear variables. Useful in case of SIMD
278 // construct.
279 void emitStoresForLinearVar(llvm::IRBuilderBase &builder) {
280 for (size_t index = 0; index < linearOrigVal.size(); index++) {
281 llvm::LoadInst *linearVarTemp =
282 builder.CreateLoad(linearVarTypes[index], linearLoopBodyTemps[index]);
283 builder.CreateStore(linearVarTemp, linearOrigVal[index]);
284 }
285 }
286
287 // Rewrite all uses of the original variable, in the basic blocks whose names
288 // start with `prefix`, with the linear variable in-place.
289 void rewriteInPlace(llvm::IRBuilderBase &builder, llvm::BasicBlock *startBB,
290 llvm::BasicBlock *endBB, llvm::StringRef prefix,
291 size_t varIndex) {
292 llvm::SmallVector<llvm::BasicBlock *, 32> worklist;
293 llvm::SmallPtrSet<llvm::BasicBlock *, 32> visited;
294 llvm::SmallPtrSet<llvm::BasicBlock *, 32> matchingBBs;
295
296 assert(startBB && endBB && "Invalid startBB/endBB");
297
298 // Traverse basic blocks from startBB to endBB and save those
299 // whose names start with the specified prefix.
300 worklist.push_back(startBB);
301 visited.insert(startBB);
302
303 while (!worklist.empty()) {
304 llvm::BasicBlock *bb = worklist.pop_back_val();
305
306 if (bb->hasName() && bb->getName().starts_with(prefix))
307 matchingBBs.insert(bb);
308
309 if (bb == endBB)
310 continue;
311
312 for (llvm::BasicBlock *succ : llvm::successors(bb)) {
313 if (visited.insert(succ).second)
314 worklist.push_back(succ);
315 }
316 }
317
318 // Rewrite all uses in the matching BBs.
319 llvm::SmallVector<llvm::User *> users(linearOrigVal[varIndex]->users());
320 for (auto *user : users) {
321 if (auto *userInst = dyn_cast<llvm::Instruction>(user)) {
322 if (matchingBBs.contains(userInst->getParent()))
323 user->replaceUsesOfWith(linearOrigVal[varIndex],
324 linearLoopBodyTemps[varIndex]);
325 }
326 }
327 }
328};
329
330} // namespace
331
332/// Looks up from the operation from and returns the PrivateClauseOp with
333/// name symbolName
334static omp::PrivateClauseOp findPrivatizer(Operation *from,
335 SymbolRefAttr symbolName) {
336 omp::PrivateClauseOp privatizer =
338 symbolName);
339 assert(privatizer && "privatizer not found in the symbol table");
340 return privatizer;
341}
342
343/// Check whether translation to LLVM IR for the given operation is currently
344/// supported. If not, descriptive diagnostics will be emitted to let users know
345/// this is a not-yet-implemented feature.
346///
347/// \returns success if no unimplemented features are needed to translate the
348/// given operation.
349static LogicalResult checkImplementationStatus(Operation &op) {
350 auto todo = [&op](StringRef clauseName) {
351 return op.emitError() << "not yet implemented: Unhandled clause "
352 << clauseName << " in " << op.getName()
353 << " operation";
354 };
355
356 auto checkAllocate = [&todo](auto op, LogicalResult &result) {
357 if (!op.getAllocateVars().empty() || !op.getAllocatorVars().empty())
358 result = todo("allocate");
359 };
360 auto checkBare = [&todo](auto op, LogicalResult &result) {
361 if (op.getBare())
362 result = todo("ompx_bare");
363 };
364 auto checkDepend = [&todo](auto op, LogicalResult &result) {
365 if (!op.getDependVars().empty() || op.getDependKinds())
366 result = todo("depend");
367 };
368 auto checkHint = [](auto op, LogicalResult &) {
369 if (op.getHint())
370 op.emitWarning("hint clause discarded");
371 };
372 auto checkInReduction = [&todo](auto op, LogicalResult &result) {
373 if (!op.getInReductionVars().empty() || op.getInReductionByref() ||
374 op.getInReductionSyms())
375 result = todo("in_reduction");
376 };
377 auto checkNowait = [&todo](auto op, LogicalResult &result) {
378 if (op.getNowait())
379 result = todo("nowait");
380 };
381 auto checkOrder = [&todo](auto op, LogicalResult &result) {
382 if (op.getOrder() || op.getOrderMod())
383 result = todo("order");
384 };
385 auto checkPrivate = [&todo](auto op, LogicalResult &result) {
386 if (!op.getPrivateVars().empty() || op.getPrivateSyms())
387 result = todo("privatization");
388 };
389 auto checkReduction = [&todo](auto op, LogicalResult &result) {
390 if (isa<omp::TeamsOp>(op) || isa<omp::TaskloopContextOp>(op))
391 if (!op.getReductionVars().empty() || op.getReductionByref() ||
392 op.getReductionSyms())
393 result = todo("reduction");
394 if (op.getReductionMod() &&
395 op.getReductionMod().value() != omp::ReductionModifier::defaultmod)
396 result = todo("reduction with modifier");
397 };
398 auto checkTaskReduction = [&todo](auto op, LogicalResult &result) {
399 if (!op.getTaskReductionVars().empty() || op.getTaskReductionByref() ||
400 op.getTaskReductionSyms())
401 result = todo("task_reduction");
402 };
403 auto checkNumTeams = [&todo](auto op, LogicalResult &result) {
404 if (op.hasNumTeamsMultiDim())
405 result = todo("num_teams with multi-dimensional values");
406 };
407 auto checkNumThreads = [&todo](auto op, LogicalResult &result) {
408 if (op.hasNumThreadsMultiDim())
409 result = todo("num_threads with multi-dimensional values");
410 };
411
412 auto checkThreadLimit = [&todo](auto op, LogicalResult &result) {
413 if (op.hasThreadLimitMultiDim())
414 result = todo("thread_limit with multi-dimensional values");
415 };
416 auto checkMap = [&todo](auto op, LogicalResult &result) {
417 if (!op.getMapIterated().empty())
418 result = todo("map/motion clause with iterator modifier");
419 };
420
421 auto checkDynGroupprivate = [&todo](auto op, LogicalResult &result) {
422 if (op.getDynGroupprivateSize())
423 result = todo("dyn_groupprivate");
424 };
425
426 LogicalResult result = success();
428 .Case([&](omp::DistributeOp op) {
429 checkAllocate(op, result);
430 checkOrder(op, result);
431 })
432 .Case([&](omp::SectionsOp op) {
433 checkAllocate(op, result);
434 checkPrivate(op, result);
435 checkReduction(op, result);
436 })
437 .Case([&](omp::ScopeOp op) {
438 checkAllocate(op, result);
439 checkReduction(op, result);
440 })
441 .Case([&](omp::SingleOp op) {
442 checkAllocate(op, result);
443 checkPrivate(op, result);
444 })
445 .Case([&](omp::TeamsOp op) {
446 checkAllocate(op, result);
447 checkPrivate(op, result);
448 checkNumTeams(op, result);
449 checkThreadLimit(op, result);
450 checkDynGroupprivate(op, result);
451 })
452 .Case([&](omp::TaskOp op) {
453 checkAllocate(op, result);
454 checkInReduction(op, result);
455 })
456 .Case([&](omp::TaskgroupOp op) {
457 checkAllocate(op, result);
458 checkTaskReduction(op, result);
459 })
460 .Case([&](omp::TaskwaitOp op) {
461 checkDepend(op, result);
462 checkNowait(op, result);
463 })
464 .Case([&](omp::TaskloopContextOp op) {
465 checkAllocate(op, result);
466 checkInReduction(op, result);
467 checkReduction(op, result);
468 })
469 .Case([&](omp::WsloopOp op) {
470 checkAllocate(op, result);
471 checkOrder(op, result);
472 checkReduction(op, result);
473 })
474 .Case([&](omp::ParallelOp op) {
475 checkAllocate(op, result);
476 checkReduction(op, result);
477 checkNumThreads(op, result);
478 })
479 .Case([&](omp::SimdOp op) { checkReduction(op, result); })
480 .Case<omp::AtomicReadOp, omp::AtomicWriteOp, omp::AtomicUpdateOp,
481 omp::AtomicCaptureOp>([&](auto op) { checkHint(op, result); })
482 .Case([&](omp::AtomicCompareOp op) {
483 checkHint(op, result);
484 Region &region = op.getRegion();
485 if (region.empty())
486 return;
487 mlir::Type argType = region.front().getArgument(0).getType();
488 auto structTy = dyn_cast<LLVM::LLVMStructType>(argType);
489 if (!structTy)
490 return;
491 DataLayout dl = DataLayout(op->getParentOfType<ModuleOp>());
492 unsigned totalBits = dl.getTypeSizeInBits(structTy);
493 if (totalBits > 128)
494 result = todo("compare for complex types wider than 128 bits");
495 })
496 .Case<omp::TargetEnterDataOp, omp::TargetExitDataOp>([&](auto op) {
497 checkDepend(op, result);
498 checkMap(op, result);
499 })
500 .Case([&](omp::TargetUpdateOp op) {
501 checkDepend(op, result);
502 checkMap(op, result);
503 })
504 .Case([&](omp::TargetOp op) {
505 checkAllocate(op, result);
506 checkBare(op, result);
507 checkInReduction(op, result);
508 checkMap(op, result);
509 checkThreadLimit(op, result);
510 })
511 .Case([&](omp::TargetDataOp op) { checkMap(op, result); })
512 .Case([&](omp::DeclareMapperInfoOp op) { checkMap(op, result); })
513 .Default([](Operation &) {
514 // Assume all clauses for an operation can be translated unless they are
515 // checked above.
516 });
517 return result;
518}
519
520static LogicalResult handleError(llvm::Error error, Operation &op) {
521 LogicalResult result = success();
522 if (error) {
523 llvm::handleAllErrors(
524 std::move(error),
525 [&](const PreviouslyReportedError &) { result = failure(); },
526 [&](const llvm::ErrorInfoBase &err) {
527 result = op.emitError(err.message());
528 });
529 }
530 return result;
531}
532
533template <typename T>
534static LogicalResult handleError(llvm::Expected<T> &result, Operation &op) {
535 if (!result)
536 return handleError(result.takeError(), op);
537
538 return success();
539}
540
541/// Find the insertion point for allocas given the current insertion point for
542/// normal operations in the builder.
543static llvm::OpenMPIRBuilder::InsertPointTy findAllocInsertPoints(
544 llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation,
545 llvm::SmallVectorImpl<llvm::BasicBlock *> *deallocBlocks = nullptr) {
546 // If there is an allocation insertion point on stack, i.e. we are in a nested
547 // operation and a specific point was provided by some surrounding operation,
548 // use it.
549 llvm::OpenMPIRBuilder::InsertPointTy allocInsertPoint;
550 llvm::ArrayRef<llvm::BasicBlock *> deallocInsertPoints;
551 WalkResult walkResult = moduleTranslation.stackWalk<OpenMPAllocStackFrame>(
552 [&](OpenMPAllocStackFrame &frame) {
553 allocInsertPoint = frame.allocInsertPoint;
554 deallocInsertPoints = frame.deallocBlocks;
555 return WalkResult::interrupt();
556 });
557 // In cases with multiple levels of outlining, the tree walk might find an
558 // insertion point that is inside the original function while the builder
559 // insertion point is inside the outlined function. We need to make sure that
560 // we do not use it in those cases.
561 if (walkResult.wasInterrupted() &&
562 allocInsertPoint.getBlock()->getParent() ==
563 builder.GetInsertBlock()->getParent()) {
564 if (deallocBlocks)
565 deallocBlocks->insert(deallocBlocks->end(), deallocInsertPoints.begin(),
566 deallocInsertPoints.end());
567 return allocInsertPoint;
568 }
569
570 // Otherwise, insert to the entry block of the surrounding function.
571 // If the current IRBuilder InsertPoint is the function's entry, it cannot
572 // also be used for alloca insertion which would result in insertion order
573 // confusion. Create a new BasicBlock for the Builder and use the entry block
574 // for the allocs.
575 // TODO: Create a dedicated alloca BasicBlock at function creation such that
576 // we do not need to move the current InsertPoint here.
577 if (builder.GetInsertBlock() ==
578 &builder.GetInsertBlock()->getParent()->getEntryBlock()) {
579 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end() &&
580 "Assuming end of basic block");
581 llvm::BasicBlock *entryBB = llvm::BasicBlock::Create(
582 builder.getContext(), "entry", builder.GetInsertBlock()->getParent(),
583 builder.GetInsertBlock()->getNextNode());
584 builder.CreateBr(entryBB);
585 builder.SetInsertPoint(entryBB);
586 }
587
588 // Collect exit blocks, which is where explicit deallocations should happen in
589 // this case.
590 if (deallocBlocks) {
591 for (llvm::BasicBlock &block : *builder.GetInsertBlock()->getParent()) {
592 // TODO: This currently results in no blocks being added to the list when
593 // all exit blocks of the enclosing function have not been lowered before
594 // this is reached.
595 llvm::Instruction *terminator = block.getTerminatorOrNull();
596 if (isa_and_present<llvm::ReturnInst>(terminator))
597 deallocBlocks->emplace_back(&block);
598 }
599 }
600
601 llvm::BasicBlock &funcEntryBlock =
602 builder.GetInsertBlock()->getParent()->getEntryBlock();
603 return llvm::OpenMPIRBuilder::InsertPointTy(
604 &funcEntryBlock, funcEntryBlock.getFirstInsertionPt());
605}
606
607/// Find the loop information structure for the loop nest being translated. It
608/// will return a `null` value unless called from the translation function for
609/// a loop wrapper operation after successfully translating its body.
610static llvm::CanonicalLoopInfo *
612 llvm::CanonicalLoopInfo *loopInfo = nullptr;
613 moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
614 [&](OpenMPLoopInfoStackFrame &frame) {
615 loopInfo = frame.loopInfo;
616 return WalkResult::interrupt();
617 });
618 return loopInfo;
619}
620
621/// Converts the given region that appears within an OpenMP dialect operation to
622/// LLVM IR, creating a branch from the `sourceBlock` to the entry block of the
623/// region, and a branch from any block with an successor-less OpenMP terminator
624/// to `continuationBlock`. Populates `continuationBlockPHIs` with the PHI nodes
625/// of the continuation block if provided.
627 Region &region, StringRef blockName, llvm::IRBuilderBase &builder,
628 LLVM::ModuleTranslation &moduleTranslation,
629 SmallVectorImpl<llvm::PHINode *> *continuationBlockPHIs = nullptr) {
630 bool isLoopWrapper = isa<omp::LoopWrapperInterface>(region.getParentOp());
631
632 llvm::BasicBlock *continuationBlock =
633 splitBB(builder, true, "omp.region.cont");
634 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
635
636 llvm::LLVMContext &llvmContext = builder.getContext();
637 for (Block &bb : region) {
638 llvm::BasicBlock *llvmBB = llvm::BasicBlock::Create(
639 llvmContext, blockName, builder.GetInsertBlock()->getParent(),
640 builder.GetInsertBlock()->getNextNode());
641 moduleTranslation.mapBlock(&bb, llvmBB);
642 }
643
644 llvm::Instruction *sourceTerminator = sourceBlock->getTerminator();
645
646 // Terminators (namely YieldOp) may be forwarding values to the region that
647 // need to be available in the continuation block. Collect the types of these
648 // operands in preparation of creating PHI nodes. This is skipped for loop
649 // wrapper operations, for which we know in advance they have no terminators.
650 SmallVector<llvm::Type *> continuationBlockPHITypes;
651 unsigned numYields = 0;
652
653 if (!isLoopWrapper) {
654 bool operandsProcessed = false;
655 for (Block &bb : region.getBlocks()) {
656 if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(bb.getTerminator())) {
657 if (!operandsProcessed) {
658 for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
659 continuationBlockPHITypes.push_back(
660 moduleTranslation.convertType(yield->getOperand(i).getType()));
661 }
662 operandsProcessed = true;
663 } else {
664 assert(continuationBlockPHITypes.size() == yield->getNumOperands() &&
665 "mismatching number of values yielded from the region");
666 for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
667 llvm::Type *operandType =
668 moduleTranslation.convertType(yield->getOperand(i).getType());
669 (void)operandType;
670 assert(continuationBlockPHITypes[i] == operandType &&
671 "values of mismatching types yielded from the region");
672 }
673 }
674 numYields++;
675 }
676 }
677 }
678
679 // Insert PHI nodes in the continuation block for any values forwarded by the
680 // terminators in this region.
681 if (!continuationBlockPHITypes.empty())
682 assert(
683 continuationBlockPHIs &&
684 "expected continuation block PHIs if converted regions yield values");
685 if (continuationBlockPHIs) {
686 llvm::IRBuilderBase::InsertPointGuard guard(builder);
687 continuationBlockPHIs->reserve(continuationBlockPHITypes.size());
688 builder.SetInsertPoint(continuationBlock, continuationBlock->begin());
689 for (llvm::Type *ty : continuationBlockPHITypes)
690 continuationBlockPHIs->push_back(builder.CreatePHI(ty, numYields));
691 }
692
693 // Convert blocks one by one in topological order to ensure
694 // defs are converted before uses.
696 for (Block *bb : blocks) {
697 llvm::BasicBlock *llvmBB = moduleTranslation.lookupBlock(bb);
698 // Retarget the branch of the entry block to the entry block of the
699 // converted region (regions are single-entry).
700 if (bb->isEntryBlock()) {
701 assert(sourceTerminator->getNumSuccessors() == 1 &&
702 "provided entry block has multiple successors");
703 assert(sourceTerminator->getSuccessor(0) == continuationBlock &&
704 "ContinuationBlock is not the successor of the entry block");
705 sourceTerminator->setSuccessor(0, llvmBB);
706 }
707
708 llvm::IRBuilderBase::InsertPointGuard guard(builder);
709 if (failed(
710 moduleTranslation.convertBlock(*bb, bb->isEntryBlock(), builder)))
711 return llvm::make_error<PreviouslyReportedError>();
712
713 // Create a direct branch here for loop wrappers to prevent their lack of a
714 // terminator from causing a crash below.
715 if (isLoopWrapper) {
716 builder.CreateBr(continuationBlock);
717 continue;
718 }
719
720 // Special handling for `omp.yield` and `omp.terminator` (we may have more
721 // than one): they return the control to the parent OpenMP dialect operation
722 // so replace them with the branch to the continuation block. We handle this
723 // here to avoid relying inter-function communication through the
724 // ModuleTranslation class to set up the correct insertion point. This is
725 // also consistent with MLIR's idiom of handling special region terminators
726 // in the same code that handles the region-owning operation.
727 Operation *terminator = bb->getTerminator();
728 if (isa<omp::TerminatorOp, omp::YieldOp>(terminator)) {
729 builder.CreateBr(continuationBlock);
730
731 for (unsigned i = 0, e = terminator->getNumOperands(); i < e; ++i)
732 (*continuationBlockPHIs)[i]->addIncoming(
733 moduleTranslation.lookupValue(terminator->getOperand(i)), llvmBB);
734 }
735 }
736 // After all blocks have been traversed and values mapped, connect the PHI
737 // nodes to the results of preceding blocks.
738 LLVM::detail::connectPHINodes(region, moduleTranslation);
739
740 // Remove the blocks and values defined in this region from the mapping since
741 // they are not visible outside of this region. This allows the same region to
742 // be converted several times, that is cloned, without clashes, and slightly
743 // speeds up the lookups.
744 moduleTranslation.forgetMapping(region);
745
746 return continuationBlock;
747}
748
749/// Convert ProcBindKind from MLIR-generated enum to LLVM enum.
750static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind) {
751 switch (kind) {
752 case omp::ClauseProcBindKind::Close:
753 return llvm::omp::ProcBindKind::OMP_PROC_BIND_close;
754 case omp::ClauseProcBindKind::Master:
755 return llvm::omp::ProcBindKind::OMP_PROC_BIND_master;
756 case omp::ClauseProcBindKind::Primary:
757 return llvm::omp::ProcBindKind::OMP_PROC_BIND_primary;
758 case omp::ClauseProcBindKind::Spread:
759 return llvm::omp::ProcBindKind::OMP_PROC_BIND_spread;
760 }
761 llvm_unreachable("Unknown ClauseProcBindKind kind");
762}
763
764/// Converts an OpenMP 'masked' operation into LLVM IR using OpenMPIRBuilder.
765static LogicalResult
766convertOmpMasked(Operation &opInst, llvm::IRBuilderBase &builder,
767 LLVM::ModuleTranslation &moduleTranslation) {
768 auto maskedOp = cast<omp::MaskedOp>(opInst);
769 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
770
771 if (failed(checkImplementationStatus(opInst)))
772 return failure();
773
774 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
776 // MaskedOp has only one region associated with it.
777 auto &region = maskedOp.getRegion();
778 builder.restoreIP(codeGenIP);
779 return convertOmpOpRegions(region, "omp.masked.region", builder,
780 moduleTranslation)
781 .takeError();
782 };
783
784 // TODO: Perform finalization actions for variables. This has to be
785 // called for variables which have destructors/finalizers.
786 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
787
788 llvm::Value *filterVal = nullptr;
789 if (auto filterVar = maskedOp.getFilteredThreadId()) {
790 filterVal = moduleTranslation.lookupValue(filterVar);
791 } else {
792 llvm::LLVMContext &llvmContext = builder.getContext();
793 filterVal =
794 llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext), /*V=*/0);
795 }
796 assert(filterVal != nullptr);
797 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
798 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
799 moduleTranslation.getOpenMPBuilder()->createMasked(ompLoc, bodyGenCB,
800 finiCB, filterVal);
801
802 if (failed(handleError(afterIP, opInst)))
803 return failure();
804
805 builder.restoreIP(*afterIP);
806 return success();
807}
808
809/// Converts an OpenMP 'master' operation into LLVM IR using OpenMPIRBuilder.
810static LogicalResult
811convertOmpMaster(Operation &opInst, llvm::IRBuilderBase &builder,
812 LLVM::ModuleTranslation &moduleTranslation) {
813 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
814 auto masterOp = cast<omp::MasterOp>(opInst);
815
816 if (failed(checkImplementationStatus(opInst)))
817 return failure();
818
819 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
821 // MasterOp has only one region associated with it.
822 auto &region = masterOp.getRegion();
823 builder.restoreIP(codeGenIP);
824 return convertOmpOpRegions(region, "omp.master.region", builder,
825 moduleTranslation)
826 .takeError();
827 };
828
829 // TODO: Perform finalization actions for variables. This has to be
830 // called for variables which have destructors/finalizers.
831 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
832
833 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
834 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
835 moduleTranslation.getOpenMPBuilder()->createMaster(ompLoc, bodyGenCB,
836 finiCB);
837
838 if (failed(handleError(afterIP, opInst)))
839 return failure();
840
841 builder.restoreIP(*afterIP);
842 return success();
843}
844
845/// Converts an OpenMP 'critical' operation into LLVM IR using OpenMPIRBuilder.
846static LogicalResult
847convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder,
848 LLVM::ModuleTranslation &moduleTranslation) {
849 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
850 auto criticalOp = cast<omp::CriticalOp>(opInst);
851
852 if (failed(checkImplementationStatus(opInst)))
853 return failure();
854
855 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
857 // CriticalOp has only one region associated with it.
858 auto &region = cast<omp::CriticalOp>(opInst).getRegion();
859 builder.restoreIP(codeGenIP);
860 return convertOmpOpRegions(region, "omp.critical.region", builder,
861 moduleTranslation)
862 .takeError();
863 };
864
865 // TODO: Perform finalization actions for variables. This has to be
866 // called for variables which have destructors/finalizers.
867 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
868
869 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
870 llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
871 llvm::Constant *hint = nullptr;
872
873 // If it has a name, it probably has a hint too.
874 if (criticalOp.getNameAttr()) {
875 // The verifiers in OpenMP Dialect guarentee that all the pointers are
876 // non-null
877 auto symbolRef = cast<SymbolRefAttr>(criticalOp.getNameAttr());
878 auto criticalDeclareOp =
880 symbolRef);
881 hint =
882 llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext),
883 static_cast<int>(criticalDeclareOp.getHint()));
884 }
885 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
886 moduleTranslation.getOpenMPBuilder()->createCritical(
887 ompLoc, bodyGenCB, finiCB, criticalOp.getName().value_or(""), hint);
888
889 if (failed(handleError(afterIP, opInst)))
890 return failure();
891
892 builder.restoreIP(*afterIP);
893 return success();
894}
895
896/// A util to collect info needed to convert delayed privatizers from MLIR to
897/// LLVM.
899 template <typename OP>
901 : blockArgs(
902 cast<omp::BlockArgOpenMPOpInterface>(*op).getPrivateBlockArgs()) {
903 mlirVars.reserve(blockArgs.size());
904 llvmVars.reserve(blockArgs.size());
905 collectPrivatizationDecls<OP>(op);
906
907 for (mlir::Value privateVar : op.getPrivateVars())
908 mlirVars.push_back(privateVar);
909 }
910
915
916private:
917 /// Populates `privatizations` with privatization declarations used for the
918 /// given op.
919 template <class OP>
920 void collectPrivatizationDecls(OP op) {
921 std::optional<ArrayAttr> attr = op.getPrivateSyms();
922 if (!attr)
923 return;
924
925 privatizers.reserve(privatizers.size() + attr->size());
926 for (auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
927 privatizers.push_back(findPrivatizer(op, symbolRef));
928 }
929 }
930};
931
932/// Populates `reductions` with reduction declarations used in the given op.
933template <typename T>
934static void
937 std::optional<ArrayAttr> attr = op.getReductionSyms();
938 if (!attr)
939 return;
940
941 reductions.reserve(reductions.size() + op.getNumReductionVars());
942 for (auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
943 reductions.push_back(
945 op, symbolRef));
946 }
947}
948
949/// Translates the blocks contained in the given region and appends them to at
950/// the current insertion point of `builder`. The operations of the entry block
951/// are appended to the current insertion block. If set, `continuationBlockArgs`
952/// is populated with translated values that correspond to the values
953/// omp.yield'ed from the region.
954static LogicalResult inlineConvertOmpRegions(
955 Region &region, StringRef blockName, llvm::IRBuilderBase &builder,
956 LLVM::ModuleTranslation &moduleTranslation,
957 SmallVectorImpl<llvm::Value *> *continuationBlockArgs = nullptr) {
958 if (region.empty())
959 return success();
960
961 // Special case for single-block regions that don't create additional blocks:
962 // insert operations without creating additional blocks.
963 if (region.hasOneBlock()) {
964 llvm::Instruction *potentialTerminator =
965 builder.GetInsertBlock()->empty() ? nullptr
966 : &builder.GetInsertBlock()->back();
967
968 if (potentialTerminator && potentialTerminator->isTerminator())
969 potentialTerminator->removeFromParent();
970 moduleTranslation.mapBlock(&region.front(), builder.GetInsertBlock());
971
972 if (failed(moduleTranslation.convertBlock(
973 region.front(), /*ignoreArguments=*/true, builder)))
974 return failure();
975
976 // The continuation arguments are simply the translated terminator operands.
977 if (continuationBlockArgs)
978 llvm::append_range(
979 *continuationBlockArgs,
980 moduleTranslation.lookupValues(region.front().back().getOperands()));
981
982 // Drop the mapping that is no longer necessary so that the same region can
983 // be processed multiple times.
984 moduleTranslation.forgetMapping(region);
985
986 if (potentialTerminator && potentialTerminator->isTerminator()) {
987 llvm::BasicBlock *block = builder.GetInsertBlock();
988 if (block->empty()) {
989 // this can happen for really simple reduction init regions e.g.
990 // %0 = llvm.mlir.constant(0 : i32) : i32
991 // omp.yield(%0 : i32)
992 // because the llvm.mlir.constant (MLIR op) isn't converted into any
993 // llvm op
994 potentialTerminator->insertInto(block, block->begin());
995 } else {
996 potentialTerminator->insertAfter(&block->back());
997 }
998 }
999
1000 return success();
1001 }
1002
1004 llvm::Expected<llvm::BasicBlock *> continuationBlock =
1005 convertOmpOpRegions(region, blockName, builder, moduleTranslation, &phis);
1006
1007 if (failed(handleError(continuationBlock, *region.getParentOp())))
1008 return failure();
1009
1010 if (continuationBlockArgs)
1011 llvm::append_range(*continuationBlockArgs, phis);
1012 builder.SetInsertPoint(*continuationBlock,
1013 (*continuationBlock)->getFirstInsertionPt());
1014 return success();
1015}
1016
1017namespace {
1018/// Owning equivalents of OpenMPIRBuilder::(Atomic)ReductionGen that are used to
1019/// store lambdas with capture.
1020using OwningReductionGen =
1021 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
1022 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *,
1023 llvm::Value *&)>;
1024using OwningAtomicReductionGen =
1025 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
1026 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Type *, llvm::Value *,
1027 llvm::Value *)>;
1028using OwningDataPtrPtrReductionGen =
1029 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
1030 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *&)>;
1031} // namespace
1032
1033/// Create an OpenMPIRBuilder-compatible reduction generator for the given
1034/// reduction declaration. The generator uses `builder` but ignores its
1035/// insertion point.
1036static OwningReductionGen
1037makeReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder,
1038 LLVM::ModuleTranslation &moduleTranslation) {
1039 // The lambda is mutable because we need access to non-const methods of decl
1040 // (which aren't actually mutating it), and we must capture decl by-value to
1041 // avoid the dangling reference after the parent function returns.
1042 OwningReductionGen gen =
1043 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
1044 llvm::Value *lhs, llvm::Value *rhs,
1045 llvm::Value *&result) mutable
1046 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1047 moduleTranslation.mapValue(decl.getReductionLhsArg(), lhs);
1048 moduleTranslation.mapValue(decl.getReductionRhsArg(), rhs);
1049 builder.restoreIP(insertPoint);
1051 if (failed(inlineConvertOmpRegions(decl.getReductionRegion(),
1052 "omp.reduction.nonatomic.body", builder,
1053 moduleTranslation, &phis)))
1054 return llvm::createStringError(
1055 "failed to inline `combiner` region of `omp.declare_reduction`");
1056 result = llvm::getSingleElement(phis);
1057 return builder.saveIP();
1058 };
1059 return gen;
1060}
1061
1062/// Create an OpenMPIRBuilder-compatible atomic reduction generator for the
1063/// given reduction declaration. The generator uses `builder` but ignores its
1064/// insertion point. Returns null if there is no atomic region available in the
1065/// reduction declaration.
1066static OwningAtomicReductionGen
1067makeAtomicReductionGen(omp::DeclareReductionOp decl,
1068 llvm::IRBuilderBase &builder,
1069 LLVM::ModuleTranslation &moduleTranslation) {
1070 if (decl.getAtomicReductionRegion().empty())
1071 return OwningAtomicReductionGen();
1072
1073 // The lambda is mutable because we need access to non-const methods of decl
1074 // (which aren't actually mutating it), and we must capture decl by-value to
1075 // avoid the dangling reference after the parent function returns.
1076 OwningAtomicReductionGen atomicGen =
1077 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint, llvm::Type *,
1078 llvm::Value *lhs, llvm::Value *rhs) mutable
1079 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1080 moduleTranslation.mapValue(decl.getAtomicReductionLhsArg(), lhs);
1081 moduleTranslation.mapValue(decl.getAtomicReductionRhsArg(), rhs);
1082 builder.restoreIP(insertPoint);
1084 if (failed(inlineConvertOmpRegions(decl.getAtomicReductionRegion(),
1085 "omp.reduction.atomic.body", builder,
1086 moduleTranslation, &phis)))
1087 return llvm::createStringError(
1088 "failed to inline `atomic` region of `omp.declare_reduction`");
1089 assert(phis.empty());
1090 return builder.saveIP();
1091 };
1092 return atomicGen;
1093}
1094
1095/// Create an OpenMPIRBuilder-compatible `data_ptr_ptr` reduction generator for
1096/// the given reduction declaration. The generator uses `builder` but ignores
1097/// its insertion point. Returns null if there is no `data_ptr_ptr` region
1098/// available in the reduction declaration.
1099static OwningDataPtrPtrReductionGen
1100makeRefDataPtrGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder,
1101 LLVM::ModuleTranslation &moduleTranslation, bool isByRef) {
1102 if (!isByRef || decl.getDataPtrPtrRegion().empty())
1103 return OwningDataPtrPtrReductionGen();
1104
1105 OwningDataPtrPtrReductionGen refDataPtrGen =
1106 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
1107 llvm::Value *byRefVal, llvm::Value *&result) mutable
1108 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1109 moduleTranslation.mapValue(decl.getDataPtrPtrRegionArg(), byRefVal);
1110 builder.restoreIP(insertPoint);
1112 if (failed(inlineConvertOmpRegions(decl.getDataPtrPtrRegion(),
1113 "omp.data_ptr_ptr.body", builder,
1114 moduleTranslation, &phis)))
1115 return llvm::createStringError(
1116 "failed to inline `data_ptr_ptr` region of `omp.declare_reduction`");
1117 result = llvm::getSingleElement(phis);
1118 return builder.saveIP();
1119 };
1120
1121 return refDataPtrGen;
1122}
1123
1124/// Converts an OpenMP 'ordered' operation into LLVM IR using OpenMPIRBuilder.
1125static LogicalResult
1126convertOmpOrdered(Operation &opInst, llvm::IRBuilderBase &builder,
1127 LLVM::ModuleTranslation &moduleTranslation) {
1128 auto orderedOp = cast<omp::OrderedOp>(opInst);
1129
1130 if (failed(checkImplementationStatus(opInst)))
1131 return failure();
1132
1133 omp::ClauseDepend dependType = *orderedOp.getDoacrossDependType();
1134 bool isDependSource = dependType == omp::ClauseDepend::dependsource;
1135 unsigned numLoops = *orderedOp.getDoacrossNumLoops();
1136 SmallVector<llvm::Value *> vecValues =
1137 moduleTranslation.lookupValues(orderedOp.getDoacrossDependVars());
1138
1139 size_t indexVecValues = 0;
1140 while (indexVecValues < vecValues.size()) {
1141 SmallVector<llvm::Value *> storeValues;
1142 storeValues.reserve(numLoops);
1143 for (unsigned i = 0; i < numLoops; i++) {
1144 storeValues.push_back(vecValues[indexVecValues]);
1145 indexVecValues++;
1146 }
1147 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1148 findAllocInsertPoints(builder, moduleTranslation);
1149 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1150 builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createOrderedDepend(
1151 ompLoc, allocaIP, numLoops, storeValues, ".cnt.addr", isDependSource));
1152 }
1153 return success();
1154}
1155
1156/// Converts an OpenMP 'ordered_region' operation into LLVM IR using
1157/// OpenMPIRBuilder.
1158static LogicalResult
1159convertOmpOrderedRegion(Operation &opInst, llvm::IRBuilderBase &builder,
1160 LLVM::ModuleTranslation &moduleTranslation) {
1161 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1162 auto orderedRegionOp = cast<omp::OrderedRegionOp>(opInst);
1163
1164 if (failed(checkImplementationStatus(opInst)))
1165 return failure();
1166
1167 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
1168 llvm::ArrayRef<llvm::BasicBlock *> deallocBlocks) {
1169 // OrderedOp has only one region associated with it.
1170 auto &region = cast<omp::OrderedRegionOp>(opInst).getRegion();
1171 builder.restoreIP(codeGenIP);
1172 return convertOmpOpRegions(region, "omp.ordered.region", builder,
1173 moduleTranslation)
1174 .takeError();
1175 };
1176
1177 // TODO: Perform finalization actions for variables. This has to be
1178 // called for variables which have destructors/finalizers.
1179 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
1180
1181 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1182 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1183 moduleTranslation.getOpenMPBuilder()->createOrderedThreadsSimd(
1184 ompLoc, bodyGenCB, finiCB, !orderedRegionOp.getParLevelSimd());
1185
1186 if (failed(handleError(afterIP, opInst)))
1187 return failure();
1188
1189 builder.restoreIP(*afterIP);
1190 return success();
1191}
1192
1193namespace {
1194/// Contains the arguments for an LLVM store operation
1195struct DeferredStore {
1196 DeferredStore(llvm::Value *value, llvm::Value *address)
1197 : value(value), address(address) {}
1198
1199 llvm::Value *value;
1200 llvm::Value *address;
1201};
1202} // namespace
1203
1204/// Allocate space for privatized reduction variables.
1205/// `deferredStores` contains information to create store operations which needs
1206/// to be inserted after all allocas
1207template <typename T>
1208static LogicalResult
1210 llvm::IRBuilderBase &builder,
1211 LLVM::ModuleTranslation &moduleTranslation,
1212 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1214 SmallVectorImpl<llvm::Value *> &privateReductionVariables,
1215 DenseMap<Value, llvm::Value *> &reductionVariableMap,
1216 SmallVectorImpl<DeferredStore> &deferredStores,
1217 llvm::ArrayRef<bool> isByRefs) {
1218 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1219 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1220
1221 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
1222 bool useDeviceSharedMem = omp::opInSharedDeviceContext(*op);
1223
1224 // delay creating stores until after all allocas
1225 deferredStores.reserve(op.getNumReductionVars());
1226
1227 for (std::size_t i = 0; i < op.getNumReductionVars(); ++i) {
1228 Region &allocRegion = reductionDecls[i].getAllocRegion();
1229 if (isByRefs[i]) {
1230 if (allocRegion.empty())
1231 continue;
1232
1234 if (failed(inlineConvertOmpRegions(allocRegion, "omp.reduction.alloc",
1235 builder, moduleTranslation, &phis)))
1236 return op.emitError(
1237 "failed to inline `alloc` region of `omp.declare_reduction`");
1238
1239 assert(phis.size() == 1 && "expected one allocation to be yielded");
1240 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1241
1242 // Allocate reduction variable (which is a pointer to the real reduction
1243 // variable allocated in the inlined region)
1244 llvm::Type *ptrTy = builder.getPtrTy();
1245 llvm::Type *varTy =
1246 moduleTranslation.convertType(reductionDecls[i].getType());
1247 llvm::Value *var;
1248 if (useDeviceSharedMem) {
1249 var = ompBuilder->createOMPAllocShared(builder, varTy);
1250 } else {
1251 var = builder.CreateAlloca(varTy);
1252 var = builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
1253 }
1254
1255 llvm::Value *castPhi =
1256 builder.CreatePointerBitCastOrAddrSpaceCast(phis[0], ptrTy);
1257
1258 deferredStores.emplace_back(castPhi, var);
1259
1260 privateReductionVariables[i] = var;
1261 moduleTranslation.mapValue(reductionArgs[i], castPhi);
1262 reductionVariableMap.try_emplace(op.getReductionVars()[i], castPhi);
1263 } else {
1264 assert(allocRegion.empty() &&
1265 "allocaction is implicit for by-val reduction");
1266
1267 llvm::Type *ptrTy = builder.getPtrTy();
1268 llvm::Type *varTy =
1269 moduleTranslation.convertType(reductionDecls[i].getType());
1270 llvm::Value *var;
1271 if (useDeviceSharedMem) {
1272 var = ompBuilder->createOMPAllocShared(builder, varTy);
1273 } else {
1274 var = builder.CreateAlloca(varTy);
1275 var = builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
1276 }
1277
1278 moduleTranslation.mapValue(reductionArgs[i], var);
1279 privateReductionVariables[i] = var;
1280 reductionVariableMap.try_emplace(op.getReductionVars()[i], var);
1281 }
1282 }
1283
1284 return success();
1285}
1286
1287/// Map input arguments to reduction initialization region
1288template <typename T>
1289static void
1291 llvm::IRBuilderBase &builder,
1293 DenseMap<Value, llvm::Value *> &reductionVariableMap,
1294 unsigned i) {
1295 // map input argument to the initialization region
1296 mlir::omp::DeclareReductionOp &reduction = reductionDecls[i];
1297 Region &initializerRegion = reduction.getInitializerRegion();
1298 Block &entry = initializerRegion.front();
1299
1300 mlir::Value mlirSource = loop.getReductionVars()[i];
1301 llvm::Value *llvmSource = moduleTranslation.lookupValue(mlirSource);
1302 llvm::Value *origVal = llvmSource;
1303 // If a non-pointer value is expected, load the value from the source pointer.
1304 if (!isa<LLVM::LLVMPointerType>(
1305 reduction.getInitializerMoldArg().getType()) &&
1306 isa<LLVM::LLVMPointerType>(mlirSource.getType())) {
1307 origVal =
1308 builder.CreateLoad(moduleTranslation.convertType(
1309 reduction.getInitializerMoldArg().getType()),
1310 llvmSource, "omp_orig");
1311 }
1312 moduleTranslation.mapValue(reduction.getInitializerMoldArg(), origVal);
1313
1314 if (entry.getNumArguments() > 1) {
1315 llvm::Value *allocation =
1316 reductionVariableMap.lookup(loop.getReductionVars()[i]);
1317 moduleTranslation.mapValue(reduction.getInitializerAllocArg(), allocation);
1318 }
1319}
1320
1321static void
1322setInsertPointForPossiblyEmptyBlock(llvm::IRBuilderBase &builder,
1323 llvm::BasicBlock *block = nullptr) {
1324 if (block == nullptr)
1325 block = builder.GetInsertBlock();
1326
1327 if (!block->hasTerminator())
1328 builder.SetInsertPoint(block);
1329 else
1330 builder.SetInsertPoint(block->getTerminator());
1331}
1332
1333/// Inline reductions' `init` regions. This functions assumes that the
1334/// `builder`'s insertion point is where the user wants the `init` regions to be
1335/// inlined; i.e. it does not try to find a proper insertion location for the
1336/// `init` regions. It also leaves the `builder's insertions point in a state
1337/// where the user can continue the code-gen directly afterwards.
1338template <typename OP>
1339static LogicalResult
1340initReductionVars(OP op, ArrayRef<BlockArgument> reductionArgs,
1341 llvm::IRBuilderBase &builder,
1342 LLVM::ModuleTranslation &moduleTranslation,
1343 llvm::BasicBlock *latestAllocaBlock,
1345 SmallVectorImpl<llvm::Value *> &privateReductionVariables,
1346 DenseMap<Value, llvm::Value *> &reductionVariableMap,
1347 llvm::ArrayRef<bool> isByRef,
1348 SmallVectorImpl<DeferredStore> &deferredStores) {
1349 if (op.getNumReductionVars() == 0)
1350 return success();
1351
1352 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
1353 bool useDeviceSharedMem = omp::opInSharedDeviceContext(*op);
1354
1355 llvm::BasicBlock *initBlock = splitBB(builder, true, "omp.reduction.init");
1356 auto allocaIP = llvm::IRBuilderBase::InsertPoint(
1357 latestAllocaBlock, latestAllocaBlock->getTerminator()->getIterator());
1358 builder.restoreIP(allocaIP);
1359 SmallVector<llvm::Value *> byRefVars(op.getNumReductionVars());
1360
1361 for (unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1362 if (isByRef[i]) {
1363 if (!reductionDecls[i].getAllocRegion().empty())
1364 continue;
1365
1366 // TODO: remove after all users of by-ref are updated to use the alloc
1367 // region: Allocate reduction variable (which is a pointer to the real
1368 // reduciton variable allocated in the inlined region)
1369 llvm::Type *varTy =
1370 moduleTranslation.convertType(reductionDecls[i].getType());
1371 if (useDeviceSharedMem)
1372 byRefVars[i] = ompBuilder->createOMPAllocShared(builder, varTy);
1373 else
1374 byRefVars[i] = builder.CreateAlloca(varTy);
1375 }
1376 }
1377
1378 setInsertPointForPossiblyEmptyBlock(builder, initBlock);
1379
1380 // store result of the alloc region to the allocated pointer to the real
1381 // reduction variable
1382 for (auto [data, addr] : deferredStores)
1383 builder.CreateStore(data, addr);
1384
1385 // Before the loop, store the initial values of reductions into reduction
1386 // variables. Although this could be done after allocas, we don't want to mess
1387 // up with the alloca insertion point.
1388 for (unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1390
1391 // map block argument to initializer region
1392 mapInitializationArgs(op, moduleTranslation, builder, reductionDecls,
1393 reductionVariableMap, i);
1394
1395 // TODO In some cases (specially on the GPU), the init regions may
1396 // contains stack alloctaions. If the region is inlined in a loop, this is
1397 // problematic. Instead of just inlining the region, handle allocations by
1398 // hoisting fixed length allocations to the function entry and using
1399 // stacksave and restore for variable length ones.
1400 if (failed(inlineConvertOmpRegions(reductionDecls[i].getInitializerRegion(),
1401 "omp.reduction.neutral", builder,
1402 moduleTranslation, &phis)))
1403 return failure();
1404
1405 assert(phis.size() == 1 && "expected one value to be yielded from the "
1406 "reduction neutral element declaration region");
1407
1409
1410 if (isByRef[i]) {
1411 if (!reductionDecls[i].getAllocRegion().empty())
1412 // done in allocReductionVars
1413 continue;
1414
1415 // TODO: this path can be removed once all users of by-ref are updated to
1416 // use an alloc region
1417
1418 // Store the result of the inlined region to the allocated reduction var
1419 // ptr
1420 builder.CreateStore(phis[0], byRefVars[i]);
1421
1422 privateReductionVariables[i] = byRefVars[i];
1423 moduleTranslation.mapValue(reductionArgs[i], phis[0]);
1424 reductionVariableMap.try_emplace(op.getReductionVars()[i], phis[0]);
1425 } else {
1426 // for by-ref case the store is inside of the reduction region
1427 builder.CreateStore(phis[0], privateReductionVariables[i]);
1428 // the rest was handled in allocByValReductionVars
1429 }
1430
1431 // forget the mapping for the initializer region because we might need a
1432 // different mapping if this reduction declaration is re-used for a
1433 // different variable
1434 moduleTranslation.forgetMapping(reductionDecls[i].getInitializerRegion());
1435 }
1436
1437 return success();
1438}
1439
1440/// Collect reduction info
1441template <typename T>
1442static void collectReductionInfo(
1443 T loop, llvm::IRBuilderBase &builder,
1444 LLVM::ModuleTranslation &moduleTranslation,
1447 SmallVectorImpl<OwningAtomicReductionGen> &owningAtomicReductionGens,
1449 const ArrayRef<llvm::Value *> privateReductionVariables,
1451 ArrayRef<bool> isByRef) {
1452 unsigned numReductions = loop.getNumReductionVars();
1453
1454 for (unsigned i = 0; i < numReductions; ++i) {
1455 owningReductionGens.push_back(
1456 makeReductionGen(reductionDecls[i], builder, moduleTranslation));
1457 owningAtomicReductionGens.push_back(
1458 makeAtomicReductionGen(reductionDecls[i], builder, moduleTranslation));
1460 reductionDecls[i], builder, moduleTranslation, isByRef[i]));
1461 }
1462
1463 // Collect the reduction information.
1464 reductionInfos.reserve(numReductions);
1465 for (unsigned i = 0; i < numReductions; ++i) {
1466 llvm::OpenMPIRBuilder::ReductionGenAtomicCBTy atomicGen = nullptr;
1467 if (owningAtomicReductionGens[i])
1468 atomicGen = owningAtomicReductionGens[i];
1469 llvm::Value *variable =
1470 moduleTranslation.lookupValue(loop.getReductionVars()[i]);
1471 mlir::Type allocatedType;
1472 reductionDecls[i].getAllocRegion().walk([&](mlir::Operation *op) {
1473 if (auto alloca = mlir::dyn_cast<LLVM::AllocaOp>(op)) {
1474 allocatedType = alloca.getElemType();
1476 }
1477
1479 });
1480
1481 reductionInfos.push_back(
1482 {moduleTranslation.convertType(reductionDecls[i].getType()), variable,
1483 privateReductionVariables[i],
1484 /*EvaluationKind=*/llvm::OpenMPIRBuilder::EvalKind::Scalar,
1486 /*ReductionGenClang=*/nullptr, atomicGen,
1488 allocatedType ? moduleTranslation.convertType(allocatedType) : nullptr,
1489 reductionDecls[i].getByrefElementType()
1490 ? moduleTranslation.convertType(
1491 *reductionDecls[i].getByrefElementType())
1492 : nullptr});
1493 }
1494}
1495
1496/// handling of DeclareReductionOp's cleanup region
1497static LogicalResult
1499 llvm::ArrayRef<llvm::Value *> privateVariables,
1500 LLVM::ModuleTranslation &moduleTranslation,
1501 llvm::IRBuilderBase &builder, StringRef regionName,
1502 bool shouldLoadCleanupRegionArg = true) {
1503 for (auto [i, cleanupRegion] : llvm::enumerate(cleanupRegions)) {
1504 if (cleanupRegion->empty())
1505 continue;
1506
1507 // map the argument to the cleanup region
1508 Block &entry = cleanupRegion->front();
1509
1510 llvm::Instruction *potentialTerminator =
1511 builder.GetInsertBlock()->empty() ? nullptr
1512 : &builder.GetInsertBlock()->back();
1513 if (potentialTerminator && potentialTerminator->isTerminator())
1514 builder.SetInsertPoint(potentialTerminator);
1515 llvm::Value *privateVarValue =
1516 shouldLoadCleanupRegionArg
1517 ? builder.CreateLoad(
1518 moduleTranslation.convertType(entry.getArgument(0).getType()),
1519 privateVariables[i])
1520 : privateVariables[i];
1521
1522 moduleTranslation.mapValue(entry.getArgument(0), privateVarValue);
1523
1524 if (failed(inlineConvertOmpRegions(*cleanupRegion, regionName, builder,
1525 moduleTranslation)))
1526 return failure();
1527
1528 // clear block argument mapping in case it needs to be re-created with a
1529 // different source for another use of the same reduction decl
1530 moduleTranslation.forgetMapping(*cleanupRegion);
1531 }
1532 return success();
1533}
1534
1535// TODO: not used by ParallelOp
1536template <class OP>
1537static LogicalResult createReductionsAndCleanup(
1538 OP op, llvm::IRBuilderBase &builder,
1539 LLVM::ModuleTranslation &moduleTranslation,
1540 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1542 ArrayRef<llvm::Value *> privateReductionVariables, ArrayRef<bool> isByRef,
1543 bool isNowait = false, bool isTeamsReduction = false) {
1544 // Process the reductions if required.
1545 if (op.getNumReductionVars() == 0)
1546 return success();
1547
1549 SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
1550 SmallVector<OwningDataPtrPtrReductionGen> owningReductionGenRefDataPtrGens;
1552
1553 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
1554
1555 // Create the reduction generators. We need to own them here because
1556 // ReductionInfo only accepts references to the generators.
1557 collectReductionInfo(op, builder, moduleTranslation, reductionDecls,
1558 owningReductionGens, owningAtomicReductionGens,
1559 owningReductionGenRefDataPtrGens,
1560 privateReductionVariables, reductionInfos, isByRef);
1561
1562 // The call to createReductions below expects the block to have a
1563 // terminator. Create an unreachable instruction to serve as terminator
1564 // and remove it later.
1565 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
1566 builder.SetInsertPoint(tempTerminator);
1567 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
1568 ompBuilder->createReductions(builder.saveIP(), allocaIP, reductionInfos,
1569 isByRef, isNowait, isTeamsReduction);
1570
1571 if (failed(handleError(contInsertPoint, *op)))
1572 return failure();
1573
1574 if (!contInsertPoint->getBlock())
1575 return op->emitOpError() << "failed to convert reductions";
1576
1577 llvm::OpenMPIRBuilder::InsertPointTy afterIP = *contInsertPoint;
1578 if (!isTeamsReduction) {
1579 llvm::OpenMPIRBuilder::InsertPointOrErrorTy barrierIP =
1580 ompBuilder->createBarrier(*contInsertPoint, llvm::omp::OMPD_for);
1581
1582 if (failed(handleError(barrierIP, *op)))
1583 return failure();
1584 afterIP = *barrierIP;
1585 }
1586
1587 tempTerminator->eraseFromParent();
1588 builder.restoreIP(afterIP);
1589
1590 // after the construct, deallocate private reduction variables
1591 SmallVector<Region *> reductionRegions;
1592 llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
1593 [](omp::DeclareReductionOp reductionDecl) {
1594 return &reductionDecl.getCleanupRegion();
1595 });
1596 LogicalResult result = inlineOmpRegionCleanup(
1597 reductionRegions, privateReductionVariables, moduleTranslation, builder,
1598 "omp.reduction.cleanup");
1599
1600 bool useDeviceSharedMem = omp::opInSharedDeviceContext(*op);
1601 if (useDeviceSharedMem) {
1602 for (auto [var, reductionDecl] :
1603 llvm::zip_equal(privateReductionVariables, reductionDecls))
1604 ompBuilder->createOMPFreeShared(
1605 builder, var, moduleTranslation.convertType(reductionDecl.getType()));
1606 }
1607
1608 return result;
1609}
1610
1611static ArrayRef<bool> getIsByRef(std::optional<ArrayRef<bool>> attr) {
1612 if (!attr)
1613 return {};
1614 return *attr;
1615}
1616
1617// TODO: not used by omp.parallel
1618template <typename OP>
1620 OP op, ArrayRef<BlockArgument> reductionArgs, llvm::IRBuilderBase &builder,
1621 LLVM::ModuleTranslation &moduleTranslation,
1622 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1624 SmallVectorImpl<llvm::Value *> &privateReductionVariables,
1625 DenseMap<Value, llvm::Value *> &reductionVariableMap,
1626 llvm::ArrayRef<bool> isByRef) {
1627 if (op.getNumReductionVars() == 0)
1628 return success();
1629
1630 SmallVector<DeferredStore> deferredStores;
1631
1632 if (failed(allocReductionVars(op, reductionArgs, builder, moduleTranslation,
1633 allocaIP, reductionDecls,
1634 privateReductionVariables, reductionVariableMap,
1635 deferredStores, isByRef)))
1636 return failure();
1637
1638 return initReductionVars(op, reductionArgs, builder, moduleTranslation,
1639 allocaIP.getBlock(), reductionDecls,
1640 privateReductionVariables, reductionVariableMap,
1641 isByRef, deferredStores);
1642}
1643
1644/// Return the llvm::Value * corresponding to the `privateVar` that
1645/// is being privatized. It isn't always as simple as looking up
1646/// moduleTranslation with privateVar. For instance, in case of
1647/// an allocatable, the descriptor for the allocatable is privatized.
1648/// This descriptor is mapped using an MapInfoOp. So, this function
1649/// will return a pointer to the llvm::Value corresponding to the
1650/// block argument for the mapped descriptor.
1651static llvm::Value *
1652findAssociatedValue(Value privateVar, llvm::IRBuilderBase &builder,
1653 LLVM::ModuleTranslation &moduleTranslation,
1654 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1655 if (mappedPrivateVars == nullptr || !mappedPrivateVars->contains(privateVar))
1656 return moduleTranslation.lookupValue(privateVar);
1657
1658 Value blockArg = (*mappedPrivateVars)[privateVar];
1659 Type privVarType = privateVar.getType();
1660 Type blockArgType = blockArg.getType();
1661 assert(isa<LLVM::LLVMPointerType>(blockArgType) &&
1662 "A block argument corresponding to a mapped var should have "
1663 "!llvm.ptr type");
1664
1665 if (privVarType == blockArgType)
1666 return moduleTranslation.lookupValue(blockArg);
1667
1668 // This typically happens when the privatized type is lowered from
1669 // boxchar<KIND> and gets lowered to !llvm.struct<(ptr, i64)>. That is the
1670 // struct/pair is passed by value. But, mapped values are passed only as
1671 // pointers, so before we privatize, we must load the pointer.
1672 if (!isa<LLVM::LLVMPointerType>(privVarType))
1673 return builder.CreateLoad(moduleTranslation.convertType(privVarType),
1674 moduleTranslation.lookupValue(blockArg));
1675
1676 return moduleTranslation.lookupValue(privateVar);
1677}
1678
1679/// Initialize a single (first)private variable. You probably want to use
1680/// allocateAndInitPrivateVars instead of this.
1681/// This returns the private variable which has been initialized. This
1682/// variable should be mapped before constructing the body of the Op.
1684initPrivateVar(llvm::IRBuilderBase &builder,
1685 LLVM::ModuleTranslation &moduleTranslation,
1686 omp::PrivateClauseOp &privDecl, llvm::Value *nonPrivateVar,
1687 BlockArgument &blockArg, llvm::Value *llvmPrivateVar,
1688 llvm::BasicBlock *privInitBlock,
1689 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1690 Region &initRegion = privDecl.getInitRegion();
1691 if (initRegion.empty())
1692 return llvmPrivateVar;
1693
1694 assert(nonPrivateVar);
1695 moduleTranslation.mapValue(privDecl.getInitMoldArg(), nonPrivateVar);
1696 moduleTranslation.mapValue(privDecl.getInitPrivateArg(), llvmPrivateVar);
1697
1698 // in-place convert the private initialization region
1700 if (failed(inlineConvertOmpRegions(initRegion, "omp.private.init", builder,
1701 moduleTranslation, &phis)))
1702 return llvm::createStringError(
1703 "failed to inline `init` region of `omp.private`");
1704
1705 assert(phis.size() == 1 && "expected one allocation to be yielded");
1706
1707 // clear init region block argument mapping in case it needs to be
1708 // re-created with a different source for another use of the same
1709 // reduction decl
1710 moduleTranslation.forgetMapping(initRegion);
1711
1712 // Prefer the value yielded from the init region to the allocated private
1713 // variable in case the region is operating on arguments by-value (e.g.
1714 // Fortran character boxes).
1715 return phis[0];
1716}
1717
1718/// Version of initPrivateVar which looks up the nonPrivateVar from mlirPrivVar.
1720 llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation,
1721 omp::PrivateClauseOp &privDecl, Value mlirPrivVar, BlockArgument &blockArg,
1722 llvm::Value *llvmPrivateVar, llvm::BasicBlock *privInitBlock,
1723 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1724 return initPrivateVar(
1725 builder, moduleTranslation, privDecl,
1726 findAssociatedValue(mlirPrivVar, builder, moduleTranslation,
1727 mappedPrivateVars),
1728 blockArg, llvmPrivateVar, privInitBlock, mappedPrivateVars);
1729}
1730
1731static llvm::Error
1732initPrivateVars(llvm::IRBuilderBase &builder,
1733 LLVM::ModuleTranslation &moduleTranslation,
1734 PrivateVarsInfo &privateVarsInfo,
1735 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1736 if (privateVarsInfo.blockArgs.empty())
1737 return llvm::Error::success();
1738
1739 llvm::BasicBlock *privInitBlock = splitBB(builder, true, "omp.private.init");
1740 setInsertPointForPossiblyEmptyBlock(builder, privInitBlock);
1741
1742 for (auto [idx, zip] : llvm::enumerate(llvm::zip_equal(
1743 privateVarsInfo.privatizers, privateVarsInfo.mlirVars,
1744 privateVarsInfo.blockArgs, privateVarsInfo.llvmVars))) {
1745 auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVar] = zip;
1747 builder, moduleTranslation, privDecl, mlirPrivVar, blockArg,
1748 llvmPrivateVar, privInitBlock, mappedPrivateVars);
1749
1750 if (!privVarOrErr)
1751 return privVarOrErr.takeError();
1752
1753 llvmPrivateVar = privVarOrErr.get();
1754 moduleTranslation.mapValue(blockArg, llvmPrivateVar);
1755
1757 }
1758
1759 return llvm::Error::success();
1760}
1761
1762/// Allocate and initialize delayed private variables. Returns the basic block
1763/// which comes after all of these allocations. llvm::Value * for each of these
1764/// private variables are populated in llvmPrivateVars.
1765template <typename T>
1767allocatePrivateVars(T op, llvm::IRBuilderBase &builder,
1768 LLVM::ModuleTranslation &moduleTranslation,
1769 PrivateVarsInfo &privateVarsInfo,
1770 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1771 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1772 // Allocate private vars
1773 llvm::Instruction *allocaTerminator = allocaIP.getBlock()->getTerminator();
1774 splitBB(llvm::OpenMPIRBuilder::InsertPointTy(allocaIP.getBlock(),
1775 allocaTerminator->getIterator()),
1776 true, allocaTerminator->getStableDebugLoc(),
1777 "omp.region.after_alloca");
1778
1779 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1780 // Update the allocaTerminator since the alloca block was split above.
1781 allocaTerminator = allocaIP.getBlock()->getTerminator();
1782 builder.SetInsertPoint(allocaTerminator);
1783 // The new terminator is an uncondition branch created by the splitBB above.
1784 assert(allocaTerminator->getNumSuccessors() == 1 &&
1785 "This is an unconditional branch created by splitBB");
1786
1787 llvm::DataLayout dataLayout = builder.GetInsertBlock()->getDataLayout();
1788 llvm::BasicBlock *afterAllocas = allocaTerminator->getSuccessor(0);
1789
1790 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
1791 bool mightUseDeviceSharedMem = omp::opInSharedDeviceContext(*op);
1792 unsigned int allocaAS =
1793 moduleTranslation.getLLVMModule()->getDataLayout().getAllocaAddrSpace();
1794 unsigned int defaultAS = moduleTranslation.getLLVMModule()
1795 ->getDataLayout()
1796 .getProgramAddressSpace();
1797
1798 for (auto [privDecl, mlirPrivVar, blockArg] :
1799 llvm::zip_equal(privateVarsInfo.privatizers, privateVarsInfo.mlirVars,
1800 privateVarsInfo.blockArgs)) {
1801 llvm::Type *llvmAllocType =
1802 moduleTranslation.convertType(privDecl.getType());
1803 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1804 llvm::Value *llvmPrivateVar = nullptr;
1805 if (mightUseDeviceSharedMem && omp::allocaUsesRequireSharedMem(blockArg)) {
1806 llvmPrivateVar = ompBuilder->createOMPAllocShared(builder, llvmAllocType);
1807 } else {
1808 llvmPrivateVar = builder.CreateAlloca(
1809 llvmAllocType, /*ArraySize=*/nullptr, "omp.private.alloc");
1810 if (allocaAS != defaultAS)
1811 llvmPrivateVar = builder.CreateAddrSpaceCast(
1812 llvmPrivateVar, builder.getPtrTy(defaultAS));
1813 }
1814
1815 privateVarsInfo.llvmVars.push_back(llvmPrivateVar);
1816 }
1817
1818 return afterAllocas;
1819}
1820
1821/// This can't always be determined statically, but when we can, it is good to
1822/// avoid generating compiler-added barriers which will deadlock the program.
1824 for (mlir::Operation *parent = op->getParentOp(); parent != nullptr;
1825 parent = parent->getParentOp()) {
1826 if (mlir::isa<omp::SingleOp, omp::CriticalOp>(parent))
1827 return true;
1828
1829 // e.g.
1830 // omp.single {
1831 // omp.parallel {
1832 // op
1833 // }
1834 // }
1835 if (mlir::isa<omp::ParallelOp>(parent))
1836 return false;
1837 }
1838 return false;
1839}
1840
1841static LogicalResult copyFirstPrivateVars(
1842 mlir::Operation *op, llvm::IRBuilderBase &builder,
1843 LLVM::ModuleTranslation &moduleTranslation,
1845 ArrayRef<llvm::Value *> llvmPrivateVars,
1846 SmallVectorImpl<omp::PrivateClauseOp> &privateDecls, bool insertBarrier,
1847 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1848 // Apply copy region for firstprivate.
1849 bool needsFirstprivate =
1850 llvm::any_of(privateDecls, [](omp::PrivateClauseOp &privOp) {
1851 return privOp.getDataSharingType() ==
1852 omp::DataSharingClauseType::FirstPrivate;
1853 });
1854
1855 if (!needsFirstprivate)
1856 return success();
1857
1858 llvm::BasicBlock *copyBlock =
1859 splitBB(builder, /*CreateBranch=*/true, "omp.private.copy");
1860 setInsertPointForPossiblyEmptyBlock(builder, copyBlock);
1861
1862 for (auto [decl, moldVar, llvmVar] :
1863 llvm::zip_equal(privateDecls, moldVars, llvmPrivateVars)) {
1864 if (decl.getDataSharingType() != omp::DataSharingClauseType::FirstPrivate)
1865 continue;
1866
1867 // copyRegion implements `lhs = rhs`
1868 Region &copyRegion = decl.getCopyRegion();
1869
1870 moduleTranslation.mapValue(decl.getCopyMoldArg(), moldVar);
1871
1872 // map copyRegion lhs arg
1873 moduleTranslation.mapValue(decl.getCopyPrivateArg(), llvmVar);
1874
1875 // in-place convert copy region
1876 if (failed(inlineConvertOmpRegions(copyRegion, "omp.private.copy", builder,
1877 moduleTranslation)))
1878 return decl.emitError("failed to inline `copy` region of `omp.private`");
1879
1881
1882 // ignore unused value yielded from copy region
1883
1884 // clear copy region block argument mapping in case it needs to be
1885 // re-created with different sources for reuse of the same reduction
1886 // decl
1887 moduleTranslation.forgetMapping(copyRegion);
1888 }
1889
1890 if (insertBarrier && !opIsInSingleThread(op)) {
1891 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
1892 llvm::OpenMPIRBuilder::InsertPointOrErrorTy res =
1893 ompBuilder->createBarrier(builder.saveIP(), llvm::omp::OMPD_barrier);
1894 if (failed(handleError(res, *op)))
1895 return failure();
1896 }
1897
1898 return success();
1899}
1900
1901static LogicalResult copyFirstPrivateVars(
1902 mlir::Operation *op, llvm::IRBuilderBase &builder,
1903 LLVM::ModuleTranslation &moduleTranslation,
1904 SmallVectorImpl<mlir::Value> &mlirPrivateVars,
1905 ArrayRef<llvm::Value *> llvmPrivateVars,
1906 SmallVectorImpl<omp::PrivateClauseOp> &privateDecls, bool insertBarrier,
1907 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1908 llvm::SmallVector<llvm::Value *> moldVars(mlirPrivateVars.size());
1909 llvm::transform(mlirPrivateVars, moldVars.begin(), [&](mlir::Value mlirVar) {
1910 // map copyRegion rhs arg
1911 llvm::Value *moldVar = findAssociatedValue(
1912 mlirVar, builder, moduleTranslation, mappedPrivateVars);
1913 assert(moldVar);
1914 return moldVar;
1915 });
1916 return copyFirstPrivateVars(op, builder, moduleTranslation, moldVars,
1917 llvmPrivateVars, privateDecls, insertBarrier,
1918 mappedPrivateVars);
1919}
1920
1921template <typename T>
1922static LogicalResult
1923cleanupPrivateVars(T op, llvm::IRBuilderBase &builder,
1924 LLVM::ModuleTranslation &moduleTranslation, Location loc,
1925 PrivateVarsInfo &privateVarsInfo) {
1926 // private variable deallocation
1927 SmallVector<Region *> privateCleanupRegions;
1928 llvm::transform(privateVarsInfo.privatizers,
1929 std::back_inserter(privateCleanupRegions),
1930 [](omp::PrivateClauseOp privatizer) {
1931 return &privatizer.getDeallocRegion();
1932 });
1933
1934 if (failed(inlineOmpRegionCleanup(privateCleanupRegions,
1935 privateVarsInfo.llvmVars, moduleTranslation,
1936 builder, "omp.private.dealloc",
1937 /*shouldLoadCleanupRegionArg=*/false)))
1938 return mlir::emitError(loc, "failed to inline `dealloc` region of an "
1939 "`omp.private` op in");
1940
1941 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
1942 bool mightUseDeviceSharedMem = omp::opInSharedDeviceContext(*op);
1943 for (auto [privDecl, llvmPrivVar, blockArg] :
1944 llvm::zip_equal(privateVarsInfo.privatizers, privateVarsInfo.llvmVars,
1945 privateVarsInfo.blockArgs)) {
1946 if (mightUseDeviceSharedMem && omp::allocaUsesRequireSharedMem(blockArg)) {
1947 ompBuilder->createOMPFreeShared(
1948 builder, llvmPrivVar,
1949 moduleTranslation.convertType(privDecl.getType()));
1950 }
1951 }
1952
1953 return success();
1954}
1955
1956/// Returns true if the construct contains omp.cancel or omp.cancellation_point
1958 // omp.cancel and omp.cancellation_point must be "closely nested" so they will
1959 // be visible and not inside of function calls. This is enforced by the
1960 // verifier.
1961 return op
1962 ->walk([](Operation *child) {
1963 if (mlir::isa<omp::CancelOp, omp::CancellationPointOp>(child))
1964 return WalkResult::interrupt();
1965 return WalkResult::advance();
1966 })
1967 .wasInterrupted();
1968}
1969
1970static LogicalResult
1971convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder,
1972 LLVM::ModuleTranslation &moduleTranslation) {
1973 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1974 using StorableBodyGenCallbackTy =
1975 llvm::OpenMPIRBuilder::StorableBodyGenCallbackTy;
1976
1977 auto sectionsOp = cast<omp::SectionsOp>(opInst);
1978
1979 if (failed(checkImplementationStatus(opInst)))
1980 return failure();
1981
1982 llvm::ArrayRef<bool> isByRef = getIsByRef(sectionsOp.getReductionByref());
1983 assert(isByRef.size() == sectionsOp.getNumReductionVars());
1984
1986 collectReductionDecls(sectionsOp, reductionDecls);
1987 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1988 findAllocInsertPoints(builder, moduleTranslation);
1989
1990 SmallVector<llvm::Value *> privateReductionVariables(
1991 sectionsOp.getNumReductionVars());
1992 DenseMap<Value, llvm::Value *> reductionVariableMap;
1993
1994 MutableArrayRef<BlockArgument> reductionArgs =
1995 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
1996
1998 sectionsOp, reductionArgs, builder, moduleTranslation, allocaIP,
1999 reductionDecls, privateReductionVariables, reductionVariableMap,
2000 isByRef)))
2001 return failure();
2002
2004
2005 for (Operation &op : *sectionsOp.getRegion().begin()) {
2006 auto sectionOp = dyn_cast<omp::SectionOp>(op);
2007 if (!sectionOp) // omp.terminator
2008 continue;
2009
2010 Region &region = sectionOp.getRegion();
2011 auto sectionCB = [&sectionsOp, &region, &builder, &moduleTranslation](
2012 InsertPointTy allocaIP, InsertPointTy codeGenIP,
2013 ArrayRef<llvm::BasicBlock *> deallocBlocks) {
2014 builder.restoreIP(codeGenIP);
2015
2016 // map the omp.section reduction block argument to the omp.sections block
2017 // arguments
2018 // TODO: this assumes that the only block arguments are reduction
2019 // variables
2020 assert(region.getNumArguments() ==
2021 sectionsOp.getRegion().getNumArguments());
2022 for (auto [sectionsArg, sectionArg] : llvm::zip_equal(
2023 sectionsOp.getRegion().getArguments(), region.getArguments())) {
2024 llvm::Value *llvmVal = moduleTranslation.lookupValue(sectionsArg);
2025 assert(llvmVal);
2026 moduleTranslation.mapValue(sectionArg, llvmVal);
2027 }
2028
2029 return convertOmpOpRegions(region, "omp.section.region", builder,
2030 moduleTranslation)
2031 .takeError();
2032 };
2033 sectionCBs.push_back(sectionCB);
2034 }
2035
2036 // No sections within omp.sections operation - skip generation. This situation
2037 // is only possible if there is only a terminator operation inside the
2038 // sections operation
2039 if (sectionCBs.empty())
2040 return success();
2041
2042 assert(isa<omp::SectionOp>(*sectionsOp.getRegion().op_begin()));
2043
2044 // TODO: Perform appropriate actions according to the data-sharing
2045 // attribute (shared, private, firstprivate, ...) of variables.
2046 // Currently defaults to shared.
2047 auto privCB = [&](InsertPointTy, InsertPointTy codeGenIP, llvm::Value &,
2048 llvm::Value &vPtr, llvm::Value *&replacementValue)
2049 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
2050 replacementValue = &vPtr;
2051 return codeGenIP;
2052 };
2053
2054 // TODO: Perform finalization actions for variables. This has to be
2055 // called for variables which have destructors/finalizers.
2056 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
2057
2058 allocaIP = findAllocInsertPoints(builder, moduleTranslation);
2059 bool isCancellable = constructIsCancellable(sectionsOp);
2060 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2061 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2062 moduleTranslation.getOpenMPBuilder()->createSections(
2063 ompLoc, allocaIP, sectionCBs, privCB, finiCB, isCancellable,
2064 sectionsOp.getNowait());
2065
2066 if (failed(handleError(afterIP, opInst)))
2067 return failure();
2068
2069 builder.restoreIP(*afterIP);
2070
2071 // Process the reductions if required.
2073 sectionsOp, builder, moduleTranslation, allocaIP, reductionDecls,
2074 privateReductionVariables, isByRef, sectionsOp.getNowait());
2075}
2076
2077/// Converts an OpenMP scope construct into LLVM IR.
2078static LogicalResult
2079convertOmpScope(omp::ScopeOp &scopeOp, llvm::IRBuilderBase &builder,
2080 LLVM::ModuleTranslation &moduleTranslation) {
2081 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2082 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
2083
2084 if (failed(checkImplementationStatus(*scopeOp)))
2085 return failure();
2086
2087 llvm::ArrayRef<bool> isByRef = getIsByRef(scopeOp.getReductionByref());
2088 assert(isByRef.size() == scopeOp.getNumReductionVars());
2089
2090 PrivateVarsInfo privateVarsInfo(scopeOp);
2091
2093 collectReductionDecls(scopeOp, reductionDecls);
2094 InsertPointTy allocaIP = findAllocInsertPoints(builder, moduleTranslation);
2095
2096 SmallVector<llvm::Value *> privateReductionVariables(
2097 scopeOp.getNumReductionVars());
2098 DenseMap<Value, llvm::Value *> reductionVariableMap;
2099
2100 MutableArrayRef<BlockArgument> reductionArgs =
2101 cast<omp::BlockArgOpenMPOpInterface>(*scopeOp).getReductionBlockArgs();
2102
2103 // Allocate private vars before the scope body
2105 scopeOp, builder, moduleTranslation, privateVarsInfo, allocaIP);
2106 if (failed(handleError(afterAllocas, *scopeOp)))
2107 return failure();
2108
2110 scopeOp, reductionArgs, builder, moduleTranslation, allocaIP,
2111 reductionDecls, privateReductionVariables, reductionVariableMap,
2112 isByRef)))
2113 return failure();
2114
2115 auto bodyCB =
2116 [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
2117 llvm::ArrayRef<llvm::BasicBlock *> deallocBlocks) -> llvm::Error {
2118 builder.restoreIP(codeGenIP);
2119
2120 if (handleError(
2121 initPrivateVars(builder, moduleTranslation, privateVarsInfo),
2122 *scopeOp)
2123 .failed())
2124 return llvm::make_error<PreviouslyReportedError>();
2125
2126 if (failed(copyFirstPrivateVars(
2127 scopeOp, builder, moduleTranslation, privateVarsInfo.mlirVars,
2128 privateVarsInfo.llvmVars, privateVarsInfo.privatizers,
2129 scopeOp.getPrivateNeedsBarrier())))
2130 return llvm::make_error<PreviouslyReportedError>();
2131
2132 return convertOmpOpRegions(scopeOp.getRegion(), "omp.scope.region", builder,
2133 moduleTranslation)
2134 .takeError();
2135 };
2136
2137 auto finiCB = [&](InsertPointTy codeGenIP) -> llvm::Error {
2138 InsertPointTy oldIP = builder.saveIP();
2139 builder.restoreIP(codeGenIP);
2140 if (failed(cleanupPrivateVars(scopeOp, builder, moduleTranslation,
2141 scopeOp.getLoc(), privateVarsInfo)))
2142 return llvm::make_error<PreviouslyReportedError>();
2143 builder.restoreIP(oldIP);
2144 return llvm::Error::success();
2145 };
2146
2147 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2148 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2149 ompBuilder->createScope(ompLoc, bodyCB, finiCB, scopeOp.getNowait());
2150
2151 if (failed(handleError(afterIP, *scopeOp)))
2152 return failure();
2153
2154 builder.restoreIP(*afterIP);
2155
2156 // Process the reductions if required.
2158 scopeOp, builder, moduleTranslation, allocaIP, reductionDecls,
2159 privateReductionVariables, isByRef, scopeOp.getNowait(),
2160 /*isTeamsReduction=*/false);
2161}
2162
2163/// Converts an OpenMP single construct into LLVM IR using OpenMPIRBuilder.
2164static LogicalResult
2165convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder,
2166 LLVM::ModuleTranslation &moduleTranslation) {
2167 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2168 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2169
2170 if (failed(checkImplementationStatus(*singleOp)))
2171 return failure();
2172
2173 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP,
2174 llvm::ArrayRef<llvm::BasicBlock *> deallocBlocks) {
2175 builder.restoreIP(codegenIP);
2176 return convertOmpOpRegions(singleOp.getRegion(), "omp.single.region",
2177 builder, moduleTranslation)
2178 .takeError();
2179 };
2180 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
2181
2182 // Handle copyprivate
2183 Operation::operand_range cpVars = singleOp.getCopyprivateVars();
2184 std::optional<ArrayAttr> cpFuncs = singleOp.getCopyprivateSyms();
2187 for (size_t i = 0, e = cpVars.size(); i < e; ++i) {
2188 llvmCPVars.push_back(moduleTranslation.lookupValue(cpVars[i]));
2190 singleOp, cast<SymbolRefAttr>((*cpFuncs)[i]));
2191 llvmCPFuncs.push_back(
2192 moduleTranslation.lookupFunction(llvmFuncOp.getName()));
2193 }
2194
2195 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2196 moduleTranslation.getOpenMPBuilder()->createSingle(
2197 ompLoc, bodyCB, finiCB, singleOp.getNowait(), llvmCPVars,
2198 llvmCPFuncs);
2199
2200 if (failed(handleError(afterIP, *singleOp)))
2201 return failure();
2202
2203 builder.restoreIP(*afterIP);
2204 return success();
2205}
2206
2207static omp::DistributeOp
2209 // Early return if we found more than one distribute op or if we can't find
2210 // any distribute op in the teams region.
2211 omp::DistributeOp distOp;
2212 WalkResult walk = teamsOp.getRegion().walk([&](omp::DistributeOp op) {
2213 if (distOp)
2214 return WalkResult::interrupt();
2215 distOp = op;
2216 return WalkResult::skip();
2217 });
2218 if (walk.wasInterrupted() || !distOp)
2219 return {};
2220
2221 auto iface =
2222 llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(teamsOp.getOperation());
2223 // Check that all uses of the reduction block arg has the same distribute op
2224 // parent.
2226 for (auto ra : iface.getReductionBlockArgs())
2227 for (auto &use : ra.getUses()) {
2228 auto *useOp = use.getOwner();
2229 // Ignore debug uses.
2230 if (mlir::isa<LLVM::DbgDeclareOp, LLVM::DbgValueOp>(useOp)) {
2231 debugUses.push_back(useOp);
2232 continue;
2233 }
2234 if (!distOp->isProperAncestor(useOp))
2235 return {};
2236 }
2237
2238 // If we are going to use distribute reduction then remove any debug uses of
2239 // the reduction parameters in teamsOp. Otherwise they will be left without
2240 // any mapped value in moduleTranslation and will eventually error out.
2241 for (auto *use : debugUses)
2242 use->erase();
2243 return distOp;
2244}
2245
2246// Convert an OpenMP Teams construct to LLVM IR using OpenMPIRBuilder
2247static LogicalResult
2248convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder,
2249 LLVM::ModuleTranslation &moduleTranslation) {
2250 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2251 if (failed(checkImplementationStatus(*op)))
2252 return failure();
2253
2254 DenseMap<Value, llvm::Value *> reductionVariableMap;
2255 unsigned numReductionVars = op.getNumReductionVars();
2257 SmallVector<llvm::Value *> privateReductionVariables(numReductionVars);
2258 llvm::ArrayRef<bool> isByRef;
2259 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2260 findAllocInsertPoints(builder, moduleTranslation);
2261
2262 // Only do teams reduction if there is no distribute op that captures the
2263 // reduction instead.
2264 bool doTeamsReduction = !getDistributeCapturingTeamsReduction(op);
2265 if (doTeamsReduction) {
2266 isByRef = getIsByRef(op.getReductionByref());
2267
2268 assert(isByRef.size() == op.getNumReductionVars());
2269
2270 MutableArrayRef<BlockArgument> reductionArgs =
2271 llvm::cast<omp::BlockArgOpenMPOpInterface>(*op).getReductionBlockArgs();
2272
2273 collectReductionDecls(op, reductionDecls);
2274
2276 op, reductionArgs, builder, moduleTranslation, allocaIP,
2277 reductionDecls, privateReductionVariables, reductionVariableMap,
2278 isByRef)))
2279 return failure();
2280 }
2281
2282 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP,
2283 llvm::ArrayRef<llvm::BasicBlock *> deallocBlocks) {
2285 moduleTranslation, allocaIP, deallocBlocks);
2286 builder.restoreIP(codegenIP);
2287 return convertOmpOpRegions(op.getRegion(), "omp.teams.region", builder,
2288 moduleTranslation)
2289 .takeError();
2290 };
2291
2292 llvm::Value *numTeamsLower = nullptr;
2293 if (Value numTeamsLowerVar = op.getNumTeamsLower())
2294 numTeamsLower = moduleTranslation.lookupValue(numTeamsLowerVar);
2295
2296 llvm::Value *numTeamsUpper = nullptr;
2297 if (!op.getNumTeamsUpperVars().empty())
2298 numTeamsUpper = moduleTranslation.lookupValue(op.getNumTeams(0));
2299
2300 llvm::Value *threadLimit = nullptr;
2301 if (!op.getThreadLimitVars().empty())
2302 threadLimit = moduleTranslation.lookupValue(op.getThreadLimit(0));
2303
2304 llvm::Value *ifExpr = nullptr;
2305 if (Value ifVar = op.getIfExpr())
2306 ifExpr = moduleTranslation.lookupValue(ifVar);
2307
2308 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2309 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2310 moduleTranslation.getOpenMPBuilder()->createTeams(
2311 ompLoc, bodyCB, numTeamsLower, numTeamsUpper, threadLimit, ifExpr);
2312
2313 if (failed(handleError(afterIP, *op)))
2314 return failure();
2315
2316 builder.restoreIP(*afterIP);
2317 if (doTeamsReduction) {
2318 // Process the reductions if required.
2320 op, builder, moduleTranslation, allocaIP, reductionDecls,
2321 privateReductionVariables, isByRef,
2322 /*isNoWait*/ false, /*isTeamsReduction*/ true);
2323 }
2324 return success();
2325}
2326
2327static llvm::omp::RTLDependenceKindTy
2328convertDependKind(mlir::omp::ClauseTaskDepend kind) {
2329 switch (kind) {
2330 case mlir::omp::ClauseTaskDepend::taskdependin:
2331 return llvm::omp::RTLDependenceKindTy::DepIn;
2332 // The OpenMP runtime requires that the codegen for 'depend' clause for
2333 // 'out' dependency kind must be the same as codegen for 'depend' clause
2334 // with 'inout' dependency.
2335 case mlir::omp::ClauseTaskDepend::taskdependout:
2336 case mlir::omp::ClauseTaskDepend::taskdependinout:
2337 return llvm::omp::RTLDependenceKindTy::DepInOut;
2338 case mlir::omp::ClauseTaskDepend::taskdependmutexinoutset:
2339 return llvm::omp::RTLDependenceKindTy::DepMutexInOutSet;
2340 case mlir::omp::ClauseTaskDepend::taskdependinoutset:
2341 return llvm::omp::RTLDependenceKindTy::DepInOutSet;
2342 }
2343 llvm_unreachable("unhandled depend kind");
2344}
2345
2347 std::optional<ArrayAttr> dependKinds, OperandRange dependVars,
2348 LLVM::ModuleTranslation &moduleTranslation,
2350 if (dependVars.empty())
2351 return;
2352 for (auto dep : llvm::zip(dependVars, dependKinds->getValue())) {
2353 auto kind =
2354 cast<mlir::omp::ClauseTaskDependAttr>(std::get<1>(dep)).getValue();
2355 llvm::omp::RTLDependenceKindTy type = convertDependKind(kind);
2356 llvm::Value *depVal = moduleTranslation.lookupValue(std::get<0>(dep));
2357 llvm::OpenMPIRBuilder::DependData dd(type, depVal->getType(), depVal);
2358 dds.emplace_back(dd);
2359 }
2360}
2361
2362/// Shared implementation of a callback which adds a termiator for the new block
2363/// created for the branch taken when an openmp construct is cancelled. The
2364/// terminator is saved in \p cancelTerminators. This callback is invoked only
2365/// if there is cancellation inside of the taskgroup body.
2366/// The terminator will need to be fixed to branch to the correct block to
2367/// cleanup the construct.
2369 SmallVectorImpl<llvm::UncondBrInst *> &cancelTerminators,
2370 llvm::IRBuilderBase &llvmBuilder, llvm::OpenMPIRBuilder &ompBuilder,
2371 mlir::Operation *op, llvm::omp::Directive cancelDirective) {
2372 auto finiCB = [&](llvm::OpenMPIRBuilder::InsertPointTy ip) -> llvm::Error {
2373 llvm::IRBuilderBase::InsertPointGuard guard(llvmBuilder);
2374
2375 // ip is currently in the block branched to if cancellation occurred.
2376 // We need to create a branch to terminate that block.
2377 llvmBuilder.restoreIP(ip);
2378
2379 // We must still clean up the construct after cancelling it, so we need to
2380 // branch to the block that finalizes the taskgroup.
2381 // That block has not been created yet so use this block as a dummy for now
2382 // and fix this after creating the operation.
2383 cancelTerminators.push_back(llvmBuilder.CreateBr(ip.getBlock()));
2384 return llvm::Error::success();
2385 };
2386 // We have to add the cleanup to the OpenMPIRBuilder before the body gets
2387 // created in case the body contains omp.cancel (which will then expect to be
2388 // able to find this cleanup callback).
2389 ompBuilder.pushFinalizationCB(
2390 {finiCB, cancelDirective, constructIsCancellable(op)});
2391}
2392
2393/// If we cancelled the construct, we should branch to the finalization block of
2394/// that construct. OMPIRBuilder structures the CFG such that the cleanup block
2395/// is immediately before the continuation block. Now this finalization has
2396/// been created we can fix the branch.
2397static void
2399 llvm::OpenMPIRBuilder &ompBuilder,
2400 const llvm::OpenMPIRBuilder::InsertPointTy &afterIP) {
2401 ompBuilder.popFinalizationCB();
2402 llvm::BasicBlock *constructFini = afterIP.getBlock()->getSinglePredecessor();
2403 for (llvm::UncondBrInst *cancelBranch : cancelTerminators)
2404 cancelBranch->setSuccessor(constructFini);
2405}
2406
2407namespace {
2408/// TaskContextStructManager takes care of creating and freeing a structure
2409/// containing information needed by the task body to execute.
2410class TaskContextStructManager {
2411public:
2412 TaskContextStructManager(llvm::IRBuilderBase &builder,
2413 LLVM::ModuleTranslation &moduleTranslation,
2414 MutableArrayRef<omp::PrivateClauseOp> privateDecls)
2415 : builder{builder}, moduleTranslation{moduleTranslation},
2416 privateDecls{privateDecls} {}
2417
2418 /// Creates a heap allocated struct containing space for each private
2419 /// variable. Invariant: privateVarTypes, privateDecls, and the elements of
2420 /// the structure should all have the same order (although privateDecls which
2421 /// do not read from the mold argument are skipped).
2422 void generateTaskContextStruct();
2423
2424 /// Create GEPs to access each member of the structure representing a private
2425 /// variable, adding them to llvmPrivateVars. Null values are added where
2426 /// private decls were skipped so that the ordering continues to match the
2427 /// private decls.
2428 void createGEPsToPrivateVars();
2429
2430 /// Given the address of the structure, return a GEP for each private variable
2431 /// in the structure. Null values are added where private decls were skipped
2432 /// so that the ordering continues to match the private decls.
2433 /// Must be called after generateTaskContextStruct().
2434 SmallVector<llvm::Value *>
2435 createGEPsToPrivateVars(llvm::Value *altStructPtr) const;
2436
2437 /// De-allocate the task context structure.
2438 void freeStructPtr();
2439
2440 MutableArrayRef<llvm::Value *> getLLVMPrivateVarGEPs() {
2441 return llvmPrivateVarGEPs;
2442 }
2443
2444 llvm::Value *getStructPtr() { return structPtr; }
2445
2446private:
2447 llvm::IRBuilderBase &builder;
2448 LLVM::ModuleTranslation &moduleTranslation;
2449 MutableArrayRef<omp::PrivateClauseOp> privateDecls;
2450
2451 /// The type of each member of the structure, in order.
2452 SmallVector<llvm::Type *> privateVarTypes;
2453
2454 /// LLVM values for each private variable, or null if that private variable is
2455 /// not included in the task context structure
2456 SmallVector<llvm::Value *> llvmPrivateVarGEPs;
2457
2458 /// A pointer to the structure containing context for this task.
2459 llvm::Value *structPtr = nullptr;
2460 /// The type of the structure
2461 llvm::Type *structTy = nullptr;
2462};
2463
2464/// IteratorInfo extracts and prepares loop bounds information from an
2465/// mlir::omp::IteratorOp for lowering to LLVM IR.
2466///
2467/// It computes the per-dimension trip counts and the total linearized trip
2468/// count, casted to i64. These are used to build a canonical loop and to
2469/// reconstruct the physical induction variables inside the loop body.
2470class IteratorInfo {
2471private:
2472 llvm::SmallVector<llvm::Value *> lowerBounds;
2473 llvm::SmallVector<llvm::Value *> upperBounds;
2474 llvm::SmallVector<llvm::Value *> steps;
2475 llvm::SmallVector<llvm::Value *> trips;
2476 unsigned dims;
2477 llvm::Value *totalTrips;
2478
2479 llvm::Value *lookUpAsI64(mlir::Value val, const LLVM::ModuleTranslation &mt,
2480 llvm::IRBuilderBase &builder) {
2481 llvm::Value *v = mt.lookupValue(val);
2482 if (!v)
2483 return nullptr;
2484 if (v->getType()->isIntegerTy(64))
2485 return v;
2486 if (v->getType()->isIntegerTy())
2487 return builder.CreateSExtOrTrunc(v, builder.getInt64Ty());
2488 return nullptr;
2489 }
2490
2491public:
2492 IteratorInfo(mlir::omp::IteratorOp itersOp,
2493 mlir::LLVM::ModuleTranslation &moduleTranslation,
2494 llvm::IRBuilderBase &builder) {
2495 dims = itersOp.getLoopLowerBounds().size();
2496 lowerBounds.resize(dims);
2497 upperBounds.resize(dims);
2498 steps.resize(dims);
2499 trips.resize(dims);
2500
2501 for (unsigned d = 0; d < dims; ++d) {
2502 llvm::Value *lb = lookUpAsI64(itersOp.getLoopLowerBounds()[d],
2503 moduleTranslation, builder);
2504 llvm::Value *ub = lookUpAsI64(itersOp.getLoopUpperBounds()[d],
2505 moduleTranslation, builder);
2506 llvm::Value *st =
2507 lookUpAsI64(itersOp.getLoopSteps()[d], moduleTranslation, builder);
2508 assert(lb && ub && st &&
2509 "Expect lowerBounds, upperBounds, and steps in IteratorOp");
2510 assert((!llvm::isa<llvm::ConstantInt>(st) ||
2511 !llvm::cast<llvm::ConstantInt>(st)->isZero()) &&
2512 "Expect non-zero step in IteratorOp");
2513
2514 lowerBounds[d] = lb;
2515 upperBounds[d] = ub;
2516 steps[d] = st;
2517
2518 // trips = ((ub - lb) / step) + 1 (inclusive ub, assume positive step)
2519 llvm::Value *diff = builder.CreateSub(ub, lb);
2520 llvm::Value *div = builder.CreateSDiv(diff, st);
2521 trips[d] = builder.CreateAdd(
2522 div, llvm::ConstantInt::get(builder.getInt64Ty(), 1));
2523 }
2524
2525 totalTrips = llvm::ConstantInt::get(builder.getInt64Ty(), 1);
2526 for (unsigned d = 0; d < dims; ++d)
2527 totalTrips = builder.CreateMul(totalTrips, trips[d]);
2528 }
2529
2530 unsigned getDims() const { return dims; }
2531 llvm::ArrayRef<llvm::Value *> getLowerBounds() const { return lowerBounds; }
2532 llvm::ArrayRef<llvm::Value *> getUpperBounds() const { return upperBounds; }
2533 llvm::ArrayRef<llvm::Value *> getSteps() const { return steps; }
2534 llvm::ArrayRef<llvm::Value *> getTrips() const { return trips; }
2535 llvm::Value *getTotalTrips() const { return totalTrips; }
2536};
2537
2538} // namespace
2539
2540void TaskContextStructManager::generateTaskContextStruct() {
2541 if (privateDecls.empty())
2542 return;
2543 privateVarTypes.reserve(privateDecls.size());
2544
2545 for (omp::PrivateClauseOp &privOp : privateDecls) {
2546 // Skip private variables which can safely be allocated and initialised
2547 // inside of the task
2548 if (!privOp.readsFromMold())
2549 continue;
2550 Type mlirType = privOp.getType();
2551 privateVarTypes.push_back(moduleTranslation.convertType(mlirType));
2552 }
2553
2554 if (privateVarTypes.empty())
2555 return;
2556
2557 structTy = llvm::StructType::get(moduleTranslation.getLLVMContext(),
2558 privateVarTypes);
2559
2560 llvm::DataLayout dataLayout =
2561 builder.GetInsertBlock()->getModule()->getDataLayout();
2562 llvm::Type *intPtrTy = builder.getIntPtrTy(dataLayout);
2563 llvm::Constant *allocSize = llvm::ConstantExpr::getSizeOf(structTy);
2564
2565 // Heap allocate the structure
2566 structPtr = builder.CreateMalloc(intPtrTy, structTy, allocSize,
2567 /*ArraySize=*/nullptr, /*MallocF=*/nullptr,
2568 "omp.task.context_ptr");
2569}
2570
2571SmallVector<llvm::Value *> TaskContextStructManager::createGEPsToPrivateVars(
2572 llvm::Value *altStructPtr) const {
2573 SmallVector<llvm::Value *> ret;
2574
2575 // Create GEPs for each struct member
2576 ret.reserve(privateDecls.size());
2577 llvm::Value *zero = builder.getInt32(0);
2578 unsigned i = 0;
2579 for (auto privDecl : privateDecls) {
2580 if (!privDecl.readsFromMold()) {
2581 // Handle this inside of the task so we don't pass unnessecary vars in
2582 ret.push_back(nullptr);
2583 continue;
2584 }
2585 llvm::Value *iVal = builder.getInt32(i);
2586 llvm::Value *gep = builder.CreateGEP(structTy, altStructPtr, {zero, iVal});
2587 ret.push_back(gep);
2588 i += 1;
2589 }
2590 return ret;
2591}
2592
2593void TaskContextStructManager::createGEPsToPrivateVars() {
2594 if (!structPtr)
2595 assert(privateVarTypes.empty());
2596 // Still need to run createGEPsToPrivateVars to populate llvmPrivateVarGEPs
2597 // with null values for skipped private decls
2598
2599 llvmPrivateVarGEPs = createGEPsToPrivateVars(structPtr);
2600}
2601
2602void TaskContextStructManager::freeStructPtr() {
2603 if (!structPtr)
2604 return;
2605
2606 llvm::IRBuilderBase::InsertPointGuard guard{builder};
2607 // Ensure we don't put the call to free() after the terminator
2608 builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
2609 builder.CreateFree(structPtr);
2610}
2611
2612static void storeAffinityEntry(llvm::IRBuilderBase &builder,
2613 llvm::OpenMPIRBuilder &ompBuilder,
2614 llvm::Value *affinityList, llvm::Value *index,
2615 llvm::Value *addr, llvm::Value *len) {
2616 llvm::StructType *kmpTaskAffinityInfoTy =
2617 ompBuilder.getKmpTaskAffinityInfoTy();
2618 llvm::Value *entry = builder.CreateInBoundsGEP(
2619 kmpTaskAffinityInfoTy, affinityList, index, "omp.affinity.entry");
2620
2621 addr = builder.CreatePtrToInt(addr, kmpTaskAffinityInfoTy->getElementType(0));
2622 len = builder.CreateIntCast(len, kmpTaskAffinityInfoTy->getElementType(1),
2623 /*isSigned=*/false);
2624 llvm::Value *flags = builder.getInt32(0);
2625
2626 builder.CreateStore(addr,
2627 builder.CreateStructGEP(kmpTaskAffinityInfoTy, entry, 0));
2628 builder.CreateStore(len,
2629 builder.CreateStructGEP(kmpTaskAffinityInfoTy, entry, 1));
2630 builder.CreateStore(flags,
2631 builder.CreateStructGEP(kmpTaskAffinityInfoTy, entry, 2));
2632}
2633
2635 llvm::IRBuilderBase &builder,
2636 LLVM::ModuleTranslation &moduleTranslation,
2637 llvm::Value *affinityList) {
2638 for (auto [i, affinityVar] : llvm::enumerate(affinityVars)) {
2639 auto entryOp = affinityVar.getDefiningOp<mlir::omp::AffinityEntryOp>();
2640 assert(entryOp && "affinity item must be omp.affinity_entry");
2641
2642 llvm::Value *addr = moduleTranslation.lookupValue(entryOp.getAddr());
2643 llvm::Value *len = moduleTranslation.lookupValue(entryOp.getLen());
2644 assert(addr && len && "expect affinity addr and len to be non-null");
2645 storeAffinityEntry(builder, *moduleTranslation.getOpenMPBuilder(),
2646 affinityList, builder.getInt64(i), addr, len);
2647 }
2648}
2649
2650static mlir::LogicalResult
2651convertIteratorRegion(llvm::Value *linearIV, IteratorInfo &iterInfo,
2652 mlir::Block &iteratorRegionBlock,
2653 llvm::IRBuilderBase &builder,
2654 LLVM::ModuleTranslation &moduleTranslation) {
2655 llvm::Value *tmp = linearIV;
2656 for (int d = (int)iterInfo.getDims() - 1; d >= 0; --d) {
2657 llvm::Value *trip = iterInfo.getTrips()[d];
2658 // idx_d = tmp % trip_d
2659 llvm::Value *idx = builder.CreateURem(tmp, trip);
2660 // tmp = tmp / trip_d
2661 tmp = builder.CreateUDiv(tmp, trip);
2662
2663 // physIV_d = lb_d + idx_d * step_d
2664 llvm::Value *physIV = builder.CreateAdd(
2665 iterInfo.getLowerBounds()[d],
2666 builder.CreateMul(idx, iterInfo.getSteps()[d]), "omp.it.phys_iv");
2667
2668 moduleTranslation.mapValue(iteratorRegionBlock.getArgument(d), physIV);
2669 }
2670
2671 // Translate the iterator region into the loop body.
2672 moduleTranslation.mapBlock(&iteratorRegionBlock, builder.GetInsertBlock());
2673 if (mlir::failed(moduleTranslation.convertBlock(iteratorRegionBlock,
2674 /*ignoreArguments=*/true,
2675 builder))) {
2676 return mlir::failure();
2677 }
2678 return mlir::success();
2679}
2680
2682 llvm::function_ref<void(llvm::Value *linearIV, mlir::omp::YieldOp yield)>;
2683
2684static mlir::LogicalResult
2685fillIteratorLoop(mlir::omp::IteratorOp itersOp, llvm::IRBuilderBase &builder,
2686 mlir::LLVM::ModuleTranslation &moduleTranslation,
2687 IteratorInfo &iterInfo, llvm::StringRef loopName,
2688 IteratorStoreEntryTy genStoreEntry) {
2689 mlir::Region &itersRegion = itersOp.getRegion();
2690 mlir::Block &iteratorRegionBlock = itersRegion.front();
2691
2692 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
2693
2694 auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy bodyIP,
2695 llvm::Value *linearIV) -> llvm::Error {
2696 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2697 builder.restoreIP(bodyIP);
2698
2699 if (failed(convertIteratorRegion(linearIV, iterInfo, iteratorRegionBlock,
2700 builder, moduleTranslation))) {
2701 return llvm::make_error<llvm::StringError>(
2702 "failed to convert iterator region", llvm::inconvertibleErrorCode());
2703 }
2704
2705 auto yield =
2706 mlir::dyn_cast<mlir::omp::YieldOp>(iteratorRegionBlock.getTerminator());
2707 assert(yield && yield.getResults().size() == 1 &&
2708 "expect omp.yield in iterator region to have one result");
2709
2710 genStoreEntry(linearIV, yield);
2711
2712 // Iterator-region block/value mappings are temporary for this conversion,
2713 // clear them to avoid stale entries in ModuleTranslation.
2714 moduleTranslation.forgetMapping(itersRegion);
2715
2716 return llvm::Error::success();
2717 };
2718
2719 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2720 moduleTranslation.getOpenMPBuilder()->createIteratorLoop(
2721 loc, iterInfo.getTotalTrips(), bodyGen, loopName);
2722 if (failed(handleError(afterIP, *itersOp)))
2723 return failure();
2724
2725 builder.restoreIP(*afterIP);
2726
2727 return mlir::success();
2728}
2729
2730static mlir::LogicalResult
2731buildAffinityData(mlir::omp::TaskOp &taskOp, llvm::IRBuilderBase &builder,
2732 mlir::LLVM::ModuleTranslation &moduleTranslation,
2733 llvm::OpenMPIRBuilder::AffinityData &ad) {
2734
2735 if (taskOp.getAffinityVars().empty() && taskOp.getIterated().empty()) {
2736 ad.Count = nullptr;
2737 ad.Info = nullptr;
2738 return mlir::success();
2739 }
2740
2742 llvm::StructType *kmpTaskAffinityInfoTy =
2743 moduleTranslation.getOpenMPBuilder()->getKmpTaskAffinityInfoTy();
2744
2745 auto allocateAffinityList = [&](llvm::Value *count) -> llvm::Value * {
2746 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2747 if (llvm::isa<llvm::Constant>(count) || llvm::isa<llvm::Argument>(count))
2748 builder.restoreIP(findAllocInsertPoints(builder, moduleTranslation));
2749 return builder.CreateAlloca(kmpTaskAffinityInfoTy, count,
2750 "omp.affinity_list");
2751 };
2752
2753 auto createAffinity =
2754 [&](llvm::Value *count,
2755 llvm::Value *info) -> llvm::OpenMPIRBuilder::AffinityData {
2756 llvm::OpenMPIRBuilder::AffinityData ad{};
2757 ad.Count = builder.CreateTrunc(count, builder.getInt32Ty());
2758 ad.Info =
2759 builder.CreatePointerBitCastOrAddrSpaceCast(info, builder.getPtrTy(0));
2760 return ad;
2761 };
2762
2763 if (!taskOp.getAffinityVars().empty()) {
2764 llvm::Value *count = llvm::ConstantInt::get(
2765 builder.getInt64Ty(), taskOp.getAffinityVars().size());
2766 llvm::Value *list = allocateAffinityList(count);
2767 fillAffinityLocators(taskOp.getAffinityVars(), builder, moduleTranslation,
2768 list);
2769 ads.emplace_back(createAffinity(count, list));
2770 }
2771
2772 if (!taskOp.getIterated().empty()) {
2773 for (auto [i, iter] : llvm::enumerate(taskOp.getIterated())) {
2774 auto itersOp = iter.getDefiningOp<omp::IteratorOp>();
2775 assert(itersOp && "iterated value must be defined by omp.iterator");
2776 IteratorInfo iterInfo(itersOp, moduleTranslation, builder);
2777 llvm::Value *affList = allocateAffinityList(iterInfo.getTotalTrips());
2778 if (failed(fillIteratorLoop(
2779 itersOp, builder, moduleTranslation, iterInfo, "iterator",
2780 [&](llvm::Value *linearIV, mlir::omp::YieldOp yield) {
2781 auto entryOp = yield.getResults()[0]
2782 .getDefiningOp<mlir::omp::AffinityEntryOp>();
2783 assert(entryOp && "expect yield produce an affinity entry");
2784 llvm::Value *addr =
2785 moduleTranslation.lookupValue(entryOp.getAddr());
2786 llvm::Value *len =
2787 moduleTranslation.lookupValue(entryOp.getLen());
2788 storeAffinityEntry(builder,
2789 *moduleTranslation.getOpenMPBuilder(),
2790 affList, linearIV, addr, len);
2791 })))
2792 return llvm::failure();
2793 ads.emplace_back(createAffinity(iterInfo.getTotalTrips(), affList));
2794 }
2795 }
2796
2797 llvm::Value *totalAffinityCount = builder.getInt32(0);
2798 for (const auto &affinity : ads)
2799 totalAffinityCount = builder.CreateAdd(
2800 totalAffinityCount,
2801 builder.CreateIntCast(affinity.Count, builder.getInt32Ty(),
2802 /*isSigned=*/false));
2803
2804 llvm::Value *affinityInfo = ads.front().Info;
2805 if (ads.size() > 1) {
2806 llvm::StructType *kmpTaskAffinityInfoTy =
2807 moduleTranslation.getOpenMPBuilder()->getKmpTaskAffinityInfoTy();
2808 llvm::Value *affinityInfoElemSize = builder.getInt64(
2809 moduleTranslation.getLLVMModule()->getDataLayout().getTypeAllocSize(
2810 kmpTaskAffinityInfoTy));
2811
2812 llvm::Value *packedAffinityInfo = allocateAffinityList(totalAffinityCount);
2813 llvm::Value *packedAffinityInfoOffset = builder.getInt32(0);
2814 for (const auto &affinity : ads) {
2815 llvm::Value *affinityCount = builder.CreateIntCast(
2816 affinity.Count, builder.getInt32Ty(), /*isSigned=*/false);
2817 llvm::Value *affinityCountInt64 = builder.CreateIntCast(
2818 affinityCount, builder.getInt64Ty(), /*isSigned=*/false);
2819 llvm::Value *affinityInfoSize =
2820 builder.CreateMul(affinityCountInt64, affinityInfoElemSize);
2821
2822 llvm::Value *packedAffinityInfoIndex = builder.CreateIntCast(
2823 packedAffinityInfoOffset, kmpTaskAffinityInfoTy->getElementType(0),
2824 /*isSigned=*/false);
2825 packedAffinityInfoIndex = builder.CreateInBoundsGEP(
2826 kmpTaskAffinityInfoTy, packedAffinityInfo, packedAffinityInfoIndex);
2827
2828 builder.CreateMemCpy(
2829 packedAffinityInfoIndex, llvm::Align(1),
2830 builder.CreatePointerBitCastOrAddrSpaceCast(
2831 affinity.Info, builder.getPtrTy(packedAffinityInfoIndex->getType()
2832 ->getPointerAddressSpace())),
2833 llvm::Align(1), affinityInfoSize);
2834
2835 packedAffinityInfoOffset =
2836 builder.CreateAdd(packedAffinityInfoOffset, affinityCount);
2837 }
2838
2839 affinityInfo = packedAffinityInfo;
2840 }
2841
2842 ad.Count = totalAffinityCount;
2843 ad.Info = affinityInfo;
2844
2845 return mlir::success();
2846}
2847
2848// Allocates a single kmp_dep_info array sized to hold both locator
2849// (non-iterated) and iterated entries, fills the locator entries first, then
2850// runs an iterator loop for each iterator modifier object.
2851static mlir::LogicalResult
2852buildDependData(OperandRange dependVars, std::optional<ArrayAttr> dependKinds,
2853 OperandRange dependIterated,
2854 std::optional<ArrayAttr> dependIteratedKinds,
2855 llvm::IRBuilderBase &builder,
2856 mlir::LLVM::ModuleTranslation &moduleTranslation,
2857 llvm::OpenMPIRBuilder::DependenciesInfo &taskDeps) {
2858 if (dependIterated.empty()) {
2859 buildDependDataLocator(dependKinds, dependVars, moduleTranslation,
2860 taskDeps.Deps);
2861 return mlir::success();
2862 }
2863
2864 llvm::OpenMPIRBuilder &ompBuilder = *moduleTranslation.getOpenMPBuilder();
2865 llvm::Type *dependInfoTy = ompBuilder.DependInfo;
2866 unsigned numLocator = dependVars.size();
2867
2868 // Compute total count: locator deps + sum of iterator trip counts.
2869 llvm::Value *totalCount =
2870 llvm::ConstantInt::get(builder.getInt64Ty(), numLocator);
2871
2873 for (auto iter : dependIterated) {
2874 auto itersOp = iter.getDefiningOp<mlir::omp::IteratorOp>();
2875 assert(itersOp && "depend_iterated value must be defined by omp.iterator");
2876 iterInfos.emplace_back(itersOp, moduleTranslation, builder);
2877 totalCount =
2878 builder.CreateAdd(totalCount, iterInfos.back().getTotalTrips());
2879 }
2880
2881 // Heap-allocate the kmp_depend_info array so we don't risk
2882 // dynamic-sized alloca outside the entry block (e.g. inside loops).
2883 llvm::Constant *allocSize = llvm::ConstantExpr::getSizeOf(dependInfoTy);
2884 llvm::Value *depArray =
2885 builder.CreateMalloc(ompBuilder.SizeTy, dependInfoTy, allocSize,
2886 totalCount, /*MallocF=*/nullptr, ".dep.arr.addr");
2887
2888 // Fill non-iterated entries at indices [0, numLocator).
2889 if (numLocator > 0) {
2891 buildDependDataLocator(dependKinds, dependVars, moduleTranslation, dds);
2892 for (auto [i, dd] : llvm::enumerate(dds)) {
2893 llvm::Value *idx = llvm::ConstantInt::get(builder.getInt64Ty(), i);
2894 llvm::Value *entry =
2895 builder.CreateInBoundsGEP(dependInfoTy, depArray, idx);
2896 ompBuilder.emitTaskDependency(builder, entry, dd);
2897 }
2898 }
2899
2900 // Fill iterated entries starting at index numLocator.
2901 llvm::Value *offset =
2902 llvm::ConstantInt::get(builder.getInt64Ty(), numLocator);
2903 for (auto [i, iterInfo] : llvm::enumerate(iterInfos)) {
2904 auto kindAttr = cast<mlir::omp::ClauseTaskDependAttr>(
2905 dependIteratedKinds->getValue()[i]);
2906 llvm::omp::RTLDependenceKindTy rtlKind =
2907 convertDependKind(kindAttr.getValue());
2908
2909 auto itersOp = dependIterated[i].getDefiningOp<mlir::omp::IteratorOp>();
2910 if (failed(fillIteratorLoop(
2911 itersOp, builder, moduleTranslation, iterInfo, "dep_iterator",
2912 [&](llvm::Value *linearIV, mlir::omp::YieldOp yield) {
2913 llvm::Value *addr =
2914 moduleTranslation.lookupValue(yield.getResults()[0]);
2915 llvm::Value *idx = builder.CreateAdd(offset, linearIV);
2916 llvm::Value *entry =
2917 builder.CreateInBoundsGEP(dependInfoTy, depArray, idx);
2918 ompBuilder.emitTaskDependency(
2919 builder, entry,
2920 llvm::OpenMPIRBuilder::DependData{rtlKind, addr->getType(),
2921 addr});
2922 })))
2923 return mlir::failure();
2924
2925 // Advance offset by the trip count of this iterator.
2926 offset = builder.CreateAdd(offset, iterInfo.getTotalTrips());
2927 }
2928
2929 taskDeps.DepArray = depArray;
2930 taskDeps.NumDeps = builder.CreateTrunc(totalCount, builder.getInt32Ty());
2931 return mlir::success();
2932}
2933
2934/// Converts an OpenMP task construct into LLVM IR using OpenMPIRBuilder.
2935static LogicalResult
2936convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
2937 LLVM::ModuleTranslation &moduleTranslation) {
2938 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2939 if (failed(checkImplementationStatus(*taskOp)))
2940 return failure();
2941
2942 PrivateVarsInfo privateVarsInfo(taskOp);
2943 TaskContextStructManager taskStructMgr{builder, moduleTranslation,
2944 privateVarsInfo.privatizers};
2945
2946 // Allocate and copy private variables before creating the task. This avoids
2947 // accessing invalid memory if (after this scope ends) the private variables
2948 // are initialized from host variables or if the variables are copied into
2949 // from host variables (firstprivate). The insertion point is just before
2950 // where the code for creating and scheduling the task will go. That puts this
2951 // code outside of the outlined task region, which is what we want because
2952 // this way the initialization and copy regions are executed immediately while
2953 // the host variable data are still live.
2955 InsertPointTy allocaIP =
2956 findAllocInsertPoints(builder, moduleTranslation, &deallocBlocks);
2957
2958 // Not using splitBB() because that requires the current block to have a
2959 // terminator.
2960 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end());
2961 llvm::BasicBlock *taskStartBlock = llvm::BasicBlock::Create(
2962 builder.getContext(), "omp.task.start",
2963 /*Parent=*/builder.GetInsertBlock()->getParent());
2964 llvm::Instruction *branchToTaskStartBlock = builder.CreateBr(taskStartBlock);
2965 builder.SetInsertPoint(branchToTaskStartBlock);
2966
2967 // Now do this again to make the initialization and copy blocks
2968 llvm::BasicBlock *copyBlock =
2969 splitBB(builder, /*CreateBranch=*/true, "omp.private.copy");
2970 llvm::BasicBlock *initBlock =
2971 splitBB(builder, /*CreateBranch=*/true, "omp.private.init");
2972
2973 // Now the control flow graph should look like
2974 // starter_block:
2975 // <---- where we started when convertOmpTaskOp was called
2976 // br %omp.private.init
2977 // omp.private.init:
2978 // br %omp.private.copy
2979 // omp.private.copy:
2980 // br %omp.task.start
2981 // omp.task.start:
2982 // <---- where we want the insertion point to be when we call createTask()
2983
2984 // Save the alloca insertion point on ModuleTranslation stack for use in
2985 // nested regions.
2987 moduleTranslation, allocaIP, deallocBlocks);
2988
2989 // Allocate and initialize private variables
2990 builder.SetInsertPoint(initBlock->getTerminator());
2991
2992 // Create task variable structure
2993 taskStructMgr.generateTaskContextStruct();
2994 // GEPs so that we can initialize the variables. Don't use these GEPs inside
2995 // of the body otherwise it will be the GEP not the struct which is fowarded
2996 // to the outlined function. GEPs forwarded in this way are passed in a
2997 // stack-allocated (by OpenMPIRBuilder) structure which is not safe for tasks
2998 // which may not be executed until after the current stack frame goes out of
2999 // scope.
3000 taskStructMgr.createGEPsToPrivateVars();
3001
3002 for (auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVarAlloc] :
3003 llvm::zip_equal(privateVarsInfo.privatizers, privateVarsInfo.mlirVars,
3004 privateVarsInfo.blockArgs,
3005 taskStructMgr.getLLVMPrivateVarGEPs())) {
3006 // To be handled inside the task.
3007 if (!privDecl.readsFromMold())
3008 continue;
3009 assert(llvmPrivateVarAlloc &&
3010 "reads from mold so shouldn't have been skipped");
3011
3012 llvm::Expected<llvm::Value *> privateVarOrErr =
3013 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
3014 blockArg, llvmPrivateVarAlloc, initBlock);
3015 if (!privateVarOrErr)
3016 return handleError(privateVarOrErr, *taskOp.getOperation());
3017
3019
3020 // TODO: this is a bit of a hack for Fortran character boxes.
3021 // Character boxes are passed by value into the init region and then the
3022 // initialized character box is yielded by value. Here we need to store the
3023 // yielded value into the private allocation, and load the private
3024 // allocation to match the type expected by region block arguments.
3025 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
3026 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
3027 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
3028 // Load it so we have the value pointed to by the GEP
3029 llvmPrivateVarAlloc = builder.CreateLoad(privateVarOrErr.get()->getType(),
3030 llvmPrivateVarAlloc);
3031 }
3032 assert(llvmPrivateVarAlloc->getType() ==
3033 moduleTranslation.convertType(blockArg.getType()));
3034
3035 // Mapping blockArg -> llvmPrivateVarAlloc is done inside the body callback
3036 // so that OpenMPIRBuilder doesn't try to pass each GEP address through a
3037 // stack allocated structure.
3038 }
3039
3040 // firstprivate copy region
3041 setInsertPointForPossiblyEmptyBlock(builder, copyBlock);
3042 if (failed(copyFirstPrivateVars(
3043 taskOp, builder, moduleTranslation, privateVarsInfo.mlirVars,
3044 taskStructMgr.getLLVMPrivateVarGEPs(), privateVarsInfo.privatizers,
3045 taskOp.getPrivateNeedsBarrier())))
3046 return llvm::failure();
3047
3048 llvm::OpenMPIRBuilder::AffinityData ad;
3049 if (failed(buildAffinityData(taskOp, builder, moduleTranslation, ad)))
3050 return llvm::failure();
3051
3052 // Set up for call to createTask()
3053 builder.SetInsertPoint(taskStartBlock);
3054
3055 auto bodyCB =
3056 [&](InsertPointTy allocaIP, InsertPointTy codegenIP,
3057 llvm::ArrayRef<llvm::BasicBlock *> deallocBlocks) -> llvm::Error {
3058 // Save the alloca insertion point on ModuleTranslation stack for use in
3059 // nested regions.
3061 moduleTranslation, allocaIP, deallocBlocks);
3062
3063 // translate the body of the task:
3064 builder.restoreIP(codegenIP);
3065
3066 llvm::BasicBlock *privInitBlock = nullptr;
3067 privateVarsInfo.llvmVars.resize(privateVarsInfo.blockArgs.size());
3068 for (auto [i, zip] : llvm::enumerate(llvm::zip_equal(
3069 privateVarsInfo.blockArgs, privateVarsInfo.privatizers,
3070 privateVarsInfo.mlirVars))) {
3071 auto [blockArg, privDecl, mlirPrivVar] = zip;
3072 // This is handled before the task executes
3073 if (privDecl.readsFromMold())
3074 continue;
3075
3076 llvm::IRBuilderBase::InsertPointGuard guard(builder);
3077 llvm::Type *llvmAllocType =
3078 moduleTranslation.convertType(privDecl.getType());
3079 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
3080 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
3081 llvmAllocType, /*ArraySize=*/nullptr, "omp.private.alloc");
3082
3083 llvm::Expected<llvm::Value *> privateVarOrError =
3084 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
3085 blockArg, llvmPrivateVar, privInitBlock);
3086 if (!privateVarOrError)
3087 return privateVarOrError.takeError();
3088 moduleTranslation.mapValue(blockArg, privateVarOrError.get());
3089 privateVarsInfo.llvmVars[i] = privateVarOrError.get();
3090 }
3091
3092 taskStructMgr.createGEPsToPrivateVars();
3093 for (auto [i, llvmPrivVar] :
3094 llvm::enumerate(taskStructMgr.getLLVMPrivateVarGEPs())) {
3095 if (!llvmPrivVar) {
3096 assert(privateVarsInfo.llvmVars[i] &&
3097 "This is added in the loop above");
3098 continue;
3099 }
3100 privateVarsInfo.llvmVars[i] = llvmPrivVar;
3101 }
3102
3103 // Find and map the addresses of each variable within the task context
3104 // structure
3105 for (auto [blockArg, llvmPrivateVar, privateDecl] :
3106 llvm::zip_equal(privateVarsInfo.blockArgs, privateVarsInfo.llvmVars,
3107 privateVarsInfo.privatizers)) {
3108 // This was handled above.
3109 if (!privateDecl.readsFromMold())
3110 continue;
3111 // Fix broken pass-by-value case for Fortran character boxes
3112 if (!mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
3113 llvmPrivateVar = builder.CreateLoad(
3114 moduleTranslation.convertType(blockArg.getType()), llvmPrivateVar);
3115 }
3116 assert(llvmPrivateVar->getType() ==
3117 moduleTranslation.convertType(blockArg.getType()));
3118 moduleTranslation.mapValue(blockArg, llvmPrivateVar);
3119 }
3120
3121 auto continuationBlockOrError = convertOmpOpRegions(
3122 taskOp.getRegion(), "omp.task.region", builder, moduleTranslation);
3123 if (failed(handleError(continuationBlockOrError, *taskOp)))
3124 return llvm::make_error<PreviouslyReportedError>();
3125
3126 builder.SetInsertPoint(continuationBlockOrError.get()->getTerminator());
3127
3128 if (failed(cleanupPrivateVars(taskOp, builder, moduleTranslation,
3129 taskOp.getLoc(), privateVarsInfo)))
3130 return llvm::make_error<PreviouslyReportedError>();
3131
3132 // Free heap allocated task context structure at the end of the task.
3133 taskStructMgr.freeStructPtr();
3134
3135 return llvm::Error::success();
3136 };
3137
3138 llvm::OpenMPIRBuilder &ompBuilder = *moduleTranslation.getOpenMPBuilder();
3139 SmallVector<llvm::UncondBrInst *> cancelTerminators;
3140 // The directive to match here is OMPD_taskgroup because it is the taskgroup
3141 // which is canceled. This is handled here because it is the task's cleanup
3142 // block which should be branched to.
3143 pushCancelFinalizationCB(cancelTerminators, builder, ompBuilder, taskOp,
3144 llvm::omp::Directive::OMPD_taskgroup);
3145
3146 llvm::OpenMPIRBuilder::DependenciesInfo dependencies;
3147 if (failed(buildDependData(taskOp.getDependVars(), taskOp.getDependKinds(),
3148 taskOp.getDependIterated(),
3149 taskOp.getDependIteratedKinds(), builder,
3150 moduleTranslation, dependencies)))
3151 return failure();
3152
3153 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3154 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3155 moduleTranslation.getOpenMPBuilder()->createTask(
3156 ompLoc, allocaIP, deallocBlocks, bodyCB, !taskOp.getUntied(),
3157 moduleTranslation.lookupValue(taskOp.getFinal()),
3158 moduleTranslation.lookupValue(taskOp.getIfExpr()), dependencies, ad,
3159 taskOp.getMergeable(),
3160 moduleTranslation.lookupValue(taskOp.getEventHandle()),
3161 moduleTranslation.lookupValue(taskOp.getPriority()));
3162
3163 if (failed(handleError(afterIP, *taskOp)))
3164 return failure();
3165
3166 // Set the correct branch target for task cancellation
3167 popCancelFinalizationCB(cancelTerminators, ompBuilder, afterIP.get());
3168
3169 builder.restoreIP(*afterIP);
3170
3171 if (dependencies.DepArray)
3172 builder.CreateFree(dependencies.DepArray);
3173
3174 return success();
3175}
3176
3177/// The correct entry point is convertOmpTaskloopContextOp. This gets called
3178/// whilst lowering the body of the taskloop context (i.e. the task function).
3179static LogicalResult
3180convertOmpTaskloopWrapperOp(omp::TaskloopWrapperOp loopWrapperOp,
3181 llvm::IRBuilderBase &builder,
3182 LLVM::ModuleTranslation &moduleTranslation) {
3183 mlir::Operation &opInst = *loopWrapperOp.getOperation();
3184 if (failed(checkImplementationStatus(opInst)))
3185 return failure();
3186
3187 // Recurse into the loop body.
3188 auto continuationBlockOrError = convertOmpOpRegions(
3189 loopWrapperOp.getRegion(), "omp.taskloop.wrapper.region", builder,
3190 moduleTranslation);
3191
3192 if (failed(handleError(continuationBlockOrError, opInst)))
3193 return failure();
3194
3195 builder.SetInsertPoint(continuationBlockOrError.get());
3196 return success();
3197}
3198
3199/// Look up the given value in the mapping, and if it's not there, translate its
3200/// defining operation at the current builder insertion point. Only pure,
3201/// regionless operations are supported because the same operation will later be
3202/// translated again when the taskloop body itself is lowered.
3203static llvm::Expected<llvm::Value *>
3205 LLVM::ModuleTranslation &moduleTranslation,
3206 llvm::IRBuilderBase &builder) {
3207 if (llvm::Value *mapped = moduleTranslation.lookupValue(value))
3208 return mapped;
3209
3210 Operation *defOp = value.getDefiningOp();
3211 if (!defOp)
3212 return llvm::make_error<llvm::StringError>(
3213 "value is a block argument and is not mapped",
3214 llvm::inconvertibleErrorCode());
3215 if (defOp->getNumRegions() != 0 || !isPure(defOp))
3216 return llvm::make_error<llvm::StringError>(
3217 "unsupported op defining taskloop loop bound",
3218 llvm::inconvertibleErrorCode());
3219
3220 SmallVector<Value> mappingsToRemove;
3221 mappingsToRemove.reserve(defOp->getNumOperands() + defOp->getNumResults());
3222 for (Value operand : defOp->getOperands()) {
3223 if (moduleTranslation.lookupValue(operand))
3224 continue;
3225
3226 llvm::Expected<llvm::Value *> operandOrError =
3227 lookupOrTranslatePureValue(operand, moduleTranslation, builder);
3228 if (!operandOrError)
3229 return operandOrError.takeError();
3230 moduleTranslation.mapValue(operand, *operandOrError);
3231 mappingsToRemove.push_back(operand);
3232 }
3233
3234 if (failed(moduleTranslation.convertOperation(*defOp, builder)))
3235 return llvm::make_error<llvm::StringError>(
3236 "failed to convert op defining taskloop loop bound",
3237 llvm::inconvertibleErrorCode());
3238
3239 llvm::Value *result = moduleTranslation.lookupValue(value);
3240 assert(result && "expected conversion of loop bound op to produce a value");
3241
3242 for (Value resultValue : defOp->getResults()) {
3243 if (moduleTranslation.lookupValue(resultValue))
3244 mappingsToRemove.push_back(resultValue);
3245 }
3246 for (Value mappedValue : mappingsToRemove)
3247 moduleTranslation.forgetMapping(mappedValue);
3248
3249 return result;
3250}
3251
3252static llvm::Error
3253computeTaskloopBounds(omp::LoopNestOp loopOp, llvm::IRBuilderBase &builder,
3254 LLVM::ModuleTranslation &moduleTranslation,
3255 llvm::Value *&lbVal, llvm::Value *&ubVal,
3256 llvm::Value *&stepVal) {
3257 Operation::operand_range lowerBounds = loopOp.getLoopLowerBounds();
3258 Operation::operand_range upperBounds = loopOp.getLoopUpperBounds();
3259 Operation::operand_range steps = loopOp.getLoopSteps();
3260
3261 llvm::Expected<llvm::Value *> firstLbOrErr =
3262 lookupOrTranslatePureValue(lowerBounds[0], moduleTranslation, builder);
3263 if (!firstLbOrErr)
3264 return firstLbOrErr.takeError();
3265
3266 llvm::Type *boundType = (*firstLbOrErr)->getType();
3267 ubVal = builder.getIntN(boundType->getIntegerBitWidth(), 1);
3268 if (loopOp.getCollapseNumLoops() > 1) {
3269 // In cases where Collapse is used with Taskloop, the upper bound of the
3270 // iteration space needs to be recalculated to cater for the collapsed loop.
3271 // The Collapsed Loop UpperBound is the product of all collapsed
3272 // loop's tripcount.
3273 // The LowerBound for collapsed loops is always 1. When the loops are
3274 // collapsed, it will reset the bounds and introduce processing to ensure
3275 // the index's are presented as expected. As this happens after creating
3276 // Taskloop, these bounds need predicting. Example:
3277 // !$omp taskloop collapse(2)
3278 // do i = 1, 10
3279 // do j = 1, 5
3280 // ..
3281 // end do
3282 // end do
3283 // This loop above has a total of 50 iterations, so the lb will be 1, and
3284 // the ub will be 50. collapseLoops in OMPIRBuilder then handles ensuring
3285 // that i and j are properly presented when used in the loop.
3286 for (uint64_t i = 0; i < loopOp.getCollapseNumLoops(); i++) {
3288 i == 0 ? std::move(firstLbOrErr)
3289 : lookupOrTranslatePureValue(lowerBounds[i], moduleTranslation,
3290 builder);
3291 if (!lbOrErr)
3292 return lbOrErr.takeError();
3294 upperBounds[i], moduleTranslation, builder);
3295 if (!ubOrErr)
3296 return ubOrErr.takeError();
3298 lookupOrTranslatePureValue(steps[i], moduleTranslation, builder);
3299 if (!stepOrErr)
3300 return stepOrErr.takeError();
3301
3302 llvm::Value *loopLb = *lbOrErr;
3303 llvm::Value *loopUb = *ubOrErr;
3304 llvm::Value *loopStep = *stepOrErr;
3305 // In some cases, such as where the ub is less than the lb so the loop
3306 // steps down, the calculation for the loopTripCount is swapped. To ensure
3307 // the correct value is found, calculate both UB - LB and LB - UB then
3308 // select which value to use depending on how the loop has been
3309 // configured.
3310 llvm::Value *loopLbMinusOne = builder.CreateSub(
3311 loopLb, builder.getIntN(boundType->getIntegerBitWidth(), 1));
3312 llvm::Value *loopUbMinusOne = builder.CreateSub(
3313 loopUb, builder.getIntN(boundType->getIntegerBitWidth(), 1));
3314 llvm::Value *boundsCmp = builder.CreateICmpSLT(loopLb, loopUb);
3315 llvm::Value *ubMinusLb = builder.CreateSub(loopUb, loopLbMinusOne);
3316 llvm::Value *lbMinusUb = builder.CreateSub(loopLb, loopUbMinusOne);
3317 llvm::Value *loopTripCount =
3318 builder.CreateSelect(boundsCmp, ubMinusLb, lbMinusUb);
3319 loopTripCount = builder.CreateBinaryIntrinsic(
3320 llvm::Intrinsic::abs, loopTripCount, builder.getFalse());
3321 // For loops that have a step value not equal to 1, we need to adjust the
3322 // trip count to ensure the correct number of iterations for the loop is
3323 // captured.
3324 llvm::Value *loopTripCountDivStep =
3325 builder.CreateSDiv(loopTripCount, loopStep);
3326 loopTripCountDivStep = builder.CreateBinaryIntrinsic(
3327 llvm::Intrinsic::abs, loopTripCountDivStep, builder.getFalse());
3328 llvm::Value *loopTripCountRem =
3329 builder.CreateSRem(loopTripCount, loopStep);
3330 loopTripCountRem = builder.CreateBinaryIntrinsic(
3331 llvm::Intrinsic::abs, loopTripCountRem, builder.getFalse());
3332 llvm::Value *needsRoundUp = builder.CreateICmpNE(
3333 loopTripCountRem,
3334 builder.getIntN(loopTripCountRem->getType()->getIntegerBitWidth(),
3335 0));
3336 loopTripCount =
3337 builder.CreateAdd(loopTripCountDivStep,
3338 builder.CreateZExtOrTrunc(
3339 needsRoundUp, loopTripCountDivStep->getType()));
3340 ubVal = builder.CreateMul(ubVal, loopTripCount);
3341 }
3342 lbVal = builder.getIntN(boundType->getIntegerBitWidth(), 1);
3343 stepVal = builder.getIntN(boundType->getIntegerBitWidth(), 1);
3344 } else {
3346 lookupOrTranslatePureValue(upperBounds[0], moduleTranslation, builder);
3347 if (!ubOrErr)
3348 return ubOrErr.takeError();
3350 lookupOrTranslatePureValue(steps[0], moduleTranslation, builder);
3351 if (!stepOrErr)
3352 return stepOrErr.takeError();
3353 lbVal = *firstLbOrErr;
3354 ubVal = *ubOrErr;
3355 stepVal = *stepOrErr;
3356 }
3357
3358 assert(lbVal != nullptr && "Expected value for lbVal");
3359 assert(ubVal != nullptr && "Expected value for ubVal");
3360 assert(stepVal != nullptr && "Expected value for stepVal");
3361 return llvm::Error::success();
3362}
3363
3364// Converts an OpenMP taskloop construct into LLVM IR using OpenMPIRBuilder.
3365static LogicalResult
3366convertOmpTaskloopContextOp(omp::TaskloopContextOp contextOp,
3367 llvm::IRBuilderBase &builder,
3368 LLVM::ModuleTranslation &moduleTranslation) {
3369 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
3370 mlir::Operation &opInst = *contextOp.getOperation();
3371 omp::TaskloopWrapperOp loopWrapperOp = contextOp.getLoopOp();
3372 if (failed(checkImplementationStatus(opInst)))
3373 return failure();
3374
3375 // It stores the pointer of allocated firstprivate copies,
3376 // which can be used later for freeing the allocated space.
3377 SmallVector<llvm::Value *> llvmFirstPrivateVars;
3378 PrivateVarsInfo privateVarsInfo(contextOp);
3379 TaskContextStructManager taskStructMgr{builder, moduleTranslation,
3380 privateVarsInfo.privatizers};
3381
3383 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3384 findAllocInsertPoints(builder, moduleTranslation, &deallocBlocks);
3385
3386 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end());
3387 llvm::BasicBlock *taskloopStartBlock = llvm::BasicBlock::Create(
3388 builder.getContext(), "omp.taskloop.wrapper.start",
3389 /*Parent=*/builder.GetInsertBlock()->getParent());
3390 llvm::Instruction *branchToTaskloopStartBlock =
3391 builder.CreateBr(taskloopStartBlock);
3392 builder.SetInsertPoint(branchToTaskloopStartBlock);
3393
3394 llvm::BasicBlock *copyBlock =
3395 splitBB(builder, /*CreateBranch=*/true, "omp.private.copy");
3396 llvm::BasicBlock *initBlock =
3397 splitBB(builder, /*CreateBranch=*/true, "omp.private.init");
3398
3400 moduleTranslation, allocaIP, deallocBlocks);
3401
3402 // Allocate and initialize private variables
3403 builder.SetInsertPoint(initBlock->getTerminator());
3404
3405 // TODO: don't allocate if the loop has zero iterations.
3406 taskStructMgr.generateTaskContextStruct();
3407 taskStructMgr.createGEPsToPrivateVars();
3408
3409 llvmFirstPrivateVars.resize(privateVarsInfo.blockArgs.size());
3410
3411 for (auto [i, zip] : llvm::enumerate(llvm::zip_equal(
3412 privateVarsInfo.privatizers, privateVarsInfo.mlirVars,
3413 privateVarsInfo.blockArgs, taskStructMgr.getLLVMPrivateVarGEPs()))) {
3414 auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVarAlloc] = zip;
3415 // To be handled inside the taskloop.
3416 if (!privDecl.readsFromMold())
3417 continue;
3418 assert(llvmPrivateVarAlloc &&
3419 "reads from mold so shouldn't have been skipped");
3420
3421 llvm::Expected<llvm::Value *> privateVarOrErr =
3422 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
3423 blockArg, llvmPrivateVarAlloc, initBlock);
3424 if (!privateVarOrErr)
3425 return handleError(privateVarOrErr, *contextOp.getOperation());
3426
3427 llvmFirstPrivateVars[i] = privateVarOrErr.get();
3428
3429 llvm::IRBuilderBase::InsertPointGuard guard(builder);
3430 builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
3431
3432 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
3433 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
3434 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
3435 // Load it so we have the value pointed to by the GEP
3436 llvmPrivateVarAlloc = builder.CreateLoad(privateVarOrErr.get()->getType(),
3437 llvmPrivateVarAlloc);
3438 }
3439 assert(llvmPrivateVarAlloc->getType() ==
3440 moduleTranslation.convertType(blockArg.getType()));
3441 }
3442
3443 // firstprivate copy region
3444 setInsertPointForPossiblyEmptyBlock(builder, copyBlock);
3445 if (failed(copyFirstPrivateVars(
3446 contextOp, builder, moduleTranslation, privateVarsInfo.mlirVars,
3447 taskStructMgr.getLLVMPrivateVarGEPs(), privateVarsInfo.privatizers,
3448 contextOp.getPrivateNeedsBarrier())))
3449 return llvm::failure();
3450
3451 // Set up inserttion point for call to createTaskloop()
3452 builder.SetInsertPoint(taskloopStartBlock);
3453
3454 auto loopOp = cast<omp::LoopNestOp>(loopWrapperOp.getWrappedLoop());
3455 llvm::Value *lbVal = nullptr;
3456 llvm::Value *ubVal = nullptr;
3457 llvm::Value *stepVal = nullptr;
3458 if (llvm::Error err = computeTaskloopBounds(
3459 loopOp, builder, moduleTranslation, lbVal, ubVal, stepVal))
3460 return handleError(std::move(err), opInst);
3461
3462 auto bodyCB =
3463 [&](InsertPointTy allocaIP, InsertPointTy codegenIP,
3464 llvm::ArrayRef<llvm::BasicBlock *> deallocBlocks) -> llvm::Error {
3465 // Save the alloca insertion point on ModuleTranslation stack for use in
3466 // nested regions.
3468 moduleTranslation, allocaIP, deallocBlocks);
3469
3470 // translate the body of the taskloop:
3471 builder.restoreIP(codegenIP);
3472
3473 llvm::BasicBlock *privInitBlock = nullptr;
3474 privateVarsInfo.llvmVars.resize(privateVarsInfo.blockArgs.size());
3475 for (auto [i, zip] : llvm::enumerate(llvm::zip_equal(
3476 privateVarsInfo.blockArgs, privateVarsInfo.privatizers,
3477 privateVarsInfo.mlirVars))) {
3478 auto [blockArg, privDecl, mlirPrivVar] = zip;
3479 // This is handled before the task executes
3480 if (privDecl.readsFromMold())
3481 continue;
3482
3483 llvm::IRBuilderBase::InsertPointGuard guard(builder);
3484 llvm::Type *llvmAllocType =
3485 moduleTranslation.convertType(privDecl.getType());
3486 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
3487 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
3488 llvmAllocType, /*ArraySize=*/nullptr, "omp.private.alloc");
3489
3490 llvm::Expected<llvm::Value *> privateVarOrError =
3491 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
3492 blockArg, llvmPrivateVar, privInitBlock);
3493 if (!privateVarOrError)
3494 return privateVarOrError.takeError();
3495 moduleTranslation.mapValue(blockArg, privateVarOrError.get());
3496 privateVarsInfo.llvmVars[i] = privateVarOrError.get();
3497 }
3498
3499 taskStructMgr.createGEPsToPrivateVars();
3500 for (auto [i, llvmPrivVar] :
3501 llvm::enumerate(taskStructMgr.getLLVMPrivateVarGEPs())) {
3502 if (!llvmPrivVar) {
3503 assert(privateVarsInfo.llvmVars[i] &&
3504 "This is added in the loop above");
3505 continue;
3506 }
3507 privateVarsInfo.llvmVars[i] = llvmPrivVar;
3508 }
3509
3510 // Find and map the addresses of each variable within the taskloop context
3511 // structure
3512 for (auto [blockArg, llvmPrivateVar, privateDecl] :
3513 llvm::zip_equal(privateVarsInfo.blockArgs, privateVarsInfo.llvmVars,
3514 privateVarsInfo.privatizers)) {
3515 // This was handled above.
3516 if (!privateDecl.readsFromMold())
3517 continue;
3518 // Fix broken pass-by-value case for Fortran character boxes
3519 if (!mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
3520 llvmPrivateVar = builder.CreateLoad(
3521 moduleTranslation.convertType(blockArg.getType()), llvmPrivateVar);
3522 }
3523 assert(llvmPrivateVar->getType() ==
3524 moduleTranslation.convertType(blockArg.getType()));
3525 moduleTranslation.mapValue(blockArg, llvmPrivateVar);
3526 }
3527
3528 // Lower the contents of the taskloop context region: this is the body of
3529 // the generated task, not the loop.
3530 auto continuationBlockOrError = convertOmpOpRegions(
3531 contextOp.getRegion(), "omp.taskloop.context.region", builder,
3532 moduleTranslation);
3533
3534 if (failed(handleError(continuationBlockOrError, opInst)))
3535 return llvm::make_error<PreviouslyReportedError>();
3536
3537 builder.SetInsertPoint(continuationBlockOrError.get()->getTerminator());
3538
3539 // This is freeing the private variables as mapped inside of the task: these
3540 // will be per-task private copies possibly after task duplication. This is
3541 // handled transparently by how these are passed to the structure passed
3542 // into the outlined function. When the task is duplicated, that structure
3543 // is duplicated too.
3544 if (failed(cleanupPrivateVars(contextOp, builder, moduleTranslation,
3545 contextOp.getLoc(), privateVarsInfo)))
3546 return llvm::make_error<PreviouslyReportedError>();
3547 // Similarly, the task context structure freed inside the task is the
3548 // per-task copy after task duplication.
3549 taskStructMgr.freeStructPtr();
3550
3551 return llvm::Error::success();
3552 };
3553
3554 // Taskloop divides into an appropriate number of tasks by repeatedly
3555 // duplicating the original task. Each time this is done, the task context
3556 // structure must be duplicated too.
3557 auto taskDupCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP,
3558 llvm::Value *destPtr, llvm::Value *srcPtr)
3560 llvm::IRBuilderBase::InsertPointGuard guard(builder);
3561 builder.restoreIP(codegenIP);
3562
3563 llvm::Type *ptrTy =
3564 builder.getPtrTy(srcPtr->getType()->getPointerAddressSpace());
3565 llvm::Value *src =
3566 builder.CreateLoad(ptrTy, srcPtr, "omp.taskloop.context.src");
3567
3568 TaskContextStructManager &srcStructMgr = taskStructMgr;
3569 TaskContextStructManager destStructMgr(builder, moduleTranslation,
3570 privateVarsInfo.privatizers);
3571 destStructMgr.generateTaskContextStruct();
3572 llvm::Value *dest = destStructMgr.getStructPtr();
3573 dest->setName("omp.taskloop.context.dest");
3574 builder.CreateStore(dest, destPtr);
3575
3577 srcStructMgr.createGEPsToPrivateVars(src);
3579 destStructMgr.createGEPsToPrivateVars(dest);
3580
3581 // Inline init regions.
3582 for (auto [privDecl, mold, blockArg, llvmPrivateVarAlloc] :
3583 llvm::zip_equal(privateVarsInfo.privatizers, srcGEPs,
3584 privateVarsInfo.blockArgs, destGEPs)) {
3585 // To be handled inside task body.
3586 if (!privDecl.readsFromMold())
3587 continue;
3588 assert(llvmPrivateVarAlloc &&
3589 "reads from mold so shouldn't have been skipped");
3590
3591 llvm::Expected<llvm::Value *> privateVarOrErr =
3592 initPrivateVar(builder, moduleTranslation, privDecl, mold, blockArg,
3593 llvmPrivateVarAlloc, builder.GetInsertBlock());
3594 if (!privateVarOrErr)
3595 return privateVarOrErr.takeError();
3596
3598
3599 // TODO: this is a bit of a hack for Fortran character boxes.
3600 // Character boxes are passed by value into the init region and then the
3601 // initialized character box is yielded by value. Here we need to store
3602 // the yielded value into the private allocation, and load the private
3603 // allocation to match the type expected by region block arguments.
3604 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
3605 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
3606 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
3607 // Load it so we have the value pointed to by the GEP
3608 llvmPrivateVarAlloc = builder.CreateLoad(
3609 privateVarOrErr.get()->getType(), llvmPrivateVarAlloc);
3610 }
3611 assert(llvmPrivateVarAlloc->getType() ==
3612 moduleTranslation.convertType(blockArg.getType()));
3613
3614 // Mapping blockArg -> llvmPrivateVarAlloc is done inside the body
3615 // callback so that OpenMPIRBuilder doesn't try to pass each GEP address
3616 // through a stack allocated structure.
3617 }
3618
3619 if (failed(copyFirstPrivateVars(contextOp.getOperation(), builder,
3620 moduleTranslation, srcGEPs, destGEPs,
3621 privateVarsInfo.privatizers,
3622 contextOp.getPrivateNeedsBarrier())))
3623 return llvm::make_error<PreviouslyReportedError>();
3624
3625 return builder.saveIP();
3626 };
3627
3628 auto loopInfo = [&]() -> llvm::Expected<llvm::CanonicalLoopInfo *> {
3629 llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
3630 return loopInfo;
3631 };
3632
3633 llvm::Value *ifCond = nullptr;
3634 llvm::Value *grainsize = nullptr;
3635 int sched = 0; // default
3636 mlir::Value grainsizeVal = contextOp.getGrainsize();
3637 mlir::Value numTasksVal = contextOp.getNumTasks();
3638 if (Value ifVar = contextOp.getIfExpr())
3639 ifCond = moduleTranslation.lookupValue(ifVar);
3640 if (grainsizeVal) {
3641 grainsize = moduleTranslation.lookupValue(grainsizeVal);
3642 sched = 1; // grainsize
3643 } else if (numTasksVal) {
3644 grainsize = moduleTranslation.lookupValue(numTasksVal);
3645 sched = 2; // num_tasks
3646 }
3647
3648 llvm::OpenMPIRBuilder::TaskDupCallbackTy taskDupOrNull = nullptr;
3649 if (taskStructMgr.getStructPtr())
3650 taskDupOrNull = taskDupCB;
3651
3652 llvm::OpenMPIRBuilder &ompBuilder = *moduleTranslation.getOpenMPBuilder();
3653 SmallVector<llvm::UncondBrInst *> cancelTerminators;
3654 // The directive to match here is OMPD_taskgroup because it is the
3655 // taskgroup which is canceled. This is handled here because it is the
3656 // task's cleanup block which should be branched to. It doesn't depend upon
3657 // nogroup because even in that case the taskloop might still be inside an
3658 // explicit taskgroup.
3659 pushCancelFinalizationCB(cancelTerminators, builder, ompBuilder, contextOp,
3660 llvm::omp::Directive::OMPD_taskgroup);
3661
3662 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3663 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3664 moduleTranslation.getOpenMPBuilder()->createTaskloop(
3665 ompLoc, allocaIP, deallocBlocks, bodyCB, loopInfo, lbVal, ubVal,
3666 stepVal, contextOp.getUntied(), ifCond, grainsize,
3667 contextOp.getNogroup(), sched,
3668 moduleTranslation.lookupValue(contextOp.getFinal()),
3669 contextOp.getMergeable(),
3670 moduleTranslation.lookupValue(contextOp.getPriority()),
3671 loopOp.getCollapseNumLoops(), taskDupOrNull,
3672 taskStructMgr.getStructPtr());
3673
3674 if (failed(handleError(afterIP, opInst)))
3675 return failure();
3676
3677 popCancelFinalizationCB(cancelTerminators, ompBuilder, afterIP.get());
3678
3679 builder.restoreIP(*afterIP);
3680 return success();
3681}
3682
3683/// Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder.
3684static LogicalResult
3685convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder,
3686 LLVM::ModuleTranslation &moduleTranslation) {
3687 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
3688 if (failed(checkImplementationStatus(*tgOp)))
3689 return failure();
3690
3691 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP,
3692 llvm::ArrayRef<llvm::BasicBlock *> deallocBlocks) {
3693 builder.restoreIP(codegenIP);
3694 return convertOmpOpRegions(tgOp.getRegion(), "omp.taskgroup.region",
3695 builder, moduleTranslation)
3696 .takeError();
3697 };
3698
3700 InsertPointTy allocaIP =
3701 findAllocInsertPoints(builder, moduleTranslation, &deallocBlocks);
3702 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3703 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3704 moduleTranslation.getOpenMPBuilder()->createTaskgroup(
3705 ompLoc, allocaIP, deallocBlocks, bodyCB);
3706
3707 if (failed(handleError(afterIP, *tgOp)))
3708 return failure();
3709
3710 builder.restoreIP(*afterIP);
3711 return success();
3712}
3713
3714static LogicalResult
3715convertOmpTaskwaitOp(omp::TaskwaitOp twOp, llvm::IRBuilderBase &builder,
3716 LLVM::ModuleTranslation &moduleTranslation) {
3717 if (failed(checkImplementationStatus(*twOp)))
3718 return failure();
3719
3720 moduleTranslation.getOpenMPBuilder()->createTaskwait(builder.saveIP());
3721 return success();
3722}
3723
3724/// Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
3725static LogicalResult
3726convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
3727 LLVM::ModuleTranslation &moduleTranslation) {
3728 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3729 auto wsloopOp = cast<omp::WsloopOp>(opInst);
3730 if (failed(checkImplementationStatus(opInst)))
3731 return failure();
3732
3733 auto loopOp = cast<omp::LoopNestOp>(wsloopOp.getWrappedLoop());
3734 llvm::ArrayRef<bool> isByRef = getIsByRef(wsloopOp.getReductionByref());
3735 assert(isByRef.size() == wsloopOp.getNumReductionVars());
3736
3737 // Static is the default.
3738 auto schedule =
3739 wsloopOp.getScheduleKind().value_or(omp::ClauseScheduleKind::Static);
3740
3741 // Find the loop configuration.
3742 llvm::Value *step = moduleTranslation.lookupValue(loopOp.getLoopSteps()[0]);
3743 llvm::Type *ivType = step->getType();
3744 llvm::Value *chunk = nullptr;
3745 if (wsloopOp.getScheduleChunk()) {
3746 llvm::Value *chunkVar =
3747 moduleTranslation.lookupValue(wsloopOp.getScheduleChunk());
3748 chunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
3749 }
3750
3751 omp::DistributeOp distributeOp = nullptr;
3752 llvm::Value *distScheduleChunk = nullptr;
3753 bool hasDistSchedule = false;
3754 if (llvm::isa_and_present<omp::DistributeOp>(opInst.getParentOp())) {
3755 distributeOp = cast<omp::DistributeOp>(opInst.getParentOp());
3756 hasDistSchedule = distributeOp.getDistScheduleStatic();
3757 if (distributeOp.getDistScheduleChunkSize()) {
3758 llvm::Value *chunkVar = moduleTranslation.lookupValue(
3759 distributeOp.getDistScheduleChunkSize());
3760 distScheduleChunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
3761 }
3762 }
3763
3764 PrivateVarsInfo privateVarsInfo(wsloopOp);
3765
3767 collectReductionDecls(wsloopOp, reductionDecls);
3768
3769 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3770 findAllocInsertPoints(builder, moduleTranslation);
3771
3772 SmallVector<llvm::Value *> privateReductionVariables(
3773 wsloopOp.getNumReductionVars());
3774
3776 wsloopOp, builder, moduleTranslation, privateVarsInfo, allocaIP);
3777 if (handleError(afterAllocas, opInst).failed())
3778 return failure();
3779
3780 DenseMap<Value, llvm::Value *> reductionVariableMap;
3781
3782 MutableArrayRef<BlockArgument> reductionArgs =
3783 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
3784
3785 SmallVector<DeferredStore> deferredStores;
3786
3787 if (failed(allocReductionVars(wsloopOp, reductionArgs, builder,
3788 moduleTranslation, allocaIP, reductionDecls,
3789 privateReductionVariables, reductionVariableMap,
3790 deferredStores, isByRef)))
3791 return failure();
3792
3793 if (handleError(initPrivateVars(builder, moduleTranslation, privateVarsInfo),
3794 opInst)
3795 .failed())
3796 return failure();
3797
3798 if (failed(copyFirstPrivateVars(
3799 wsloopOp, builder, moduleTranslation, privateVarsInfo.mlirVars,
3800 privateVarsInfo.llvmVars, privateVarsInfo.privatizers,
3801 wsloopOp.getPrivateNeedsBarrier())))
3802 return failure();
3803
3804 assert(afterAllocas.get()->getSinglePredecessor());
3805 if (failed(initReductionVars(wsloopOp, reductionArgs, builder,
3806 moduleTranslation,
3807 afterAllocas.get()->getSinglePredecessor(),
3808 reductionDecls, privateReductionVariables,
3809 reductionVariableMap, isByRef, deferredStores)))
3810 return failure();
3811
3812 // TODO: Handle doacross loops when the ordered clause has a parameter.
3813 bool isOrdered = wsloopOp.getOrdered().has_value();
3814 std::optional<omp::ScheduleModifier> scheduleMod = wsloopOp.getScheduleMod();
3815 bool isSimd = wsloopOp.getScheduleSimd();
3816 bool loopNeedsBarrier = !wsloopOp.getNowait();
3817
3818 // The only legal way for the direct parent to be omp.distribute is that this
3819 // represents 'distribute parallel do'. Otherwise, this is a regular
3820 // worksharing loop.
3821 llvm::omp::WorksharingLoopType workshareLoopType =
3822 llvm::isa_and_present<omp::DistributeOp>(opInst.getParentOp())
3823 ? llvm::omp::WorksharingLoopType::DistributeForStaticLoop
3824 : llvm::omp::WorksharingLoopType::ForStaticLoop;
3825
3826 SmallVector<llvm::UncondBrInst *> cancelTerminators;
3827 pushCancelFinalizationCB(cancelTerminators, builder, *ompBuilder, wsloopOp,
3828 llvm::omp::Directive::OMPD_for);
3829
3830 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3831
3832 // Initialize linear variables and linear step
3833 LinearClauseProcessor linearClauseProcessor;
3834
3835 if (!wsloopOp.getLinearVars().empty()) {
3836 auto linearVarTypes = wsloopOp.getLinearVarTypes().value();
3837 for (mlir::Attribute linearVarType : linearVarTypes)
3838 linearClauseProcessor.registerType(moduleTranslation, linearVarType);
3839
3840 for (auto [idx, linearVar] : llvm::enumerate(wsloopOp.getLinearVars()))
3841 linearClauseProcessor.createLinearVar(
3842 builder, moduleTranslation, moduleTranslation.lookupValue(linearVar),
3843 idx);
3844 for (mlir::Value linearStep : wsloopOp.getLinearStepVars())
3845 linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
3846 }
3847
3848 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
3850 wsloopOp.getRegion(), "omp.wsloop.region", builder, moduleTranslation);
3851
3852 if (failed(handleError(regionBlock, opInst)))
3853 return failure();
3854
3855 llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
3856
3857 // Emit Initialization and Update IR for linear variables
3858 if (!wsloopOp.getLinearVars().empty()) {
3859 linearClauseProcessor.initLinearVar(builder, moduleTranslation,
3860 loopInfo->getPreheader());
3861 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
3862 moduleTranslation.getOpenMPBuilder()->createBarrier(
3863 builder.saveIP(), llvm::omp::OMPD_barrier);
3864 if (failed(handleError(afterBarrierIP, *loopOp)))
3865 return failure();
3866 builder.restoreIP(*afterBarrierIP);
3867 linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
3868 loopInfo->getIndVar());
3869 linearClauseProcessor.splitLinearFiniBB(builder, loopInfo->getExit());
3870 }
3871
3872 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
3873
3874 // Check if we can generate no-loop kernel
3875 bool noLoopMode = false;
3876 omp::TargetOp targetOp = wsloopOp->getParentOfType<mlir::omp::TargetOp>();
3877 if (targetOp) {
3878 Operation *targetCapturedOp = targetOp.getInnermostCapturedOmpOp();
3879 // We need this check because, without it, noLoopMode would be set to true
3880 // for every omp.wsloop nested inside a no-loop SPMD target region, even if
3881 // that loop is not the top-level SPMD one.
3882 if (loopOp == targetCapturedOp) {
3883 if (targetOp.getKernelExecFlags(targetCapturedOp) ==
3884 omp::TargetExecMode::no_loop)
3885 noLoopMode = true;
3886 }
3887 }
3888
3889 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
3890 ompBuilder->applyWorkshareLoop(
3891 ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
3892 convertToScheduleKind(schedule), chunk, isSimd,
3893 scheduleMod == omp::ScheduleModifier::monotonic,
3894 scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
3895 workshareLoopType, noLoopMode, hasDistSchedule, distScheduleChunk);
3896
3897 if (failed(handleError(wsloopIP, opInst)))
3898 return failure();
3899
3900 // Emit finalization and in-place rewrites for linear vars.
3901 if (!wsloopOp.getLinearVars().empty()) {
3902 llvm::OpenMPIRBuilder::InsertPointTy oldIP = builder.saveIP();
3903 assert(loopInfo->getLastIter() &&
3904 "`lastiter` in CanonicalLoopInfo is nullptr");
3905 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
3906 linearClauseProcessor.finalizeLinearVar(builder, moduleTranslation,
3907 loopInfo->getLastIter());
3908 if (failed(handleError(afterBarrierIP, *loopOp)))
3909 return failure();
3910 for (size_t index = 0; index < wsloopOp.getLinearVars().size(); index++)
3911 linearClauseProcessor.rewriteInPlace(
3912 builder, sourceBlock->getSingleSuccessor(), *regionBlock,
3913 "omp.loop_nest.region", index);
3914
3915 builder.restoreIP(oldIP);
3916 }
3917
3918 // Set the correct branch target for task cancellation
3919 popCancelFinalizationCB(cancelTerminators, *ompBuilder, wsloopIP.get());
3920
3921 // Process the reductions if required.
3922 if (failed(createReductionsAndCleanup(
3923 wsloopOp, builder, moduleTranslation, allocaIP, reductionDecls,
3924 privateReductionVariables, isByRef, wsloopOp.getNowait(),
3925 /*isTeamsReduction=*/false)))
3926 return failure();
3927
3928 return cleanupPrivateVars(wsloopOp, builder, moduleTranslation,
3929 wsloopOp.getLoc(), privateVarsInfo);
3930}
3931
3932/// Converts the OpenMP parallel operation to LLVM IR.
3933static LogicalResult
3934convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
3935 LLVM::ModuleTranslation &moduleTranslation) {
3936 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
3937 ArrayRef<bool> isByRef = getIsByRef(opInst.getReductionByref());
3938 assert(isByRef.size() == opInst.getNumReductionVars());
3939 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3940 bool isCancellable = constructIsCancellable(opInst);
3941
3942 if (failed(checkImplementationStatus(*opInst)))
3943 return failure();
3944
3945 PrivateVarsInfo privateVarsInfo(opInst);
3946
3947 // Collect reduction declarations
3949 collectReductionDecls(opInst, reductionDecls);
3950 SmallVector<llvm::Value *> privateReductionVariables(
3951 opInst.getNumReductionVars());
3952 SmallVector<DeferredStore> deferredStores;
3953
3954 auto bodyGenCB =
3955 [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
3956 llvm::ArrayRef<llvm::BasicBlock *> deallocBlocks) -> llvm::Error {
3958 opInst, builder, moduleTranslation, privateVarsInfo, allocaIP);
3959 if (handleError(afterAllocas, *opInst).failed())
3960 return llvm::make_error<PreviouslyReportedError>();
3961
3962 // Allocate reduction vars
3963 DenseMap<Value, llvm::Value *> reductionVariableMap;
3964
3965 MutableArrayRef<BlockArgument> reductionArgs =
3966 cast<omp::BlockArgOpenMPOpInterface>(*opInst).getReductionBlockArgs();
3967
3968 allocaIP =
3969 InsertPointTy(allocaIP.getBlock(),
3970 allocaIP.getBlock()->getTerminator()->getIterator());
3971
3972 if (failed(allocReductionVars(
3973 opInst, reductionArgs, builder, moduleTranslation, allocaIP,
3974 reductionDecls, privateReductionVariables, reductionVariableMap,
3975 deferredStores, isByRef)))
3976 return llvm::make_error<PreviouslyReportedError>();
3977
3978 assert(afterAllocas.get()->getSinglePredecessor());
3979 builder.restoreIP(codeGenIP);
3980
3981 if (handleError(
3982 initPrivateVars(builder, moduleTranslation, privateVarsInfo),
3983 *opInst)
3984 .failed())
3985 return llvm::make_error<PreviouslyReportedError>();
3986
3987 if (failed(copyFirstPrivateVars(
3988 opInst, builder, moduleTranslation, privateVarsInfo.mlirVars,
3989 privateVarsInfo.llvmVars, privateVarsInfo.privatizers,
3990 opInst.getPrivateNeedsBarrier())))
3991 return llvm::make_error<PreviouslyReportedError>();
3992
3993 if (failed(
3994 initReductionVars(opInst, reductionArgs, builder, moduleTranslation,
3995 afterAllocas.get()->getSinglePredecessor(),
3996 reductionDecls, privateReductionVariables,
3997 reductionVariableMap, isByRef, deferredStores)))
3998 return llvm::make_error<PreviouslyReportedError>();
3999
4000 // Save the alloca insertion point on ModuleTranslation stack for use in
4001 // nested regions.
4003 moduleTranslation, allocaIP, deallocBlocks);
4004
4005 // ParallelOp has only one region associated with it.
4007 opInst.getRegion(), "omp.par.region", builder, moduleTranslation);
4008 if (!regionBlock)
4009 return regionBlock.takeError();
4010
4011 // Process the reductions if required.
4012 if (opInst.getNumReductionVars() > 0) {
4013 // Collect reduction info
4015 SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
4017 owningReductionGenRefDataPtrGens;
4019 collectReductionInfo(opInst, builder, moduleTranslation, reductionDecls,
4020 owningReductionGens, owningAtomicReductionGens,
4021 owningReductionGenRefDataPtrGens,
4022 privateReductionVariables, reductionInfos, isByRef);
4023
4024 // Move to region cont block
4025 builder.SetInsertPoint((*regionBlock)->getTerminator());
4026
4027 // Generate reductions from info
4028 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
4029 builder.SetInsertPoint(tempTerminator);
4030
4031 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
4032 ompBuilder->createReductions(
4033 builder.saveIP(), allocaIP, reductionInfos, isByRef,
4034 /*IsNoWait=*/false, /*IsTeamsReduction=*/false);
4035 if (!contInsertPoint)
4036 return contInsertPoint.takeError();
4037
4038 if (!contInsertPoint->getBlock())
4039 return llvm::make_error<PreviouslyReportedError>();
4040
4041 tempTerminator->eraseFromParent();
4042 builder.restoreIP(*contInsertPoint);
4043 }
4044
4045 return llvm::Error::success();
4046 };
4047
4048 auto privCB = [](InsertPointTy allocaIP, InsertPointTy codeGenIP,
4049 llvm::Value &, llvm::Value &val, llvm::Value *&replVal) {
4050 // tell OpenMPIRBuilder not to do anything. We handled Privatisation in
4051 // bodyGenCB.
4052 replVal = &val;
4053 return codeGenIP;
4054 };
4055
4056 // TODO: Perform finalization actions for variables. This has to be
4057 // called for variables which have destructors/finalizers.
4058 auto finiCB = [&](InsertPointTy codeGenIP) -> llvm::Error {
4059 InsertPointTy oldIP = builder.saveIP();
4060 builder.restoreIP(codeGenIP);
4061
4062 // if the reduction has a cleanup region, inline it here to finalize the
4063 // reduction variables
4064 SmallVector<Region *> reductionCleanupRegions;
4065 llvm::transform(reductionDecls, std::back_inserter(reductionCleanupRegions),
4066 [](omp::DeclareReductionOp reductionDecl) {
4067 return &reductionDecl.getCleanupRegion();
4068 });
4069 if (failed(inlineOmpRegionCleanup(
4070 reductionCleanupRegions, privateReductionVariables,
4071 moduleTranslation, builder, "omp.reduction.cleanup")))
4072 return llvm::createStringError(
4073 "failed to inline `cleanup` region of `omp.declare_reduction`");
4074
4075 if (failed(cleanupPrivateVars(opInst, builder, moduleTranslation,
4076 opInst.getLoc(), privateVarsInfo)))
4077 return llvm::make_error<PreviouslyReportedError>();
4078
4079 // If we could be performing cancellation, add the cancellation barrier on
4080 // the way out of the outlined region.
4081 if (isCancellable) {
4082 auto IPOrErr = ompBuilder->createBarrier(
4083 llvm::OpenMPIRBuilder::LocationDescription(builder),
4084 llvm::omp::Directive::OMPD_unknown,
4085 /* ForceSimpleCall */ false,
4086 /* CheckCancelFlag */ false);
4087 if (!IPOrErr)
4088 return IPOrErr.takeError();
4089 }
4090
4091 builder.restoreIP(oldIP);
4092 return llvm::Error::success();
4093 };
4094
4095 llvm::Value *ifCond = nullptr;
4096 if (auto ifVar = opInst.getIfExpr())
4097 ifCond = moduleTranslation.lookupValue(ifVar);
4098 llvm::Value *numThreads = nullptr;
4099 if (!opInst.getNumThreadsVars().empty())
4100 numThreads = moduleTranslation.lookupValue(opInst.getNumThreads(0));
4101 auto pbKind = llvm::omp::OMP_PROC_BIND_default;
4102 if (auto bind = opInst.getProcBindKind())
4103 pbKind = getProcBindKind(*bind);
4104
4106 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4107 findAllocInsertPoints(builder, moduleTranslation, &deallocBlocks);
4108 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4109
4110 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4111 ompBuilder->createParallel(ompLoc, allocaIP, deallocBlocks, bodyGenCB,
4112 privCB, finiCB, ifCond, numThreads, pbKind,
4113 isCancellable);
4114
4115 if (failed(handleError(afterIP, *opInst)))
4116 return failure();
4117
4118 builder.restoreIP(*afterIP);
4119 return success();
4120}
4121
4122/// Convert Order attribute to llvm::omp::OrderKind.
4123static llvm::omp::OrderKind
4124convertOrderKind(std::optional<omp::ClauseOrderKind> o) {
4125 if (!o)
4126 return llvm::omp::OrderKind::OMP_ORDER_unknown;
4127 switch (*o) {
4128 case omp::ClauseOrderKind::Concurrent:
4129 return llvm::omp::OrderKind::OMP_ORDER_concurrent;
4130 }
4131 llvm_unreachable("Unknown ClauseOrderKind kind");
4132}
4133
4134/// Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder.
4135static LogicalResult
4136convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
4137 LLVM::ModuleTranslation &moduleTranslation) {
4138 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4139 auto simdOp = cast<omp::SimdOp>(opInst);
4140
4141 if (failed(checkImplementationStatus(opInst)))
4142 return failure();
4143
4144 PrivateVarsInfo privateVarsInfo(simdOp);
4145
4146 MutableArrayRef<BlockArgument> reductionArgs =
4147 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
4148 DenseMap<Value, llvm::Value *> reductionVariableMap;
4149 SmallVector<llvm::Value *> privateReductionVariables(
4150 simdOp.getNumReductionVars());
4151 SmallVector<DeferredStore> deferredStores;
4153 collectReductionDecls(simdOp, reductionDecls);
4154 llvm::ArrayRef<bool> isByRef = getIsByRef(simdOp.getReductionByref());
4155 assert(isByRef.size() == simdOp.getNumReductionVars());
4156
4157 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4158 findAllocInsertPoints(builder, moduleTranslation);
4159
4161 simdOp, builder, moduleTranslation, privateVarsInfo, allocaIP);
4162 if (handleError(afterAllocas, opInst).failed())
4163 return failure();
4164
4165 // Initialize linear variables and linear step
4166 LinearClauseProcessor linearClauseProcessor;
4167
4168 if (!simdOp.getLinearVars().empty()) {
4169 auto linearVarTypes = simdOp.getLinearVarTypes().value();
4170 for (mlir::Attribute linearVarType : linearVarTypes)
4171 linearClauseProcessor.registerType(moduleTranslation, linearVarType);
4172 for (auto [idx, linearVar] : llvm::enumerate(simdOp.getLinearVars())) {
4173 bool isImplicit = false;
4174 for (auto [mlirPrivVar, llvmPrivateVar] : llvm::zip_equal(
4175 privateVarsInfo.mlirVars, privateVarsInfo.llvmVars)) {
4176 // If the linear variable is implicit, reuse the already
4177 // existing llvm::Value
4178 if (linearVar == mlirPrivVar) {
4179 isImplicit = true;
4180 linearClauseProcessor.createLinearVar(builder, moduleTranslation,
4181 llvmPrivateVar, idx);
4182 break;
4183 }
4184 }
4185
4186 if (!isImplicit)
4187 linearClauseProcessor.createLinearVar(
4188 builder, moduleTranslation,
4189 moduleTranslation.lookupValue(linearVar), idx);
4190 }
4191 for (mlir::Value linearStep : simdOp.getLinearStepVars())
4192 linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
4193 }
4194
4195 if (failed(allocReductionVars(simdOp, reductionArgs, builder,
4196 moduleTranslation, allocaIP, reductionDecls,
4197 privateReductionVariables, reductionVariableMap,
4198 deferredStores, isByRef)))
4199 return failure();
4200
4201 if (handleError(initPrivateVars(builder, moduleTranslation, privateVarsInfo),
4202 opInst)
4203 .failed())
4204 return failure();
4205
4206 // No call to copyFirstPrivateVars because FIRSTPRIVATE is not allowed for
4207 // SIMD.
4208
4209 assert(afterAllocas.get()->getSinglePredecessor());
4210 if (failed(initReductionVars(simdOp, reductionArgs, builder,
4211 moduleTranslation,
4212 afterAllocas.get()->getSinglePredecessor(),
4213 reductionDecls, privateReductionVariables,
4214 reductionVariableMap, isByRef, deferredStores)))
4215 return failure();
4216
4217 llvm::ConstantInt *simdlen = nullptr;
4218 if (std::optional<uint64_t> simdlenVar = simdOp.getSimdlen())
4219 simdlen = builder.getInt64(simdlenVar.value());
4220
4221 llvm::ConstantInt *safelen = nullptr;
4222 if (std::optional<uint64_t> safelenVar = simdOp.getSafelen())
4223 safelen = builder.getInt64(safelenVar.value());
4224
4225 llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars;
4226 llvm::omp::OrderKind order = convertOrderKind(simdOp.getOrder());
4227
4228 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
4229 std::optional<ArrayAttr> alignmentValues = simdOp.getAlignments();
4230 mlir::OperandRange operands = simdOp.getAlignedVars();
4231 for (size_t i = 0; i < operands.size(); ++i) {
4232 llvm::Value *alignment = nullptr;
4233 llvm::Value *llvmVal = moduleTranslation.lookupValue(operands[i]);
4234 llvm::Type *ty = llvmVal->getType();
4235
4236 auto intAttr = cast<IntegerAttr>((*alignmentValues)[i]);
4237 alignment = builder.getInt64(intAttr.getInt());
4238 assert(ty->isPointerTy() && "Invalid type for aligned variable");
4239 assert(alignment && "Invalid alignment value");
4240
4241 // Check if the alignment value is not a power of 2. If so, skip emitting
4242 // alignment.
4243 if (!intAttr.getValue().isPowerOf2())
4244 continue;
4245
4246 auto curInsert = builder.saveIP();
4247 builder.SetInsertPoint(sourceBlock);
4248 llvmVal = builder.CreateLoad(ty, llvmVal);
4249 builder.restoreIP(curInsert);
4250 alignedVars[llvmVal] = alignment;
4251 }
4252
4254 simdOp.getRegion(), "omp.simd.region", builder, moduleTranslation);
4255
4256 if (failed(handleError(regionBlock, opInst)))
4257 return failure();
4258
4259 llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
4260 // Emit Initialization for linear variables
4261 if (simdOp.getLinearVars().size()) {
4262 linearClauseProcessor.initLinearVar(builder, moduleTranslation,
4263 loopInfo->getPreheader());
4264
4265 linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
4266 loopInfo->getIndVar());
4267 }
4268 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
4269
4270 ompBuilder->applySimd(loopInfo, alignedVars,
4271 simdOp.getIfExpr()
4272 ? moduleTranslation.lookupValue(simdOp.getIfExpr())
4273 : nullptr,
4274 order, simdlen, safelen);
4275
4276 linearClauseProcessor.emitStoresForLinearVar(builder);
4277
4278 // Check if this SIMD loop contains ordered regions
4279 bool hasOrderedRegions = false;
4280 simdOp.getRegion().walk([&](omp::OrderedRegionOp orderedOp) {
4281 hasOrderedRegions = true;
4282 return WalkResult::interrupt();
4283 });
4284
4285 for (size_t index = 0; index < simdOp.getLinearVars().size(); index++) {
4286 llvm::BasicBlock *startBB = sourceBlock->getSingleSuccessor();
4287 llvm::BasicBlock *endBB = *regionBlock;
4288 linearClauseProcessor.rewriteInPlace(builder, startBB, endBB,
4289 "omp.loop_nest.region", index);
4290
4291 if (hasOrderedRegions) {
4292 // Also rewrite uses in ordered regions so they read the current value
4293 linearClauseProcessor.rewriteInPlace(builder, startBB, endBB,
4294 "omp.ordered.region", index);
4295 // Also rewrite uses in finalize blocks (code after ordered regions)
4296 linearClauseProcessor.rewriteInPlace(builder, startBB, endBB,
4297 "omp_region.finalize", index);
4298 }
4299 }
4300
4301 // We now need to reduce the per-simd-lane reduction variable into the
4302 // original variable. This works a bit differently to other reductions (e.g.
4303 // wsloop) because we don't need to call into the OpenMP runtime to handle
4304 // threads: everything happened in this one thread.
4305 for (auto [i, tuple] : llvm::enumerate(
4306 llvm::zip(reductionDecls, isByRef, simdOp.getReductionVars(),
4307 privateReductionVariables))) {
4308 auto [decl, byRef, reductionVar, privateReductionVar] = tuple;
4309
4310 OwningReductionGen gen = makeReductionGen(decl, builder, moduleTranslation);
4311 llvm::Value *originalVariable = moduleTranslation.lookupValue(reductionVar);
4312 llvm::Type *reductionType = moduleTranslation.convertType(decl.getType());
4313
4314 // We have one less load for by-ref case because that load is now inside of
4315 // the reduction region.
4316 llvm::Value *redValue = originalVariable;
4317 if (!byRef)
4318 redValue =
4319 builder.CreateLoad(reductionType, redValue, "red.value." + Twine(i));
4320 llvm::Value *privateRedValue = builder.CreateLoad(
4321 reductionType, privateReductionVar, "red.private.value." + Twine(i));
4322 llvm::Value *reduced;
4323
4324 auto res = gen(builder.saveIP(), redValue, privateRedValue, reduced);
4325 if (failed(handleError(res, opInst)))
4326 return failure();
4327 builder.restoreIP(res.get());
4328
4329 // For by-ref case, the store is inside of the reduction region.
4330 if (!byRef)
4331 builder.CreateStore(reduced, originalVariable);
4332 }
4333
4334 // After the construct, deallocate private reduction variables.
4335 SmallVector<Region *> reductionRegions;
4336 llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
4337 [](omp::DeclareReductionOp reductionDecl) {
4338 return &reductionDecl.getCleanupRegion();
4339 });
4340 if (failed(inlineOmpRegionCleanup(reductionRegions, privateReductionVariables,
4341 moduleTranslation, builder,
4342 "omp.reduction.cleanup")))
4343 return failure();
4344
4345 return cleanupPrivateVars(simdOp, builder, moduleTranslation, simdOp.getLoc(),
4346 privateVarsInfo);
4347}
4348
4349/// Converts an OpenMP loop nest into LLVM IR using OpenMPIRBuilder.
4350static LogicalResult
4351convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder,
4352 LLVM::ModuleTranslation &moduleTranslation) {
4353 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4354 auto loopOp = cast<omp::LoopNestOp>(opInst);
4355
4356 if (failed(checkImplementationStatus(opInst)))
4357 return failure();
4358
4359 // Set up the source location value for OpenMP runtime.
4360 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4361
4362 // Generator of the canonical loop body.
4365 auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip,
4366 llvm::Value *iv) -> llvm::Error {
4367 // Make sure further conversions know about the induction variable.
4368 moduleTranslation.mapValue(
4369 loopOp.getRegion().front().getArgument(loopInfos.size()), iv);
4370
4371 // Capture the body insertion point for use in nested loops. BodyIP of the
4372 // CanonicalLoopInfo always points to the beginning of the entry block of
4373 // the body.
4374 bodyInsertPoints.push_back(ip);
4375
4376 if (loopInfos.size() != loopOp.getNumLoops() - 1)
4377 return llvm::Error::success();
4378
4379 // Convert the body of the loop.
4380 builder.restoreIP(ip);
4382 loopOp.getRegion(), "omp.loop_nest.region", builder, moduleTranslation);
4383 if (!regionBlock)
4384 return regionBlock.takeError();
4385
4386 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
4387 return llvm::Error::success();
4388 };
4389
4390 // Delegate actual loop construction to the OpenMP IRBuilder.
4391 // TODO: this currently assumes omp.loop_nest is semantically similar to SCF
4392 // loop, i.e. it has a positive step, uses signed integer semantics.
4393 // Reconsider this code when the nested loop operation clearly supports more
4394 // cases.
4395 for (unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) {
4396 llvm::Value *lowerBound =
4397 moduleTranslation.lookupValue(loopOp.getLoopLowerBounds()[i]);
4398 llvm::Value *upperBound =
4399 moduleTranslation.lookupValue(loopOp.getLoopUpperBounds()[i]);
4400 llvm::Value *step = moduleTranslation.lookupValue(loopOp.getLoopSteps()[i]);
4401
4402 // Make sure loop trip count are emitted in the preheader of the outermost
4403 // loop at the latest so that they are all available for the new collapsed
4404 // loop will be created below.
4405 llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc;
4406 llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP;
4407 if (i != 0) {
4408 loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back(),
4409 ompLoc.DL);
4410 computeIP = loopInfos.front()->getPreheaderIP();
4411 }
4412
4414 ompBuilder->createCanonicalLoop(
4415 loc, bodyGen, lowerBound, upperBound, step,
4416 /*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP);
4417
4418 if (failed(handleError(loopResult, *loopOp)))
4419 return failure();
4420
4421 loopInfos.push_back(*loopResult);
4422 }
4423
4424 llvm::OpenMPIRBuilder::InsertPointTy afterIP =
4425 loopInfos.front()->getAfterIP();
4426
4427 // Do tiling.
4428 if (const auto &tiles = loopOp.getTileSizes()) {
4429 llvm::Type *ivType = loopInfos.front()->getIndVarType();
4431
4432 for (auto tile : tiles.value()) {
4433 llvm::Value *tileVal = llvm::ConstantInt::get(ivType, tile);
4434 tileSizes.push_back(tileVal);
4435 }
4436
4437 std::vector<llvm::CanonicalLoopInfo *> newLoops =
4438 ompBuilder->tileLoops(ompLoc.DL, loopInfos, tileSizes);
4439
4440 // Update afterIP to get the correct insertion point after
4441 // tiling.
4442 llvm::BasicBlock *afterBB = newLoops.front()->getAfter();
4443 llvm::BasicBlock *afterAfterBB = afterBB->getSingleSuccessor();
4444 afterIP = {afterAfterBB, afterAfterBB->begin()};
4445
4446 // Update the loop infos.
4447 loopInfos.clear();
4448 for (const auto &newLoop : newLoops)
4449 loopInfos.push_back(newLoop);
4450 } // Tiling done.
4451
4452 // Do collapse.
4453 const auto &numCollapse = loopOp.getCollapseNumLoops();
4455 loopInfos.begin(), loopInfos.begin() + (numCollapse));
4456
4457 auto newTopLoopInfo =
4458 ompBuilder->collapseLoops(ompLoc.DL, collapseLoopInfos, {});
4459
4460 assert(newTopLoopInfo && "New top loop information is missing");
4461 moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
4462 [&](OpenMPLoopInfoStackFrame &frame) {
4463 frame.loopInfo = newTopLoopInfo;
4464 return WalkResult::interrupt();
4465 });
4466
4467 // Continue building IR after the loop. Note that the LoopInfo returned by
4468 // `collapseLoops` points inside the outermost loop and is intended for
4469 // potential further loop transformations. Use the insertion point stored
4470 // before collapsing loops instead.
4471 builder.restoreIP(afterIP);
4472 return success();
4473}
4474
4475/// Convert an omp.canonical_loop to LLVM-IR
4476static LogicalResult
4477convertOmpCanonicalLoopOp(omp::CanonicalLoopOp op, llvm::IRBuilderBase &builder,
4478 LLVM::ModuleTranslation &moduleTranslation) {
4479 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4480
4481 llvm::OpenMPIRBuilder::LocationDescription loopLoc(builder);
4482 Value loopIV = op.getInductionVar();
4483 Value loopTC = op.getTripCount();
4484
4485 llvm::Value *llvmTC = moduleTranslation.lookupValue(loopTC);
4486
4488 ompBuilder->createCanonicalLoop(
4489 loopLoc,
4490 [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *llvmIV) {
4491 // Register the mapping of MLIR induction variable to LLVM-IR
4492 // induction variable
4493 moduleTranslation.mapValue(loopIV, llvmIV);
4494
4495 builder.restoreIP(ip);
4497 convertOmpOpRegions(op.getRegion(), "omp.loop.region", builder,
4498 moduleTranslation);
4499
4500 return bodyGenStatus.takeError();
4501 },
4502 llvmTC, "omp.loop");
4503 if (!llvmOrError)
4504 return op.emitError(llvm::toString(llvmOrError.takeError()));
4505
4506 llvm::CanonicalLoopInfo *llvmCLI = *llvmOrError;
4507 llvm::IRBuilderBase::InsertPoint afterIP = llvmCLI->getAfterIP();
4508 builder.restoreIP(afterIP);
4509
4510 // Register the mapping of MLIR loop to LLVM-IR OpenMPIRBuilder loop
4511 if (Value cli = op.getCli())
4512 moduleTranslation.mapOmpLoop(cli, llvmCLI);
4513
4514 return success();
4515}
4516
4517/// Apply a `#pragma omp unroll` / "!$omp unroll" transformation using the
4518/// OpenMPIRBuilder.
4519static LogicalResult
4520applyUnrollHeuristic(omp::UnrollHeuristicOp op, llvm::IRBuilderBase &builder,
4521 LLVM::ModuleTranslation &moduleTranslation) {
4522 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4523
4524 Value applyee = op.getApplyee();
4525 assert(applyee && "Loop to apply unrolling on required");
4526
4527 llvm::CanonicalLoopInfo *consBuilderCLI =
4528 moduleTranslation.lookupOMPLoop(applyee);
4529 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
4530 ompBuilder->unrollLoopHeuristic(loc.DL, consBuilderCLI);
4531
4532 moduleTranslation.invalidateOmpLoop(applyee);
4533 return success();
4534}
4535
4536/// Apply a `#pragma omp tile` / `!$omp tile` transformation using the
4537/// OpenMPIRBuilder.
4538static LogicalResult applyTile(omp::TileOp op, llvm::IRBuilderBase &builder,
4539 LLVM::ModuleTranslation &moduleTranslation) {
4540 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4541 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
4542
4544 SmallVector<llvm::Value *> translatedSizes;
4545
4546 for (Value size : op.getSizes()) {
4547 llvm::Value *translatedSize = moduleTranslation.lookupValue(size);
4548 assert(translatedSize &&
4549 "sizes clause arguments must already be translated");
4550 translatedSizes.push_back(translatedSize);
4551 }
4552
4553 for (Value applyee : op.getApplyees()) {
4554 llvm::CanonicalLoopInfo *consBuilderCLI =
4555 moduleTranslation.lookupOMPLoop(applyee);
4556 assert(applyee && "Canonical loop must already been translated");
4557 translatedLoops.push_back(consBuilderCLI);
4558 }
4559
4560 auto generatedLoops =
4561 ompBuilder->tileLoops(loc.DL, translatedLoops, translatedSizes);
4562 if (!op.getGeneratees().empty()) {
4563 for (auto [mlirLoop, genLoop] :
4564 zip_equal(op.getGeneratees(), generatedLoops))
4565 moduleTranslation.mapOmpLoop(mlirLoop, genLoop);
4566 }
4567
4568 // CLIs can only be consumed once
4569 for (Value applyee : op.getApplyees())
4570 moduleTranslation.invalidateOmpLoop(applyee);
4571
4572 return success();
4573}
4574
4575/// Apply a `#pragma omp fuse` / `!$omp fuse` transformation using the
4576/// OpenMPIRBuilder.
4577static LogicalResult applyFuse(omp::FuseOp op, llvm::IRBuilderBase &builder,
4578 LLVM::ModuleTranslation &moduleTranslation) {
4579 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4580 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
4581
4582 // Select what CLIs are going to be fused
4583 SmallVector<llvm::CanonicalLoopInfo *> beforeFuse, toFuse, afterFuse;
4584 for (size_t i = 0; i < op.getApplyees().size(); i++) {
4585 Value applyee = op.getApplyees()[i];
4586 llvm::CanonicalLoopInfo *consBuilderCLI =
4587 moduleTranslation.lookupOMPLoop(applyee);
4588 assert(applyee && "Canonical loop must already been translated");
4589 if (op.getFirst().has_value() && i < op.getFirst().value() - 1)
4590 beforeFuse.push_back(consBuilderCLI);
4591 else if (op.getCount().has_value() &&
4592 i >= op.getFirst().value() + op.getCount().value() - 1)
4593 afterFuse.push_back(consBuilderCLI);
4594 else
4595 toFuse.push_back(consBuilderCLI);
4596 }
4597 assert(
4598 (op.getGeneratees().empty() ||
4599 beforeFuse.size() + afterFuse.size() + 1 == op.getGeneratees().size()) &&
4600 "Wrong number of generatees");
4601
4602 // do the fuse
4603 auto generatedLoop = ompBuilder->fuseLoops(loc.DL, toFuse);
4604 if (!op.getGeneratees().empty()) {
4605 size_t i = 0;
4606 for (; i < beforeFuse.size(); i++)
4607 moduleTranslation.mapOmpLoop(op.getGeneratees()[i], beforeFuse[i]);
4608 moduleTranslation.mapOmpLoop(op.getGeneratees()[i++], generatedLoop);
4609 for (; i < afterFuse.size(); i++)
4610 moduleTranslation.mapOmpLoop(op.getGeneratees()[i], afterFuse[i]);
4611 }
4612
4613 // CLIs can only be consumed once
4614 for (Value applyee : op.getApplyees())
4615 moduleTranslation.invalidateOmpLoop(applyee);
4616
4617 return success();
4618}
4619
4620/// Convert an Atomic Ordering attribute to llvm::AtomicOrdering.
4621static llvm::AtomicOrdering
4622convertAtomicOrdering(std::optional<omp::ClauseMemoryOrderKind> ao) {
4623 if (!ao)
4624 return llvm::AtomicOrdering::Monotonic; // Default Memory Ordering
4625
4626 switch (*ao) {
4627 case omp::ClauseMemoryOrderKind::Seq_cst:
4628 return llvm::AtomicOrdering::SequentiallyConsistent;
4629 case omp::ClauseMemoryOrderKind::Acq_rel:
4630 return llvm::AtomicOrdering::AcquireRelease;
4631 case omp::ClauseMemoryOrderKind::Acquire:
4632 return llvm::AtomicOrdering::Acquire;
4633 case omp::ClauseMemoryOrderKind::Release:
4634 return llvm::AtomicOrdering::Release;
4635 case omp::ClauseMemoryOrderKind::Relaxed:
4636 return llvm::AtomicOrdering::Monotonic;
4637 }
4638 llvm_unreachable("Unknown ClauseMemoryOrderKind kind");
4639}
4640
4641/// Convert omp.atomic.read operation to LLVM IR.
4642static LogicalResult
4643convertOmpAtomicRead(Operation &opInst, llvm::IRBuilderBase &builder,
4644 LLVM::ModuleTranslation &moduleTranslation) {
4645 auto readOp = cast<omp::AtomicReadOp>(opInst);
4646 if (failed(checkImplementationStatus(opInst)))
4647 return failure();
4648
4649 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4650 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4651 findAllocInsertPoints(builder, moduleTranslation);
4652
4653 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4654
4655 llvm::AtomicOrdering AO = convertAtomicOrdering(readOp.getMemoryOrder());
4656 llvm::Value *x = moduleTranslation.lookupValue(readOp.getX());
4657 llvm::Value *v = moduleTranslation.lookupValue(readOp.getV());
4658
4659 llvm::Type *elementType =
4660 moduleTranslation.convertType(readOp.getElementType());
4661
4662 llvm::OpenMPIRBuilder::AtomicOpValue V = {v, elementType, false, false};
4663 llvm::OpenMPIRBuilder::AtomicOpValue X = {x, elementType, false, false};
4664 builder.restoreIP(ompBuilder->createAtomicRead(ompLoc, X, V, AO, allocaIP));
4665 return success();
4666}
4667
4668/// Converts an omp.atomic.write operation to LLVM IR.
4669static LogicalResult
4670convertOmpAtomicWrite(Operation &opInst, llvm::IRBuilderBase &builder,
4671 LLVM::ModuleTranslation &moduleTranslation) {
4672 auto writeOp = cast<omp::AtomicWriteOp>(opInst);
4673 if (failed(checkImplementationStatus(opInst)))
4674 return failure();
4675
4676 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4677 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4678 findAllocInsertPoints(builder, moduleTranslation);
4679
4680 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4681 llvm::AtomicOrdering ao = convertAtomicOrdering(writeOp.getMemoryOrder());
4682 llvm::Value *expr = moduleTranslation.lookupValue(writeOp.getExpr());
4683 llvm::Value *dest = moduleTranslation.lookupValue(writeOp.getX());
4684 llvm::Type *ty = moduleTranslation.convertType(writeOp.getExpr().getType());
4685 llvm::OpenMPIRBuilder::AtomicOpValue x = {dest, ty, /*isSigned=*/false,
4686 /*isVolatile=*/false};
4687 builder.restoreIP(
4688 ompBuilder->createAtomicWrite(ompLoc, x, expr, ao, allocaIP));
4689 return success();
4690}
4691
4692/// Converts an LLVM dialect binary operation to the corresponding enum value
4693/// for `atomicrmw` supported binary operation.
4694static llvm::AtomicRMWInst::BinOp convertBinOpToAtomic(Operation &op) {
4696 .Case([&](LLVM::AddOp) { return llvm::AtomicRMWInst::BinOp::Add; })
4697 .Case([&](LLVM::SubOp) { return llvm::AtomicRMWInst::BinOp::Sub; })
4698 .Case([&](LLVM::AndOp) { return llvm::AtomicRMWInst::BinOp::And; })
4699 .Case([&](LLVM::OrOp) { return llvm::AtomicRMWInst::BinOp::Or; })
4700 .Case([&](LLVM::XOrOp) { return llvm::AtomicRMWInst::BinOp::Xor; })
4701 .Case([&](LLVM::UMaxOp) { return llvm::AtomicRMWInst::BinOp::UMax; })
4702 .Case([&](LLVM::UMinOp) { return llvm::AtomicRMWInst::BinOp::UMin; })
4703 .Case([&](LLVM::FAddOp) { return llvm::AtomicRMWInst::BinOp::FAdd; })
4704 .Case([&](LLVM::FSubOp) { return llvm::AtomicRMWInst::BinOp::FSub; })
4705 .Default(llvm::AtomicRMWInst::BinOp::BAD_BINOP);
4706}
4707
4708static void extractAtomicControlFlags(omp::AtomicUpdateOp atomicUpdateOp,
4709 bool &isIgnoreDenormalMode,
4710 bool &isFineGrainedMemory,
4711 bool &isRemoteMemory) {
4712 isIgnoreDenormalMode = false;
4713 isFineGrainedMemory = false;
4714 isRemoteMemory = false;
4715 if (atomicUpdateOp &&
4716 atomicUpdateOp->hasAttr(atomicUpdateOp.getAtomicControlAttrName())) {
4717 mlir::omp::AtomicControlAttr atomicControlAttr =
4718 atomicUpdateOp.getAtomicControlAttr();
4719 isIgnoreDenormalMode = atomicControlAttr.getIgnoreDenormalMode();
4720 isFineGrainedMemory = atomicControlAttr.getFineGrainedMemory();
4721 isRemoteMemory = atomicControlAttr.getRemoteMemory();
4722 }
4723}
4724
4725/// Converts an OpenMP atomic update operation using OpenMPIRBuilder.
4726static LogicalResult
4727convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst,
4728 llvm::IRBuilderBase &builder,
4729 LLVM::ModuleTranslation &moduleTranslation) {
4730 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4731 if (failed(checkImplementationStatus(*opInst)))
4732 return failure();
4733
4734 // Convert values and types.
4735 auto &innerOpList = opInst.getRegion().front().getOperations();
4736 bool isXBinopExpr{false};
4737 llvm::AtomicRMWInst::BinOp binop;
4738 mlir::Value mlirExpr;
4739 llvm::Value *llvmExpr = nullptr;
4740 llvm::Value *llvmX = nullptr;
4741 llvm::Type *llvmXElementType = nullptr;
4742 if (innerOpList.size() == 2) {
4743 // The two operations here are the update and the terminator.
4744 // Since we can identify the update operation, there is a possibility
4745 // that we can generate the atomicrmw instruction.
4746 mlir::Operation &innerOp = *opInst.getRegion().front().begin();
4747 if (!llvm::is_contained(innerOp.getOperands(),
4748 opInst.getRegion().getArgument(0))) {
4749 return opInst.emitError("no atomic update operation with region argument"
4750 " as operand found inside atomic.update region");
4751 }
4752 binop = convertBinOpToAtomic(innerOp);
4753 isXBinopExpr = innerOp.getOperand(0) == opInst.getRegion().getArgument(0);
4754 mlirExpr = (isXBinopExpr ? innerOp.getOperand(1) : innerOp.getOperand(0));
4755 llvmExpr = moduleTranslation.lookupValue(mlirExpr);
4756 } else {
4757 // Since the update region includes more than one operation
4758 // we will resort to generating a cmpxchg loop.
4759 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
4760 }
4761 llvmX = moduleTranslation.lookupValue(opInst.getX());
4762 llvmXElementType = moduleTranslation.convertType(
4763 opInst.getRegion().getArgument(0).getType());
4764 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
4765 /*isSigned=*/false,
4766 /*isVolatile=*/false};
4767
4768 llvm::AtomicOrdering atomicOrdering =
4769 convertAtomicOrdering(opInst.getMemoryOrder());
4770
4771 // Generate update code.
4772 auto updateFn =
4773 [&opInst, &moduleTranslation](
4774 llvm::Value *atomicx,
4775 llvm::IRBuilder<> &builder) -> llvm::Expected<llvm::Value *> {
4776 Block &bb = *opInst.getRegion().begin();
4777 moduleTranslation.mapValue(*opInst.getRegion().args_begin(), atomicx);
4778 moduleTranslation.mapBlock(&bb, builder.GetInsertBlock());
4779 if (failed(moduleTranslation.convertBlock(bb, true, builder)))
4780 return llvm::make_error<PreviouslyReportedError>();
4781
4782 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.getTerminator());
4783 assert(yieldop && yieldop.getResults().size() == 1 &&
4784 "terminator must be omp.yield op and it must have exactly one "
4785 "argument");
4786 return moduleTranslation.lookupValue(yieldop.getResults()[0]);
4787 };
4788
4789 bool isIgnoreDenormalMode;
4790 bool isFineGrainedMemory;
4791 bool isRemoteMemory;
4792 extractAtomicControlFlags(opInst, isIgnoreDenormalMode, isFineGrainedMemory,
4793 isRemoteMemory);
4794 // Handle ambiguous alloca, if any.
4795 auto allocaIP = findAllocInsertPoints(builder, moduleTranslation);
4796 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4797 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4798 ompBuilder->createAtomicUpdate(ompLoc, allocaIP, llvmAtomicX, llvmExpr,
4799 atomicOrdering, binop, updateFn,
4800 isXBinopExpr, isIgnoreDenormalMode,
4801 isFineGrainedMemory, isRemoteMemory);
4802
4803 if (failed(handleError(afterIP, *opInst)))
4804 return failure();
4805
4806 builder.restoreIP(*afterIP);
4807 return success();
4808}
4809
4810static LogicalResult
4811convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp,
4812 llvm::IRBuilderBase &builder,
4813 LLVM::ModuleTranslation &moduleTranslation) {
4814 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4815 if (failed(checkImplementationStatus(*atomicCaptureOp)))
4816 return failure();
4817
4818 mlir::Value mlirExpr;
4819 bool isXBinopExpr = false, isPostfixUpdate = false;
4820 llvm::AtomicRMWInst::BinOp binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
4821
4822 omp::AtomicUpdateOp atomicUpdateOp = atomicCaptureOp.getAtomicUpdateOp();
4823 omp::AtomicWriteOp atomicWriteOp = atomicCaptureOp.getAtomicWriteOp();
4824
4825 assert((atomicUpdateOp || atomicWriteOp) &&
4826 "internal op must be an atomic.update or atomic.write op");
4827
4828 if (atomicWriteOp) {
4829 isPostfixUpdate = true;
4830 mlirExpr = atomicWriteOp.getExpr();
4831 } else {
4832 isPostfixUpdate = atomicCaptureOp.getSecondOp() ==
4833 atomicCaptureOp.getAtomicUpdateOp().getOperation();
4834 auto &innerOpList = atomicUpdateOp.getRegion().front().getOperations();
4835 // Find the binary update operation that uses the region argument
4836 // and get the expression to update
4837 if (innerOpList.size() == 2) {
4838 mlir::Operation &innerOp = *atomicUpdateOp.getRegion().front().begin();
4839 if (!llvm::is_contained(innerOp.getOperands(),
4840 atomicUpdateOp.getRegion().getArgument(0))) {
4841 return atomicUpdateOp.emitError(
4842 "no atomic update operation with region argument"
4843 " as operand found inside atomic.update region");
4844 }
4845 binop = convertBinOpToAtomic(innerOp);
4846 isXBinopExpr =
4847 innerOp.getOperand(0) == atomicUpdateOp.getRegion().getArgument(0);
4848 mlirExpr = (isXBinopExpr ? innerOp.getOperand(1) : innerOp.getOperand(0));
4849 } else {
4850 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
4851 }
4852 }
4853
4854 llvm::Value *llvmExpr = moduleTranslation.lookupValue(mlirExpr);
4855 llvm::Value *llvmX =
4856 moduleTranslation.lookupValue(atomicCaptureOp.getAtomicReadOp().getX());
4857 llvm::Value *llvmV =
4858 moduleTranslation.lookupValue(atomicCaptureOp.getAtomicReadOp().getV());
4859 llvm::Type *llvmXElementType = moduleTranslation.convertType(
4860 atomicCaptureOp.getAtomicReadOp().getElementType());
4861 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
4862 /*isSigned=*/false,
4863 /*isVolatile=*/false};
4864 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicV = {llvmV, llvmXElementType,
4865 /*isSigned=*/false,
4866 /*isVolatile=*/false};
4867
4868 llvm::AtomicOrdering atomicOrdering =
4869 convertAtomicOrdering(atomicCaptureOp.getMemoryOrder());
4870
4871 auto updateFn =
4872 [&](llvm::Value *atomicx,
4873 llvm::IRBuilder<> &builder) -> llvm::Expected<llvm::Value *> {
4874 if (atomicWriteOp)
4875 return moduleTranslation.lookupValue(atomicWriteOp.getExpr());
4876 Block &bb = *atomicUpdateOp.getRegion().begin();
4877 moduleTranslation.mapValue(*atomicUpdateOp.getRegion().args_begin(),
4878 atomicx);
4879 moduleTranslation.mapBlock(&bb, builder.GetInsertBlock());
4880 if (failed(moduleTranslation.convertBlock(bb, true, builder)))
4881 return llvm::make_error<PreviouslyReportedError>();
4882
4883 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.getTerminator());
4884 assert(yieldop && yieldop.getResults().size() == 1 &&
4885 "terminator must be omp.yield op and it must have exactly one "
4886 "argument");
4887 return moduleTranslation.lookupValue(yieldop.getResults()[0]);
4888 };
4889
4890 bool isIgnoreDenormalMode;
4891 bool isFineGrainedMemory;
4892 bool isRemoteMemory;
4893 extractAtomicControlFlags(atomicUpdateOp, isIgnoreDenormalMode,
4894 isFineGrainedMemory, isRemoteMemory);
4895 // Handle ambiguous alloca, if any.
4896 auto allocaIP = findAllocInsertPoints(builder, moduleTranslation);
4897 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4898 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4899 ompBuilder->createAtomicCapture(
4900 ompLoc, allocaIP, llvmAtomicX, llvmAtomicV, llvmExpr, atomicOrdering,
4901 binop, updateFn, atomicUpdateOp, isPostfixUpdate, isXBinopExpr,
4902 isIgnoreDenormalMode, isFineGrainedMemory, isRemoteMemory);
4903
4904 if (failed(handleError(afterIP, *atomicCaptureOp)))
4905 return failure();
4906
4907 builder.restoreIP(*afterIP);
4908 return success();
4909}
4910
4911/// Helper to extract the OMPAtomicCompareOp from an integer comparison
4912/// predicate. Returns std::nullopt for unsupported predicates.
4913static std::optional<llvm::omp::OMPAtomicCompareOp>
4914convertICmpPredicateToAtomicCompareOp(LLVM::ICmpPredicate predicate) {
4915 switch (predicate) {
4916 case LLVM::ICmpPredicate::eq:
4917 return llvm::omp::OMPAtomicCompareOp::EQ;
4918 case LLVM::ICmpPredicate::slt:
4919 case LLVM::ICmpPredicate::ult:
4920 return llvm::omp::OMPAtomicCompareOp::MIN;
4921 case LLVM::ICmpPredicate::sgt:
4922 case LLVM::ICmpPredicate::ugt:
4923 return llvm::omp::OMPAtomicCompareOp::MAX;
4924 default:
4925 return std::nullopt;
4926 }
4927}
4928
4929/// Helper to extract the OMPAtomicCompareOp from a floating-point comparison
4930/// predicate. Returns std::nullopt for unsupported predicates.
4931static std::optional<llvm::omp::OMPAtomicCompareOp>
4932convertFCmpPredicateToAtomicCompareOp(LLVM::FCmpPredicate predicate) {
4933 switch (predicate) {
4934 case LLVM::FCmpPredicate::oeq:
4935 case LLVM::FCmpPredicate::ueq:
4936 return llvm::omp::OMPAtomicCompareOp::EQ;
4937 case LLVM::FCmpPredicate::olt:
4938 case LLVM::FCmpPredicate::ult:
4939 return llvm::omp::OMPAtomicCompareOp::MIN;
4940 case LLVM::FCmpPredicate::ogt:
4941 case LLVM::FCmpPredicate::ugt:
4942 return llvm::omp::OMPAtomicCompareOp::MAX;
4943 default:
4944 return std::nullopt;
4945 }
4946}
4947
4948/// Converts an omp.atomic.compare operation to LLVM IR.
4949///
4950/// if (x == e) x = d
4951/// The region contains a comparison + select pattern:
4952/// ^bb0(%xval: T):
4953/// %cmp = llvm.icmp/fcmp <pred> %xval, %e : T
4954/// %sel = llvm.select %cmp, %d, %xval : i1, T
4955/// omp.yield(%sel : T)
4956///
4957/// From MLIR extract:
4958/// 1) comparison operator
4959/// 2) expected value (e)
4960/// 3) desired value (d)
4961/// These are passed to OpenMPIRBuilder::createAtomicCompare which generates
4962/// the actual cmpxchg / atomicrmw instruction.
4963///
4964static LogicalResult
4965convertOmpAtomicCompare(omp::AtomicCompareOp atomicCompareOp,
4966 llvm::IRBuilderBase &builder,
4967 LLVM::ModuleTranslation &moduleTranslation) {
4968 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4969 if (failed(checkImplementationStatus(*atomicCompareOp)))
4970 return failure();
4971
4972 Region &region = atomicCompareOp.getRegion();
4973 Block &block = region.front();
4974
4975 // Determine element type from the region block argument
4976 llvm::Type *llvmXElementType =
4977 moduleTranslation.convertType(block.getArgument(0).getType());
4978 if (!llvmXElementType)
4979 return atomicCompareOp.emitError(
4980 "unable to determine element type for atomic compare");
4981
4982 llvm::Value *llvmX = moduleTranslation.lookupValue(atomicCompareOp.getX());
4983
4984 // IsSigned is determined from the comparison predicate in the region.
4985 // Signed ICmp predicates (slt/sgt) set this to true; unsigned (ult/ugt)
4986 // leave it false. For EQ and float comparisons, signedness is irrelevant.
4987 bool isSigned = false;
4988 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
4989 isSigned,
4990 /*IsVolatile=*/false};
4991
4992 llvm::AtomicOrdering atomicOrdering =
4993 convertAtomicOrdering(atomicCompareOp.getMemoryOrder());
4994
4995 auto isAtomicComparePatternOp = [](Operation &op) {
4996 return llvm::isa<LLVM::ICmpOp, LLVM::FCmpOp, LLVM::SelectOp, LLVM::AndOp,
4997 LLVM::OrOp>(op);
4998 };
4999
5000 // Pre-translate operations inside the region that compute e and d (e.g.,
5001 // GEP, loads for dereferencing Fortran pointers) but are not part of the
5002 // atomic compare-and-swap pattern (icmp/fcmp, select, and/or).
5003 //
5004 // 1) Validity: The OpenMP spec requires e and d to be evaluated before the
5005 // atomic operation, so emitting their computation here is correct.
5006 // 2) Memory effects: These ops only depend on values defined outside the
5007 // region. They cannot observe the block argument (%xval), which is the
5008 // value loaded atomically by cmpxchg and does not exist yet.
5009 // 3) Invariant enforcement: The `allOperandsMapped` check below skips any
5010 // op whose operands include the unmapped block argument, guaranteeing
5011 // only region-external-dependent ops are pre-translated.
5012 for (Operation &op : block.without_terminator()) {
5013 // Skip operations that form the atomic compare pattern — these are
5014 // not emitted as individual instructions but are analyzed below to
5015 // extract the comparison predicate, expected value (e), and desired
5016 // value (d) for generating a single cmpxchg/atomicrmw.
5017 if (isAtomicComparePatternOp(op))
5018 continue;
5019
5020 // Avoid translating ops that depend on the unmapped block argument.
5021 bool allOperandsMapped = llvm::all_of(op.getOperands(), [&](mlir::Value v) {
5022 return moduleTranslation.lookupValue(v) != nullptr;
5023 });
5024 if (!allOperandsMapped)
5025 continue;
5026
5027 if (failed(moduleTranslation.convertOperation(op, builder)))
5028 return atomicCompareOp.emitError(
5029 "failed to translate operation inside atomic compare region");
5030 }
5031
5032 // Look up a value that may have been pre-translated or defined outside the
5033 // region.
5034 auto materializeValue = [&](mlir::Value val) -> llvm::Value * {
5035 // Check if the value is already mapped (pre-translated or defined outside).
5036 if (llvm::Value *existing = moduleTranslation.lookupValue(val))
5037 return existing;
5038 // Fallback for a single LoadOp whose address is mapped but whose result
5039 // was not pre-translated.
5040 if (auto loadOp = val.getDefiningOp<LLVM::LoadOp>()) {
5041 if (loadOp->getParentRegion() == &region) {
5042 llvm::Value *loadAddr = moduleTranslation.lookupValue(loadOp.getAddr());
5043 if (!loadAddr)
5044 return nullptr;
5045 llvm::Type *loadType =
5046 moduleTranslation.convertType(loadOp.getResult().getType());
5047 return builder.CreateLoad(loadType, loadAddr);
5048 }
5049 }
5050 return nullptr;
5051 };
5052
5053 // Walk the region to extract comparison predicate, eVal, and dVal.
5054 // if (x == eVal) x = dVal
5055 llvm::omp::OMPAtomicCompareOp compareOp = llvm::omp::OMPAtomicCompareOp::EQ;
5056 llvm::Value *eVal = nullptr;
5057 llvm::Value *dVal = nullptr;
5058 bool isXBinopExpr = false;
5059
5060 auto traceToAggregate = [](mlir::Value v) -> mlir::Value {
5061 if (auto extractOp = v.getDefiningOp<LLVM::ExtractValueOp>())
5062 return extractOp.getContainer();
5063 return nullptr;
5064 };
5065
5066 // Check for a decomposed complex comparison pattern:
5067 // %re_x = llvm.extractvalue %xval[0]
5068 // %re_e = llvm.extractvalue %eStruct[0]
5069 // %cmp_re = llvm.fcmp "oeq" %re_x, %re_e
5070 // %im_x = llvm.extractvalue %xval[1]
5071 // %im_e = llvm.extractvalue %eStruct[1]
5072 // %cmp_im = llvm.fcmp "oeq" %im_x, %im_e
5073 // %cmp = llvm.and %cmp_re, %cmp_im (for EQ)
5074 // Detect this by looking for AndOp/OrOp whose operands are both FCmpOps
5075 // operating on ExtractValueOps from the block argument.
5076 bool isComplexPattern = false;
5077 for (Operation &op : block.getOperations()) {
5078 if (!isa<LLVM::AndOp, LLVM::OrOp>(op))
5079 continue;
5080
5081 // Using : %cmp = llvm.and %cmp_re, %cmp_im
5082 auto lhsFcmp = op.getOperand(0).getDefiningOp<LLVM::FCmpOp>();
5083 auto rhsFcmp = op.getOperand(1).getDefiningOp<LLVM::FCmpOp>();
5084 if (!lhsFcmp || !rhsFcmp)
5085 continue;
5086
5087 // Using : %cmp_re = llvm.fcmp "oeq" %re_x, %re_e
5088 // Check presence of x (block argument) and get e.
5089 mlir::Value lhsAgg0 = traceToAggregate(lhsFcmp.getOperand(0));
5090 mlir::Value lhsAgg1 = traceToAggregate(lhsFcmp.getOperand(1));
5091 bool lhsXIsOp0 = (lhsAgg0 == block.getArgument(0));
5092 bool lhsXIsOp1 = (lhsAgg1 == block.getArgument(0));
5093 if (!lhsXIsOp0 && !lhsXIsOp1)
5094 continue;
5095 mlir::Value eAggregate = lhsXIsOp0 ? lhsAgg1 : lhsAgg0;
5096 if (!eAggregate)
5097 continue;
5098
5099 if (isa<LLVM::AndOp>(op))
5100 compareOp = llvm::omp::OMPAtomicCompareOp::EQ;
5101 else
5102 // OrOp corresponds to NE, which is not a valid atomic compare op.
5103 return atomicCompareOp.emitError(
5104 "unsupported comparison predicate (NE) for complex atomic compare");
5105
5106 isXBinopExpr = lhsXIsOp0;
5107 eVal = materializeValue(eAggregate);
5108 isComplexPattern = true;
5109 break;
5110 }
5111
5112 if (isComplexPattern) {
5113 // dVal from SelectOp or YieldOp.
5114 for (Operation &op : block.getOperations()) {
5115 if (auto selectOp = dyn_cast<LLVM::SelectOp>(op)) {
5116 dVal = materializeValue(selectOp.getTrueValue());
5117 break;
5118 }
5119 }
5120 if (!dVal) {
5121 auto yieldOp = cast<omp::YieldOp>(block.getTerminator());
5122 if (yieldOp.getResults().empty())
5123 return atomicCompareOp.emitError(
5124 "failed to extract desired value (d) from atomic compare region");
5125 dVal = materializeValue(yieldOp.getResults()[0]);
5126 }
5127
5128 const llvm::DataLayout &DL =
5129 builder.GetInsertBlock()->getModule()->getDataLayout();
5130 unsigned totalBits =
5131 DL.getTypeStoreSizeInBits(llvmXElementType).getFixedValue();
5132
5133 llvm::IntegerType *intTy =
5134 llvm::IntegerType::get(builder.getContext(), totalBits);
5135
5136 llvm::Align complexAlign = DL.getABITypeAlign(llvmXElementType);
5137 llvm::Align intAlign = DL.getABITypeAlign(intTy);
5138 llvm::Align maxAlign = std::max(complexAlign, intAlign);
5139
5140 llvm::AllocaInst *eAlloca =
5141 builder.CreateAlloca(llvmXElementType, nullptr, "cmplx.e");
5142 eAlloca->setAlignment(maxAlign);
5143 llvm::AllocaInst *dAlloca =
5144 builder.CreateAlloca(llvmXElementType, nullptr, "cmplx.d");
5145 dAlloca->setAlignment(maxAlign);
5146
5147 builder.CreateAlignedStore(eVal, eAlloca, maxAlign);
5148 llvm::Value *eInt =
5149 builder.CreateAlignedLoad(intTy, eAlloca, maxAlign, "cmplx.e.int");
5150 builder.CreateAlignedStore(dVal, dAlloca, maxAlign);
5151 llvm::Value *dInt =
5152 builder.CreateAlignedLoad(intTy, dAlloca, maxAlign, "cmplx.d.int");
5153
5154 llvm::AtomicOrdering failOrdering =
5155 llvm::AtomicCmpXchgInst::getStrongestFailureOrdering(atomicOrdering);
5156 auto *cmpXchg = builder.CreateAtomicCmpXchg(llvmX, eInt, dInt, maxAlign,
5157 atomicOrdering, failOrdering);
5158 cmpXchg->setWeak(atomicCompareOp.getWeak());
5159
5160 // Emit flush after atomic compare if needed (for release, acq_rel,
5161 // seq_cst orderings).
5162 if (atomicOrdering == llvm::AtomicOrdering::Release ||
5163 atomicOrdering == llvm::AtomicOrdering::AcquireRelease ||
5164 atomicOrdering == llvm::AtomicOrdering::SequentiallyConsistent) {
5165 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
5166 ompBuilder->createFlush(ompLoc);
5167 }
5168 return success();
5169 } else {
5170
5171 for (Operation &op : block.getOperations()) {
5172 if (auto icmpOp = dyn_cast<LLVM::ICmpOp>(op)) {
5173 auto maybeOp =
5174 convertICmpPredicateToAtomicCompareOp(icmpOp.getPredicate());
5175 if (!maybeOp)
5176 return atomicCompareOp.emitError(
5177 "unsupported comparison predicate in atomic compare");
5178 compareOp = *maybeOp;
5179
5180 LLVM::ICmpPredicate pred = icmpOp.getPredicate();
5181 isSigned = (pred == LLVM::ICmpPredicate::slt ||
5182 pred == LLVM::ICmpPredicate::sgt ||
5183 pred == LLVM::ICmpPredicate::sle ||
5184 pred == LLVM::ICmpPredicate::sge);
5185
5186 // Identify which operand is the block argument (x) and which is e.
5187 isXBinopExpr = (icmpOp.getOperand(0) == block.getArgument(0));
5188 mlir::Value eOperand =
5189 isXBinopExpr ? icmpOp.getOperand(1) : icmpOp.getOperand(0);
5190 eVal = materializeValue(eOperand);
5191 } else if (auto fcmpOp = dyn_cast<LLVM::FCmpOp>(op)) {
5192 auto maybeOp =
5193 convertFCmpPredicateToAtomicCompareOp(fcmpOp.getPredicate());
5194 if (!maybeOp)
5195 return atomicCompareOp.emitError(
5196 "unsupported comparison predicate in atomic compare");
5197 compareOp = *maybeOp;
5198
5199 isXBinopExpr = (fcmpOp.getOperand(0) == block.getArgument(0));
5200 mlir::Value eOperand =
5201 isXBinopExpr ? fcmpOp.getOperand(1) : fcmpOp.getOperand(0);
5202 eVal = materializeValue(eOperand);
5203 } else if (auto selectOp = dyn_cast<LLVM::SelectOp>(op)) {
5204 if (!dVal)
5205 dVal = materializeValue(selectOp.getTrueValue());
5206 }
5207 }
5208 }
5209
5210 // For non-complex patterns, also extract dVal from SelectOp.
5211 if (!dVal) {
5212 for (Operation &op : block.getOperations()) {
5213 if (auto selectOp = dyn_cast<LLVM::SelectOp>(op)) {
5214 dVal = materializeValue(selectOp.getTrueValue());
5215 break;
5216 }
5217 }
5218 }
5219
5220 if (!eVal)
5221 return atomicCompareOp.emitError(
5222 "failed to extract expected value (e) from atomic compare region");
5223 if (!dVal) {
5224 // Fall back to the yield operand.
5225 auto yieldOp = cast<omp::YieldOp>(block.getTerminator());
5226 if (yieldOp.getResults().empty())
5227 return atomicCompareOp.emitError(
5228 "failed to extract desired value (d) from atomic compare region");
5229 dVal = materializeValue(yieldOp.getResults()[0]);
5230 }
5231
5232 llvmAtomicX.IsSigned = isSigned;
5233
5234 llvm::OpenMPIRBuilder::AtomicOpValue vOpVal = {nullptr, nullptr, false,
5235 false};
5236 llvm::OpenMPIRBuilder::AtomicOpValue rOpVal = {nullptr, nullptr, false,
5237 false};
5238 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
5239
5240 bool isWeak = atomicCompareOp.getWeak();
5241
5242 bool savedHandleFPNegZero = ompBuilder->setHandleFPNegZero(true);
5243 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
5244 ompBuilder->createAtomicCompare(ompLoc, llvmAtomicX, vOpVal, rOpVal, eVal,
5245 dVal, atomicOrdering, compareOp,
5246 isXBinopExpr, false, false, isWeak);
5247 ompBuilder->setHandleFPNegZero(savedHandleFPNegZero);
5248
5249 if (failed(handleError(afterIP, *atomicCompareOp)))
5250 return failure();
5251
5252 builder.restoreIP(*afterIP);
5253 return success();
5254}
5255
5256static llvm::omp::Directive convertCancellationConstructType(
5257 omp::ClauseCancellationConstructType directive) {
5258 switch (directive) {
5259 case omp::ClauseCancellationConstructType::Loop:
5260 return llvm::omp::Directive::OMPD_for;
5261 case omp::ClauseCancellationConstructType::Parallel:
5262 return llvm::omp::Directive::OMPD_parallel;
5263 case omp::ClauseCancellationConstructType::Sections:
5264 return llvm::omp::Directive::OMPD_sections;
5265 case omp::ClauseCancellationConstructType::Taskgroup:
5266 return llvm::omp::Directive::OMPD_taskgroup;
5267 }
5268 llvm_unreachable("Unhandled cancellation construct type");
5269}
5270
5271static LogicalResult
5272convertOmpCancel(omp::CancelOp op, llvm::IRBuilderBase &builder,
5273 LLVM::ModuleTranslation &moduleTranslation) {
5274 if (failed(checkImplementationStatus(*op.getOperation())))
5275 return failure();
5276
5277 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
5278 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
5279
5280 llvm::Value *ifCond = nullptr;
5281 if (Value ifVar = op.getIfExpr())
5282 ifCond = moduleTranslation.lookupValue(ifVar);
5283
5284 llvm::omp::Directive cancelledDirective =
5285 convertCancellationConstructType(op.getCancelDirective());
5286
5287 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
5288 ompBuilder->createCancel(ompLoc, ifCond, cancelledDirective);
5289
5290 if (failed(handleError(afterIP, *op.getOperation())))
5291 return failure();
5292
5293 builder.restoreIP(afterIP.get());
5294
5295 return success();
5296}
5297
5298static LogicalResult
5299convertOmpCancellationPoint(omp::CancellationPointOp op,
5300 llvm::IRBuilderBase &builder,
5301 LLVM::ModuleTranslation &moduleTranslation) {
5302 if (failed(checkImplementationStatus(*op.getOperation())))
5303 return failure();
5304
5305 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
5306 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
5307
5308 llvm::omp::Directive cancelledDirective =
5309 convertCancellationConstructType(op.getCancelDirective());
5310
5311 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
5312 ompBuilder->createCancellationPoint(ompLoc, cancelledDirective);
5313
5314 if (failed(handleError(afterIP, *op.getOperation())))
5315 return failure();
5316
5317 builder.restoreIP(afterIP.get());
5318
5319 return success();
5320}
5321
5322/// Converts an OpenMP Threadprivate operation into LLVM IR using
5323/// OpenMPIRBuilder.
5324static LogicalResult
5325convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder,
5326 LLVM::ModuleTranslation &moduleTranslation) {
5327 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
5328 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
5329 auto threadprivateOp = cast<omp::ThreadprivateOp>(opInst);
5330
5331 if (failed(checkImplementationStatus(opInst)))
5332 return failure();
5333
5334 Value symAddr = threadprivateOp.getSymAddr();
5335 auto *symOp = symAddr.getDefiningOp();
5336
5337 if (auto asCast = dyn_cast<LLVM::AddrSpaceCastOp>(symOp))
5338 symOp = asCast.getOperand().getDefiningOp();
5339
5340 if (!isa<LLVM::AddressOfOp>(symOp))
5341 return opInst.emitError("Addressing symbol not found");
5342 LLVM::AddressOfOp addressOfOp = dyn_cast<LLVM::AddressOfOp>(symOp);
5343
5344 LLVM::GlobalOp global =
5345 addressOfOp.getGlobal(moduleTranslation.symbolTable());
5346 llvm::GlobalValue *globalValue = moduleTranslation.lookupGlobal(global);
5347 llvm::Type *type = globalValue->getValueType();
5348 llvm::TypeSize typeSize =
5349 builder.GetInsertBlock()->getModule()->getDataLayout().getTypeStoreSize(
5350 type);
5351 llvm::ConstantInt *size = builder.getInt64(typeSize.getFixedValue());
5352 llvm::Value *callInst = ompBuilder->createCachedThreadPrivate(
5353 ompLoc, globalValue, size, global.getSymName() + ".cache");
5354 moduleTranslation.mapValue(opInst.getResult(0), callInst);
5355
5356 return success();
5357}
5358
5359static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind
5360convertToDeviceClauseKind(mlir::omp::DeclareTargetDeviceType deviceClause) {
5361 switch (deviceClause) {
5362 case mlir::omp::DeclareTargetDeviceType::host:
5363 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseHost;
5364 break;
5365 case mlir::omp::DeclareTargetDeviceType::nohost:
5366 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseNoHost;
5367 break;
5368 case mlir::omp::DeclareTargetDeviceType::any:
5369 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseAny;
5370 break;
5371 }
5372 llvm_unreachable("unhandled device clause");
5373}
5374
5375static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind
5377 mlir::omp::DeclareTargetCaptureClause captureClause) {
5378 switch (captureClause) {
5379 case mlir::omp::DeclareTargetCaptureClause::to:
5380 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo;
5381 case mlir::omp::DeclareTargetCaptureClause::link:
5382 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink;
5383 case mlir::omp::DeclareTargetCaptureClause::enter:
5384 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter;
5385 case mlir::omp::DeclareTargetCaptureClause::none:
5386 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryNone;
5387 }
5388 llvm_unreachable("unhandled capture clause");
5389}
5390
5392 Operation *op = value.getDefiningOp();
5393 if (auto addrCast = dyn_cast_if_present<LLVM::AddrSpaceCastOp>(op))
5394 op = addrCast->getOperand(0).getDefiningOp();
5395 if (auto addressOfOp = dyn_cast_if_present<LLVM::AddressOfOp>(op)) {
5396 auto modOp = addressOfOp->getParentOfType<mlir::ModuleOp>();
5397 return modOp.lookupSymbol(addressOfOp.getGlobalName());
5398 }
5399 return nullptr;
5400}
5401
5403 while (Operation *op = value.getDefiningOp()) {
5404 if (auto addrCast = dyn_cast_if_present<LLVM::AddrSpaceCastOp>(op))
5405 value = addrCast.getOperand();
5406 // Traces through hlfir.declare, fir.declare to reach the base address and
5407 // use for type lookup.
5408 else if (op->getName().getIdentifier() &&
5409 (op->getName().getIdentifier().str() == "hlfir.declare" ||
5410 op->getName().getIdentifier().str() == "fir.declare")) {
5411 if (op->getNumOperands() > 0)
5412 value = op->getOperand(0);
5413 else
5414 break;
5415 } else {
5416 break;
5417 }
5418 }
5419 return value;
5420}
5421
5422static llvm::SmallString<64>
5423getDeclareTargetRefPtrSuffix(LLVM::GlobalOp globalOp,
5424 llvm::OpenMPIRBuilder &ompBuilder,
5425 llvm::vfs::FileSystem &vfs) {
5426 llvm::SmallString<64> suffix;
5427 llvm::raw_svector_ostream os(suffix);
5428 if (globalOp.getVisibility() == mlir::SymbolTable::Visibility::Private) {
5429 auto loc = globalOp->getLoc()->findInstanceOf<FileLineColLoc>();
5430 auto fileInfoCallBack = [&loc]() {
5431 return std::pair<std::string, uint64_t>(
5432 llvm::StringRef(loc.getFilename()), loc.getLine());
5433 };
5434
5435 os << llvm::format(
5436 "_%x",
5437 ompBuilder.getTargetEntryUniqueInfo(fileInfoCallBack, vfs).FileID);
5438 }
5439 os << "_decl_tgt_ref_ptr";
5440
5441 return suffix;
5442}
5443
5444static bool isDeclareTargetLink(Value value) {
5445 if (auto declareTargetGlobal =
5446 dyn_cast_if_present<omp::DeclareTargetInterface>(
5447 getGlobalOpFromValue(value)))
5448 if (declareTargetGlobal.getDeclareTargetCaptureClause() ==
5449 omp::DeclareTargetCaptureClause::link)
5450 return true;
5451 return false;
5452}
5453
5454static bool isDeclareTargetTo(Value value) {
5455 if (auto declareTargetGlobal =
5456 dyn_cast_if_present<omp::DeclareTargetInterface>(
5457 getGlobalOpFromValue(value)))
5458 if (declareTargetGlobal.getDeclareTargetCaptureClause() ==
5459 omp::DeclareTargetCaptureClause::to ||
5460 declareTargetGlobal.getDeclareTargetCaptureClause() ==
5461 omp::DeclareTargetCaptureClause::enter)
5462 return true;
5463 return false;
5464}
5465
5466// Returns the reference pointer generated by the lowering of the declare
5467// target operation in cases where the link clause is used or the to clause is
5468// used in USM mode.
5469static llvm::Value *
5471 LLVM::ModuleTranslation &moduleTranslation) {
5472 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
5473 if (auto gOp =
5474 dyn_cast_or_null<LLVM::GlobalOp>(getGlobalOpFromValue(value))) {
5475 // In this case, we must utilise the reference pointer generated by
5476 // the declare target operation, similar to Clang
5477 if (isDeclareTargetLink(value) ||
5478 (isDeclareTargetTo(value) &&
5479 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
5481 gOp, *ompBuilder, moduleTranslation.getFileSystem());
5482
5483 if (gOp.getSymName().contains(suffix))
5484 return moduleTranslation.getLLVMModule()->getNamedValue(
5485 gOp.getSymName());
5486
5487 return moduleTranslation.getLLVMModule()->getNamedValue(
5488 (gOp.getSymName().str() + suffix.str()).str());
5489 }
5490 }
5491 return nullptr;
5492}
5493
5494namespace {
5495// Append customMappers information to existing MapInfosTy
5496struct MapInfosTy : llvm::OpenMPIRBuilder::MapInfosTy {
5497 SmallVector<Operation *, 4> Mappers;
5498
5499 /// Append arrays in \a CurInfo.
5500 void append(MapInfosTy &curInfo) {
5501 Mappers.append(curInfo.Mappers.begin(), curInfo.Mappers.end());
5502 llvm::OpenMPIRBuilder::MapInfosTy::append(curInfo);
5503 }
5504};
5505// A small helper structure to contain data gathered
5506// for map lowering and coalese it into one area and
5507// avoiding extra computations such as searches in the
5508// llvm module for lowered mapped variables or checking
5509// if something is declare target (and retrieving the
5510// value) more than neccessary.
5511struct MapInfoData : MapInfosTy {
5512 llvm::SmallVector<bool, 4> IsDeclareTarget;
5513 llvm::SmallVector<bool, 4> IsAMember;
5514 // Identify if mapping was added by mapClause or use_device clauses.
5515 llvm::SmallVector<bool, 4> IsAMapping;
5516 llvm::SmallVector<mlir::Operation *, 4> MapClause;
5517 llvm::SmallVector<llvm::Value *, 4> OriginalValue;
5518 // Stripped off array/pointer to get the underlying
5519 // element type
5520 llvm::SmallVector<llvm::Type *, 4> BaseType;
5521
5522 /// Append arrays in \a CurInfo.
5523 void append(MapInfoData &CurInfo) {
5524 IsDeclareTarget.append(CurInfo.IsDeclareTarget.begin(),
5525 CurInfo.IsDeclareTarget.end());
5526 MapClause.append(CurInfo.MapClause.begin(), CurInfo.MapClause.end());
5527 OriginalValue.append(CurInfo.OriginalValue.begin(),
5528 CurInfo.OriginalValue.end());
5529 BaseType.append(CurInfo.BaseType.begin(), CurInfo.BaseType.end());
5530 MapInfosTy::append(CurInfo);
5531 }
5532};
5533
5534enum class TargetDirectiveEnumTy : uint32_t {
5535 None = 0,
5536 Target = 1,
5537 TargetData = 2,
5538 TargetEnterData = 3,
5539 TargetExitData = 4,
5540 TargetUpdate = 5
5541};
5542
5543static TargetDirectiveEnumTy getTargetDirectiveEnumTyFromOp(Operation *op) {
5544 return llvm::TypeSwitch<Operation *, TargetDirectiveEnumTy>(op)
5545 .Case([](omp::TargetDataOp) { return TargetDirectiveEnumTy::TargetData; })
5546 .Case([](omp::TargetEnterDataOp) {
5547 return TargetDirectiveEnumTy::TargetEnterData;
5548 })
5549 .Case([&](omp::TargetExitDataOp) {
5550 return TargetDirectiveEnumTy::TargetExitData;
5551 })
5552 .Case([&](omp::TargetUpdateOp) {
5553 return TargetDirectiveEnumTy::TargetUpdate;
5554 })
5555 .Case([&](omp::TargetOp) { return TargetDirectiveEnumTy::Target; })
5556 .Default([&](Operation *op) { return TargetDirectiveEnumTy::None; });
5557}
5558
5559} // namespace
5560
5561static uint64_t getArrayElementSizeInBits(LLVM::LLVMArrayType arrTy,
5562 DataLayout &dl) {
5563 if (auto nestedArrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(
5564 arrTy.getElementType()))
5565 return getArrayElementSizeInBits(nestedArrTy, dl);
5566 return dl.getTypeSizeInBits(arrTy.getElementType());
5567}
5568
5569// This function calculates the size to be offloaded for a specified type, given
5570// its associated map clause (which can contain bounds information which affects
5571// the total size), this size is calculated based on the underlying element type
5572// e.g. given a 1-D array of ints, we will calculate the size from the integer
5573// type * number of elements in the array. This size can be used in other
5574// calculations but is ultimately used as an argument to the OpenMP runtimes
5575// kernel argument structure which is generated through the combinedInfo data
5576// structures.
5577// This function is somewhat equivalent to Clang's getExprTypeSize inside of
5578// CGOpenMPRuntime.cpp.
5579static llvm::Value *getSizeInBytes(DataLayout &dl, const mlir::Type &type,
5580 Operation *clauseOp,
5581 llvm::Value *basePointer,
5582 llvm::Type *baseType,
5583 llvm::IRBuilderBase &builder,
5584 LLVM::ModuleTranslation &moduleTranslation) {
5585 if (auto memberClause =
5586 mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(clauseOp)) {
5587 // This calculates the size to transfer based on bounds and the underlying
5588 // element type, provided bounds have been specified (Fortran
5589 // pointers/allocatables/target and arrays that have sections specified fall
5590 // into this as well)
5591 if (!memberClause.getBounds().empty()) {
5592 llvm::Value *elementCount = builder.getInt64(1);
5593 for (auto bounds : memberClause.getBounds()) {
5594 if (auto boundOp = mlir::dyn_cast_if_present<mlir::omp::MapBoundsOp>(
5595 bounds.getDefiningOp())) {
5596 // The below calculation for the size to be mapped calculated from the
5597 // map.info's bounds is: (elemCount * [UB - LB] + 1), later we
5598 // multiply by the underlying element types byte size to get the full
5599 // size to be offloaded based on the bounds
5600 elementCount = builder.CreateMul(
5601 elementCount,
5602 builder.CreateAdd(
5603 builder.CreateSub(
5604 moduleTranslation.lookupValue(boundOp.getUpperBound()),
5605 moduleTranslation.lookupValue(boundOp.getLowerBound())),
5606 builder.getInt64(1)));
5607 }
5608 }
5609
5610 // utilising getTypeSizeInBits instead of getTypeSize as getTypeSize gives
5611 // the size in inconsistent byte or bit format.
5612 uint64_t underlyingTypeSzInBits = dl.getTypeSizeInBits(type);
5613 if (auto arrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(type))
5614 underlyingTypeSzInBits = getArrayElementSizeInBits(arrTy, dl);
5615
5616 // The size in bytes x number of elements, the sizeInBytes stored is
5617 // the underyling types size, e.g. if ptr<i32>, it'll be the i32's
5618 // size, so we do some on the fly runtime math to get the size in
5619 // bytes from the extent (ub - lb) * sizeInBytes. NOTE: This may need
5620 // some adjustment for members with more complex types.
5621 return builder.CreateMul(elementCount,
5622 builder.getInt64(underlyingTypeSzInBits / 8));
5623 }
5624 }
5625
5626 return builder.getInt64(dl.getTypeSizeInBits(type) / 8);
5627}
5628
5629// Convert the MLIR map flag set to the runtime map flag set for embedding
5630// in LLVM-IR. This is important as the two bit-flag lists do not correspond
5631// 1-to-1 as there's flags the runtime doesn't care about and vice versa.
5632// Certain flags are discarded here such as RefPtee and co.
5633static llvm::omp::OpenMPOffloadMappingFlags
5634convertClauseMapFlags(omp::ClauseMapFlags mlirFlags) {
5635 const bool hasExplicitMap =
5636 (mlirFlags & ~omp::ClauseMapFlags::is_device_ptr) !=
5637 omp::ClauseMapFlags::none;
5638
5639 llvm::omp::OpenMPOffloadMappingFlags mapType =
5640 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
5641
5642 if (bitEnumContainsAll(mlirFlags, omp::ClauseMapFlags::to))
5643 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
5644
5645 if (bitEnumContainsAll(mlirFlags, omp::ClauseMapFlags::from))
5646 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
5647
5648 if (bitEnumContainsAll(mlirFlags, omp::ClauseMapFlags::always))
5649 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
5650
5651 if (bitEnumContainsAll(mlirFlags, omp::ClauseMapFlags::del))
5652 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
5653
5654 if (bitEnumContainsAll(mlirFlags, omp::ClauseMapFlags::return_param))
5655 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
5656
5657 if (bitEnumContainsAll(mlirFlags, omp::ClauseMapFlags::priv))
5658 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRIVATE;
5659
5660 if (bitEnumContainsAll(mlirFlags, omp::ClauseMapFlags::literal))
5661 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
5662
5663 if (bitEnumContainsAll(mlirFlags, omp::ClauseMapFlags::implicit))
5664 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
5665
5666 if (bitEnumContainsAll(mlirFlags, omp::ClauseMapFlags::close))
5667 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
5668
5669 if (bitEnumContainsAll(mlirFlags, omp::ClauseMapFlags::present))
5670 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT;
5671
5672 if (bitEnumContainsAll(mlirFlags, omp::ClauseMapFlags::ompx_hold))
5673 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD;
5674
5675 if (bitEnumContainsAll(mlirFlags, omp::ClauseMapFlags::attach))
5676 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH;
5677
5678 if (bitEnumContainsAll(mlirFlags, omp::ClauseMapFlags::is_device_ptr)) {
5679 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
5680 if (!hasExplicitMap)
5681 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
5682 }
5683
5684 return mapType;
5685}
5686
5688 MapInfoData &mapData, SmallVectorImpl<Value> &mapVars,
5689 LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl,
5690 llvm::IRBuilderBase &builder, ArrayRef<Value> useDevPtrOperands = {},
5691 ArrayRef<Value> useDevAddrOperands = {},
5692 ArrayRef<Value> hasDevAddrOperands = {}) {
5693
5694 auto checkRefPtrOrPteeMapWithAttach = [](omp::ClauseMapFlags mapType) {
5695 bool hasRefType =
5696 bitEnumContainsAll(mapType, omp::ClauseMapFlags::ref_ptr) ||
5697 bitEnumContainsAll(mapType, omp::ClauseMapFlags::ref_ptee);
5698 return hasRefType &&
5699 bitEnumContainsAll(mapType, omp::ClauseMapFlags::attach);
5700 };
5701
5702 auto checkIsAMember = [](const auto &mapVars, auto mapOp) {
5703 // Check if this is a member mapping and correctly assign that it is, if
5704 // it is a member of a larger object.
5705 // TODO: Need better handling of members, and distinguishing of members
5706 // that are implicitly allocated on device vs explicitly passed in as
5707 // arguments.
5708 // TODO: May require some further additions to support nested record
5709 // types, i.e. member maps that can have member maps.
5710 for (Value mapValue : mapVars) {
5711 auto map = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
5712 for (auto member : map.getMembers())
5713 if (member == mapOp)
5714 return true;
5715 }
5716 return false;
5717 };
5718
5719 // Process MapOperands
5720 for (Value mapValue : mapVars) {
5721 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
5722 bool isRefPtrOrPteeMapWithAttach =
5723 checkRefPtrOrPteeMapWithAttach(mapOp.getMapType());
5724 Value offloadPtr = (mapOp.getVarPtrPtr() && !isRefPtrOrPteeMapWithAttach)
5725 ? mapOp.getVarPtrPtr()
5726 : mapOp.getVarPtr();
5727 mapData.OriginalValue.push_back(moduleTranslation.lookupValue(offloadPtr));
5728 mapData.Pointers.push_back(
5729 isRefPtrOrPteeMapWithAttach
5730 ? moduleTranslation.lookupValue(mapOp.getVarPtrPtr())
5731 : mapData.OriginalValue.back());
5732
5733 if (llvm::Value *refPtr =
5734 getRefPtrIfDeclareTarget(offloadPtr, moduleTranslation)) {
5735 mapData.IsDeclareTarget.push_back(true);
5736 mapData.BasePointers.push_back(refPtr);
5737 } else if (isDeclareTargetTo(offloadPtr)) {
5738 mapData.IsDeclareTarget.push_back(true);
5739 mapData.BasePointers.push_back(mapData.OriginalValue.back());
5740 } else { // regular mapped variable
5741 mapData.IsDeclareTarget.push_back(false);
5742 mapData.BasePointers.push_back(mapData.OriginalValue.back());
5743 }
5744
5745 // In every situation we currently have if we have a varPtrPtr present
5746 // we wish to utilise it's type for the base type, main cases are
5747 // currently Fortran descriptor base address maps and attach maps.
5748 mapData.BaseType.push_back(moduleTranslation.convertType(
5749 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtrType().value()
5750 : mapOp.getVarPtrType()));
5751
5752 // For the attach map cases, it's a little odd, as we effectively have to
5753 // utilise the base address (including all bounds offsets) for the pointer
5754 // field, the pointer address for the base address field, and the pointer
5755 // not the data (base addresses) size. So we end up with a mix of base
5756 // types and sizes we wish to insert here.
5757 mlir::Type sizeType = (isRefPtrOrPteeMapWithAttach || !mapOp.getVarPtrPtr())
5758 ? mapOp.getVarPtrType()
5759 : mapOp.getVarPtrPtrType().value();
5760 mapData.Sizes.push_back(getSizeInBytes(
5761 dl, sizeType, isRefPtrOrPteeMapWithAttach ? nullptr : mapOp,
5762 mapData.Pointers.back(), moduleTranslation.convertType(sizeType),
5763 builder, moduleTranslation));
5764 mapData.MapClause.push_back(mapOp.getOperation());
5765 mapData.Types.push_back(convertClauseMapFlags(mapOp.getMapType()));
5766 mapData.Names.push_back(LLVM::createMappingInformation(
5767 mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
5768 mapData.DevicePointers.push_back(llvm::OpenMPIRBuilder::DeviceInfoTy::None);
5769 if (mapOp.getMapperId())
5770 mapData.Mappers.push_back(
5772 mapOp, mapOp.getMapperIdAttr()));
5773 else
5774 mapData.Mappers.push_back(nullptr);
5775 mapData.IsAMapping.push_back(true);
5776 mapData.IsAMember.push_back(checkIsAMember(mapVars, mapOp));
5777 }
5778
5779 auto findMapInfo = [&mapData](llvm::Value *val,
5780 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy,
5781 size_t memberCount) {
5782 unsigned index = 0;
5783 bool found = false;
5784 for (llvm::Value *basePtr : mapData.OriginalValue) {
5785 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[index]);
5786 // TODO: Currently we define an equivalent mapping as
5787 // the same base pointer and an equivalent member count, but
5788 // that is a loose definition. We may have to extend to check
5789 // for other fields (varPtrPtr/individual members being mapped).
5790 // Note: Attach maps are not the same as a normal data transfer
5791 // they specify to the runtime to perform an attach map and they
5792 // (at least at the moment) are never something we would aim to
5793 // return in a use_dev_* clause, so they are skipped in terms of
5794 // duplicate maps.
5795 bool isAttachMap =
5796 (mapData.Types[index] &
5797 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH) ==
5798 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH;
5799 if (!isAttachMap && basePtr == val && mapData.IsAMapping[index] &&
5800 memberCount == mapOp.getMembers().size()) {
5801 found = true;
5802 mapData.Types[index] |=
5803 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
5804 mapData.DevicePointers[index] = devInfoTy;
5805 }
5806 index++;
5807 }
5808 return found;
5809 };
5810
5811 // Process useDevPtr(Addr)Operands
5812 auto addDevInfos = [&](const llvm::ArrayRef<Value> &useDevOperands,
5813 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
5814 for (Value mapValue : useDevOperands) {
5815 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
5816 Value offloadPtr =
5817 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
5818 llvm::Value *origValue = moduleTranslation.lookupValue(offloadPtr);
5819
5820 // Check if map info is already present for this entry.
5821 if (!findMapInfo(origValue, devInfoTy, mapOp.getMembers().size())) {
5822 mapData.OriginalValue.push_back(origValue);
5823 mapData.Pointers.push_back(mapData.OriginalValue.back());
5824 mapData.IsDeclareTarget.push_back(false);
5825 mapData.BasePointers.push_back(mapData.OriginalValue.back());
5826 mlir::Type baseTy = mapOp.getVarPtrPtr()
5827 ? mapOp.getVarPtrPtrType().value()
5828 : mapOp.getVarPtrType();
5829 mapData.BaseType.push_back(moduleTranslation.convertType(baseTy));
5830 mapData.Sizes.push_back(builder.getInt64(0));
5831 mapData.MapClause.push_back(mapOp.getOperation());
5832 mapData.Types.push_back(
5833 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM);
5834 mapData.Names.push_back(LLVM::createMappingInformation(
5835 mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
5836 mapData.DevicePointers.push_back(devInfoTy);
5837 mapData.Mappers.push_back(nullptr);
5838 mapData.IsAMapping.push_back(false);
5839 mapData.IsAMember.push_back(checkIsAMember(useDevOperands, mapOp));
5840 }
5841 }
5842 };
5843
5844 addDevInfos(useDevAddrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
5845 addDevInfos(useDevPtrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer);
5846
5847 for (Value mapValue : hasDevAddrOperands) {
5848 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
5849 Value offloadPtr =
5850 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
5851 llvm::Value *origValue = moduleTranslation.lookupValue(offloadPtr);
5852 auto mapType = convertClauseMapFlags(mapOp.getMapType());
5853 auto mapTypeAlways = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
5854 bool isDevicePtr =
5855 (mapOp.getMapType() & omp::ClauseMapFlags::is_device_ptr) !=
5856 omp::ClauseMapFlags::none;
5857
5858 mapData.OriginalValue.push_back(origValue);
5859 mapData.BasePointers.push_back(origValue);
5860 mapData.Pointers.push_back(origValue);
5861 mapData.IsDeclareTarget.push_back(false);
5862
5863 mlir::Type baseTy = mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtrType().value()
5864 : mapOp.getVarPtrType();
5865 mapData.BaseType.push_back(moduleTranslation.convertType(baseTy));
5866 mapData.Sizes.push_back(builder.getInt64(dl.getTypeSize(baseTy)));
5867
5868 mapData.MapClause.push_back(mapOp.getOperation());
5869 if (llvm::to_underlying(mapType & mapTypeAlways)) {
5870 // Descriptors are mapped with the ALWAYS flag, since they can get
5871 // rematerialized, so the address of the decriptor for a given object
5872 // may change from one place to another.
5873 mapData.Types.push_back(mapType);
5874 // Technically it's possible for a non-descriptor mapping to have
5875 // both has-device-addr and ALWAYS, so lookup the mapper in case it
5876 // exists.
5877 if (mapOp.getMapperId()) {
5878 mapData.Mappers.push_back(
5880 mapOp, mapOp.getMapperIdAttr()));
5881 } else {
5882 mapData.Mappers.push_back(nullptr);
5883 }
5884 } else {
5885 // For is_device_ptr we need the map type to propagate so the runtime
5886 // can materialize the device-side copy of the pointer container.
5887 mapData.Types.push_back(
5888 isDevicePtr ? mapType
5889 : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL);
5890 mapData.Mappers.push_back(nullptr);
5891 }
5892 mapData.Names.push_back(LLVM::createMappingInformation(
5893 mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
5894 mapData.DevicePointers.push_back(
5895 isDevicePtr ? llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer
5896 : llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
5897 mapData.IsAMapping.push_back(false);
5898 mapData.IsAMember.push_back(checkIsAMember(hasDevAddrOperands, mapOp));
5899 }
5900}
5901
5902static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp) {
5903 auto *res = llvm::find(mapData.MapClause, memberOp);
5904 assert(res != mapData.MapClause.end() &&
5905 "MapInfoOp for member not found in MapData, cannot return index");
5906 return std::distance(mapData.MapClause.begin(), res);
5907}
5908
5910 omp::MapInfoOp mapInfo, bool first = true) {
5911 ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
5912 llvm::SmallVector<size_t> occludedChildren;
5913 llvm::sort(
5914 indices.begin(), indices.end(), [&](const size_t a, const size_t b) {
5915 // Bail early if we are asked to look at the same index. If we do not
5916 // bail early, we can end up mistakenly adding indices to
5917 // occludedChildren. This can occur with some types of libc++ hardening.
5918 if (a == b)
5919 return false;
5920
5921 auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]);
5922 auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]);
5923
5924 for (auto it : llvm::zip(memberIndicesA, memberIndicesB)) {
5925 int64_t aIndex = mlir::cast<IntegerAttr>(std::get<0>(it)).getInt();
5926 int64_t bIndex = mlir::cast<IntegerAttr>(std::get<1>(it)).getInt();
5927
5928 if (aIndex == bIndex)
5929 continue;
5930
5931 if (aIndex < bIndex)
5932 return first;
5933
5934 if (aIndex > bIndex)
5935 return !first;
5936 }
5937
5938 // Iterated up until the end of the smallest member and
5939 // they were found to be equal up to that point, so select
5940 // the member with the lowest index count, so the "parent"
5941 bool memberAParent = memberIndicesA.size() < memberIndicesB.size();
5942 if (memberAParent)
5943 occludedChildren.push_back(b);
5944 else
5945 occludedChildren.push_back(a);
5946 return memberAParent;
5947 });
5948
5949 for (auto v : occludedChildren)
5950 indices.erase(std::remove(indices.begin(), indices.end(), v),
5951 indices.end());
5952}
5953
5954static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo,
5955 bool first) {
5956 ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
5957 // Only 1 member has been mapped, we can return it.
5958 if (indexAttr.size() == 1)
5959 return cast<omp::MapInfoOp>(mapInfo.getMembers()[0].getDefiningOp());
5960 llvm::SmallVector<size_t> indices(indexAttr.size());
5961 std::iota(indices.begin(), indices.end(), 0);
5962 sortMapIndices(indices, mapInfo, first);
5963 return llvm::cast<omp::MapInfoOp>(
5964 mapInfo.getMembers()[indices.front()].getDefiningOp());
5965}
5966
5967/// This function calculates the array/pointer offset for map data provided
5968/// with bounds operations, e.g. when provided something like the following:
5969///
5970/// Fortran
5971/// map(tofrom: array(2:5, 3:2))
5972///
5973/// We must calculate the initial pointer offset to pass across, this function
5974/// performs this using bounds.
5975///
5976/// TODO/WARNING: This only supports Fortran's column major indexing currently
5977/// as is noted in the note below and comments in the function, we must extend
5978/// this function when we add a C++ frontend.
5979/// NOTE: which while specified in row-major order it currently needs to be
5980/// flipped for Fortran's column order array allocation and access (as
5981/// opposed to C++'s row-major, hence the backwards processing where order is
5982/// important). This is likely important to keep in mind for the future when
5983/// we incorporate a C++ frontend, both frontends will need to agree on the
5984/// ordering of generated bounds operations (one may have to flip them) to
5985/// make the below lowering frontend agnostic. The offload size
5986/// calcualtion may also have to be adjusted for C++.
5987static std::vector<llvm::Value *>
5989 llvm::IRBuilderBase &builder, bool isArrayTy,
5990 OperandRange bounds) {
5991 std::vector<llvm::Value *> idx;
5992 // There's no bounds to calculate an offset from, we can safely
5993 // ignore and return no indices.
5994 if (bounds.empty())
5995 return idx;
5996
5997 // If we have an array type, then we have its type so can treat it as a
5998 // normal GEP instruction where the bounds operations are simply indexes
5999 // into the array. We currently do reverse order of the bounds, which
6000 // I believe leans more towards Fortran's column-major in memory.
6001 if (isArrayTy) {
6002 idx.push_back(builder.getInt64(0));
6003 for (int i = bounds.size() - 1; i >= 0; --i) {
6004 if (auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
6005 bounds[i].getDefiningOp())) {
6006 idx.push_back(moduleTranslation.lookupValue(boundOp.getLowerBound()));
6007 }
6008 }
6009 } else {
6010 // If we do not have an array type, but we have bounds, then we're dealing
6011 // with a pointer that's being treated like an array and we have the
6012 // underlying type e.g. an i32, or f64 etc, e.g. a fortran descriptor base
6013 // address (pointer pointing to the actual data) so we must caclulate the
6014 // offset using a single index which the following loop attempts to
6015 // compute using the standard column-major algorithm e.g for a 3D array:
6016 //
6017 // ((((c_idx * b_len) + b_idx) * a_len) + a_idx)
6018 //
6019 // It is of note that it's doing column-major rather than row-major at the
6020 // moment, but having a way for the frontend to indicate which major format
6021 // to use or standardizing/canonicalizing the order of the bounds to compute
6022 // the offset may be useful in the future when there's other frontends with
6023 // different formats.
6024 std::vector<llvm::Value *> dimensionIndexSizeOffset;
6025 for (int i = bounds.size() - 1; i >= 0; --i) {
6026 if (auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
6027 bounds[i].getDefiningOp())) {
6028 if (i == ((int)bounds.size() - 1))
6029 idx.emplace_back(
6030 moduleTranslation.lookupValue(boundOp.getLowerBound()));
6031 else
6032 idx.back() = builder.CreateAdd(
6033 builder.CreateMul(idx.back(), moduleTranslation.lookupValue(
6034 boundOp.getExtent())),
6035 moduleTranslation.lookupValue(boundOp.getLowerBound()));
6036 }
6037 }
6038 }
6039
6040 return idx;
6041}
6042
6044 llvm::transform(values, std::back_inserter(ints), [](Attribute value) {
6045 return cast<IntegerAttr>(value).getInt();
6046 });
6047}
6048
6049// Gathers members that are overlapping in the parent, excluding members that
6050// themselves overlap, keeping the top-most (closest to parents level) map.
6051static void
6053 omp::MapInfoOp parentOp) {
6054 // No members mapped, no overlaps.
6055 if (parentOp.getMembers().empty())
6056 return;
6057
6058 // Single member, we can insert and return early.
6059 if (parentOp.getMembers().size() == 1) {
6060 overlapMapDataIdxs.push_back(0);
6061 return;
6062 }
6063
6064 ArrayAttr indexAttr = parentOp.getMembersIndexAttr();
6065 size_t numMembers = indexAttr.size();
6066
6067 // Pre-convert all member indices to integer arrays for efficient comparison.
6068 llvm::SmallVector<llvm::SmallVector<int64_t>> memberIndices(numMembers);
6069 for (auto [i, indicesAttr] : llvm::enumerate(indexAttr))
6070 getAsIntegers(cast<ArrayAttr>(indicesAttr), memberIndices[i]);
6071
6072 // For each member, check if it's superseded by another (shorter prefix)
6073 // member. If member j's indices are a prefix of member i's indices, then
6074 // i is a child of j and should be skipped. e.g. if member [0] is mapped,
6075 // we skip members [0,1], [0,2], etc.
6076 llvm::SmallDenseSet<size_t> skipIndices;
6077 for (size_t i = 0; i < numMembers; ++i) {
6078 const auto &iIndices = memberIndices[i];
6079 for (size_t j = 0; j < numMembers; ++j) {
6080 if (i == j)
6081 continue;
6082 const auto &jIndices = memberIndices[j];
6083 // If j's indices are a strict prefix of i's indices, skip i
6084 if (jIndices.size() < iIndices.size() &&
6085 std::equal(jIndices.begin(), jIndices.end(), iIndices.begin())) {
6086 skipIndices.insert(i);
6087 break; // No need to check other potential parents
6088 }
6089 }
6090 }
6091
6092 // Collect indices of members that are not superseded by a parent.
6093 for (size_t i = 0; i < numMembers; ++i)
6094 if (!skipIndices.contains(i))
6095 overlapMapDataIdxs.push_back(i);
6096}
6097
6098// The intent is to verify if the mapped data being passed is a
6099// pointer -> pointee that requires special handling in certain cases,
6100// e.g. applying the OMP_MAP_PTR_AND_OBJ map type.
6101//
6102// There may be a better way to verify this, but unfortunately with
6103// opaque pointers we lose the ability to easily check if something is
6104// a pointer whilst maintaining access to the underlying type.
6105static bool checkIfPointerMap(omp::MapInfoOp mapOp) {
6106 // If we have a varPtrPtr field assigned then the underlying type is a pointer
6107 if (mapOp.getVarPtrPtr())
6108 return true;
6109
6110 // If the map data is declare target with a link clause, then it's represented
6111 // as a pointer when we lower it to LLVM-IR even if at the MLIR level it has
6112 // no relation to pointers.
6113 if (isDeclareTargetLink(mapOp.getVarPtr()))
6114 return true;
6115
6116 return false;
6117}
6118
6119/// This function handles the insertion of a single item of map data from
6120/// MapInfoData into the OMPIRBuilder's MapInfo list. Utilising this function
6121/// means the map being inserted can be treated as a non-parent map entity,
6122/// if the memberOfFlag is set then the map being inserted is treated as
6123/// a member map of a larger entity. The insertion into the MapInfo list of
6124/// the OMPIRBuilder can vary based on a number of factors, such as if it's
6125/// a ref_ptr or ref_ptee map, if it's a member of a record, what construct
6126/// the map belongs to and the various map type bit flags that are set for
6127/// the map.
6128static void
6129processIndividualMap(llvm::IRBuilderBase &builder,
6130 llvm::OpenMPIRBuilder &ompBuilder, MapInfoData &mapData,
6131 size_t mapDataIdx, MapInfosTy &combinedInfo,
6132 TargetDirectiveEnumTy targetDirective,
6133 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag =
6134 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE,
6135 bool isTargetParam = true, int mapDataParentIdx = -1) {
6136 auto mapFlag = mapData.Types[mapDataIdx];
6137 auto mapInfoOp = llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIdx]);
6138
6139 bool isPtrTy = checkIfPointerMap(mapInfoOp);
6140 bool isAttachMap = ((convertClauseMapFlags(mapInfoOp.getMapType()) &
6141 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH) ==
6142 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH);
6143
6144 // Declare target variables are not passed to the kernel, and for the moment
6145 // attach maps are not passed to the kernel. However, it is possible to create
6146 // attach maps that transfer data and thus can be kernel arguments, but our
6147 // existing frontend does not do this.
6148 if (isTargetParam &&
6149 (targetDirective == TargetDirectiveEnumTy::Target &&
6150 !mapData.IsDeclareTarget[mapDataIdx]) &&
6151 !isAttachMap)
6152 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
6153
6154 if (mapInfoOp.getMapCaptureType() == omp::VariableCaptureKind::ByCopy &&
6155 !isPtrTy)
6156 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
6157
6158 // If we have a pointer and it's part of a MEMBER_OF mapping we do not apply
6159 // MEMBER_OF, as the runtime currently has a work-around that utilises
6160 // MEMBER_OF to prevent reference updating in certain scenarios instead of
6161 // target_param. However, this causes a noticeable issue in cases where we
6162 // map some data (Fortran descriptor primarily at the moment), alter it on
6163 // the host, and then expect it to not be updated in a subsequent implicit map
6164 // (such as an implicit map on a target).
6165 if (memberOfFlag != llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE) {
6166 if (!isPtrTy && !isAttachMap)
6167 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
6168
6169 // The return parameter should be the over-riding parent in cases where we
6170 // have a return parameter that is echoed to all members, the main case of
6171 // this currently is with fortran descriptors. It may need more finessing
6172 // for C/C++ in the future or descriptors that are members of derived
6173 // types.
6174 mapFlag &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
6175 }
6176
6177 // We apply MAP_PTR_AND_OBJ when within a declare mapper object as it enforces
6178 // MEMBER_OF mappings on maps that are passed the initial nesting depth, which
6179 // includes pointed to data and attach members, both of which are technically
6180 // not part of the main object. This has the side effect of causing early
6181 // map-backs in certain cases where an implicit declare mapper has been
6182 // emitted for a target region. Applying MAP_PTR_AND_OBJ in these situations
6183 // circumvents this.
6184 if (isPtrTy && !isAttachMap && mapData.IsDeclareTarget[mapDataIdx])
6185 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
6186
6187 // if we're provided a mapDataParentIdx, then the data being mapped is
6188 // part of a larger object (in a parent <-> member mapping) and in this
6189 // case our BasePointer should be the parent. Except in the edge case
6190 // where we are mapping pointee data, where we try staying close to
6191 // what Clang currently does and utilise the regular base pointer of the
6192 // data.
6193 bool isRefPtee =
6194 !bitEnumContainsAll(mapInfoOp.getMapType(),
6195 omp::ClauseMapFlags::ref_ptr) &&
6196 bitEnumContainsAll(mapInfoOp.getMapType(), omp::ClauseMapFlags::ref_ptee);
6197 bool isRefPtrPtee = bitEnumContainsAll(mapInfoOp.getMapType(),
6198 omp::ClauseMapFlags::ref_ptr |
6199 omp::ClauseMapFlags::ref_ptee);
6200
6201 if (!mapInfoOp->getParentOfType<omp::DeclareMapperOp>() &&
6202 mapDataParentIdx >= 0 && !(isRefPtee || (isRefPtrPtee && isPtrTy))) {
6203 combinedInfo.BasePointers.emplace_back(
6204 mapData.BasePointers[mapDataParentIdx]);
6205 } else {
6206 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIdx]);
6207 }
6208
6209 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIdx]);
6210 combinedInfo.DevicePointers.emplace_back(
6211 memberOfFlag != llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE
6212 ? llvm::OpenMPIRBuilder::DeviceInfoTy::None
6213 : mapData.DevicePointers[mapDataIdx]);
6214 combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIdx]);
6215 combinedInfo.Names.emplace_back(mapData.Names[mapDataIdx]);
6216 combinedInfo.Types.emplace_back(mapFlag);
6217 combinedInfo.Sizes.emplace_back(
6218 isPtrTy ? builder.CreateSelect(
6219 builder.CreateIsNull(mapData.Pointers[mapDataIdx]),
6220 builder.getInt64(0), mapData.Sizes[mapDataIdx])
6221 : mapData.Sizes[mapDataIdx]);
6222}
6223
6224// This creates two insertions into the MapInfosTy data structure for the
6225// "parent" of a set of members, (usually a container e.g.
6226// class/structure/derived type) when subsequent members have also been
6227// explicitly mapped on the same map clause. Certain types, such as Fortran
6228// descriptors are mapped like this as well, however, the members are
6229// implicit as far as a user is concerned, but we must explicitly map them
6230// internally.
6231//
6232// This function also returns the memberOfFlag for this particular parent,
6233// which is utilised in subsequent member mappings (by modifying there map type
6234// with it) to indicate that a member is part of this parent and should be
6235// treated by the runtime as such. Important to achieve the correct mapping.
6236//
6237// This function borrows a lot from Clang's emitCombinedEntry function
6238// inside of CGOpenMPRuntime.cpp
6240 LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder,
6241 llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo,
6242 MapInfoData &mapData, uint64_t mapDataIndex,
6243 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag,
6244 TargetDirectiveEnumTy targetDirective) {
6245 using MapFlags = llvm::omp::OpenMPOffloadMappingFlags;
6246 assert(!ompBuilder.Config.isTargetDevice() &&
6247 "function only supported for host device codegen");
6248 auto parentClause =
6249 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
6250 auto *parentMapper = mapData.Mappers[mapDataIndex];
6251
6252 // Map the first segment of the parent. If a user-defined mapper is attached,
6253 // include the parent's to/from-style bits (and common modifiers) in this
6254 // base entry so the mapper receives correct copy semantics via its 'type'
6255 // parameter. Also keep TARGET_PARAM when required for kernel arguments.
6256 MapFlags baseFlag = (targetDirective == TargetDirectiveEnumTy::Target &&
6257 !mapData.IsDeclareTarget[mapDataIndex])
6258 ? MapFlags::OMP_MAP_TARGET_PARAM
6259 : MapFlags::OMP_MAP_NONE;
6260
6261 if (parentMapper) {
6262 // Preserve relevant map-type bits from the parent clause. These include
6263 // the copy direction (TO/FROM), as well as commonly used modifiers that
6264 // should be visible to the mapper for correct behaviour.
6265 MapFlags parentFlags = mapData.Types[mapDataIndex];
6266 MapFlags preserve = MapFlags::OMP_MAP_TO | MapFlags::OMP_MAP_FROM |
6267 MapFlags::OMP_MAP_ALWAYS | MapFlags::OMP_MAP_CLOSE |
6268 MapFlags::OMP_MAP_PRESENT |
6269 MapFlags::OMP_MAP_OMPX_HOLD |
6270 MapFlags::OMP_MAP_IMPLICIT;
6271 baseFlag |= (parentFlags & preserve);
6272 } else {
6273 MapFlags parentFlags = mapData.Types[mapDataIndex];
6274 MapFlags preserve =
6275 MapFlags::OMP_MAP_PRESENT | MapFlags::OMP_MAP_RETURN_PARAM;
6276 baseFlag |= (parentFlags & preserve);
6277 }
6278
6279 combinedInfo.Types.emplace_back(baseFlag);
6280 combinedInfo.DevicePointers.emplace_back(
6281 mapData.DevicePointers[mapDataIndex]);
6282 // Only attach the mapper to the base entry when we are mapping the whole
6283 // parent. Combined/segment entries must not carry a mapper; otherwise the
6284 // mapper can be invoked with a partial size, which is undefined behaviour.
6285 combinedInfo.Mappers.emplace_back(
6286 parentMapper && !parentClause.getPartialMap() ? parentMapper : nullptr);
6287 combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
6288 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
6289 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
6290
6291 // Calculate size of the parent object being mapped based on the
6292 // addresses at runtime, highAddr - lowAddr = size. This of course
6293 // doesn't factor in allocated data like pointers, hence the further
6294 // processing of members specified by users, or in the case of
6295 // Fortran pointers and allocatables, the mapping of the pointed to
6296 // data by the descriptor (which itself, is a structure containing
6297 // runtime information on the dynamically allocated data).
6298 llvm::Value *lowAddr, *highAddr;
6299 if (!parentClause.getPartialMap()) {
6300 lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
6301 builder.getPtrTy());
6302 highAddr = builder.CreatePointerCast(
6303 builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
6304 mapData.Pointers[mapDataIndex], 1),
6305 builder.getPtrTy());
6306 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
6307 } else {
6308 auto mapOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
6309 int firstMemberIdx = getMapDataMemberIdx(
6310 mapData, getFirstOrLastMappedMemberPtr(mapOp, true));
6311 lowAddr = builder.CreatePointerCast(mapData.BasePointers[firstMemberIdx],
6312 builder.getPtrTy());
6313
6314 int lastMemberIdx = getMapDataMemberIdx(
6315 mapData, getFirstOrLastMappedMemberPtr(mapOp, false));
6316 auto lastMemberMapInfo =
6317 cast<omp::MapInfoOp>(mapData.MapClause[lastMemberIdx]);
6318
6319 // NOTE: Currently, for RefPtee the BaseType is set to the varPtrPtr field,
6320 // which is the pointer datas type and not the member within the structure
6321 // that it's part of, so we have to make sure we use the member type in this
6322 // case when calculating the parents size offsets.
6323 // TODO: May be good to extend MapInfoData to support tracking of both
6324 // VarPtr/VarPtrPtr BaseType's to better distinguish what's being used more
6325 // consistently.
6326 bool isRefPteeMap = bitEnumContainsAll(lastMemberMapInfo.getMapType(),
6327 omp::ClauseMapFlags::ref_ptee) &&
6328 !bitEnumContainsAll(lastMemberMapInfo.getMapType(),
6329 omp::ClauseMapFlags::ref_ptr);
6330 llvm::Type *castType = mapData.BaseType[lastMemberIdx];
6331 if (isRefPteeMap)
6332 castType =
6333 moduleTranslation.convertType(lastMemberMapInfo.getVarPtrType());
6334 highAddr = builder.CreatePointerCast(
6335 builder.CreateGEP(castType, mapData.BasePointers[lastMemberIdx],
6336 builder.getInt64(1)),
6337 builder.getPtrTy());
6338 combinedInfo.Pointers.emplace_back(mapData.BasePointers[firstMemberIdx]);
6339 }
6340
6341 llvm::Value *size = builder.CreateIntCast(
6342 builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
6343 builder.getInt64Ty(),
6344 /*isSigned=*/false);
6345 combinedInfo.Sizes.push_back(size);
6346
6347 // This creates the initial MEMBER_OF mapping that consists of
6348 // the parent/top level container (same as above effectively, except
6349 // with a fixed initial compile time size and separate maptype which
6350 // indicates the true mape type (tofrom etc.). This parent mapping is
6351 // only relevant if the structure in its totality is being mapped,
6352 // otherwise the above suffices.
6353 if (!parentClause.getPartialMap()) {
6354 // TODO: This will need to be expanded to include the whole host of logic
6355 // for the map flags that Clang currently supports (e.g. it should do some
6356 // further case specific flag modifications). For the moment, it handles
6357 // what we support as expected.
6358 MapFlags mapFlag = mapData.Types[mapDataIndex];
6359 bool hasMapClose = (MapFlags(mapFlag) & MapFlags::OMP_MAP_CLOSE) ==
6360 MapFlags::OMP_MAP_CLOSE;
6361 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
6362
6363 llvm::SmallVector<size_t> overlapIdxs;
6364 // Find all of the members that "overlap", i.e. occlude other members that
6365 // were mapped alongside the parent, e.g. member [0], occludes [0,1] and
6366 // [0,2], but not [1,0].
6367 getOverlappedMembers(overlapIdxs, parentClause);
6368
6369 // When we only have one overlap we skip the case that tries to segment the
6370 // mapping as best it can without creating holes, as the calculation is more
6371 // likely to have more overhead than anything we gain from mapping a smaller
6372 // chunk of data. This can be seen in cases where we are mapping Fortran
6373 // descriptors which are a special case of record type mapping.
6374 //
6375 // The cases for close and update are unique edge cases where the segmenting
6376 // does not play well with the runtime currently.
6377 if (targetDirective == TargetDirectiveEnumTy::TargetUpdate || hasMapClose ||
6378 overlapIdxs.size() == 1) {
6379 combinedInfo.Types.emplace_back(mapFlag);
6380 combinedInfo.DevicePointers.emplace_back(
6381 mapData.DevicePointers[mapDataIndex]);
6382 combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
6383 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
6384 combinedInfo.BasePointers.emplace_back(
6385 mapData.BasePointers[mapDataIndex]);
6386 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
6387 combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIndex]);
6388 combinedInfo.Mappers.emplace_back(nullptr);
6389 } else {
6390 // We need to make sure the overlapped members are sorted in order of
6391 // lowest address to highest address.
6392 sortMapIndices(overlapIdxs, parentClause);
6393
6394 lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
6395 builder.getPtrTy());
6396 highAddr = builder.CreatePointerCast(
6397 builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
6398 mapData.Pointers[mapDataIndex], 1),
6399 builder.getPtrTy());
6400
6401 // Currently, the return parameter should be the over-riding parent in
6402 // cases where we have a return parameter that is echoed to all members,
6403 // the main case of this currently is with fortran descriptors. It may
6404 // need more finessing for C/C++ in the future or descriptors that are
6405 // members of derived types.
6406 mapFlag &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
6407
6408 // TODO: We may want to skip arrays/array sections in this as Clang does.
6409 // It appears to be an optimisation rather than a necessity though,
6410 // but this requires further investigation. However, we would have to make
6411 // sure to not exclude maps with bounds that ARE pointers, as these are
6412 // processed as separate components, i.e. pointer + data.
6413 for (auto v : overlapIdxs) {
6414 auto mapDataOverlapIdx = getMapDataMemberIdx(
6415 mapData,
6416 cast<omp::MapInfoOp>(parentClause.getMembers()[v].getDefiningOp()));
6417 auto isPtrMap = checkIfPointerMap(
6418 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataOverlapIdx]));
6419 combinedInfo.Types.emplace_back(mapFlag);
6420 combinedInfo.DevicePointers.emplace_back(
6421 llvm::OpenMPIRBuilder::DeviceInfoTy::None);
6422 combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
6423 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
6424 combinedInfo.BasePointers.emplace_back(
6425 mapData.BasePointers[mapDataIndex]);
6426 combinedInfo.Mappers.emplace_back(nullptr);
6427 combinedInfo.Pointers.emplace_back(lowAddr);
6428 auto sizeCalc = builder.CreateIntCast(
6429 builder.CreatePtrDiff(builder.getInt8Ty(),
6430 mapData.OriginalValue[mapDataOverlapIdx],
6431 lowAddr),
6432 builder.getInt64Ty(), /*isSigned=*/true);
6433 // In certain cases, we'll generate a size of 0 if we're not careful
6434 // (e.g. if lowAddr happens to be the first member), which isn't
6435 // correct, even if the runtimes is sometimes fine with it so, in these
6436 // scenarios we select the types size instead.
6437 auto sizeSel = builder.CreateSelect(
6438 builder.CreateICmpNE(builder.getInt64(0), sizeCalc), sizeCalc,
6439 isPtrMap ? llvm::ConstantExpr::getSizeOf(builder.getPtrTy())
6440 : mapData.Sizes[mapDataOverlapIdx]);
6441 combinedInfo.Sizes.emplace_back(sizeSel);
6442 lowAddr = builder.CreateConstGEP1_32(
6443 isPtrMap ? builder.getPtrTy() : mapData.BaseType[mapDataOverlapIdx],
6444 mapData.BasePointers[mapDataOverlapIdx], 1);
6445 }
6446
6447 combinedInfo.Types.emplace_back(mapFlag);
6448 combinedInfo.DevicePointers.emplace_back(
6449 llvm::OpenMPIRBuilder::DeviceInfoTy::None);
6450 combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
6451 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
6452 combinedInfo.BasePointers.emplace_back(
6453 mapData.BasePointers[mapDataIndex]);
6454 combinedInfo.Mappers.emplace_back(nullptr);
6455 combinedInfo.Pointers.emplace_back(lowAddr);
6456 combinedInfo.Sizes.emplace_back(builder.CreateIntCast(
6457 builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
6458 builder.getInt64Ty(), true));
6459 }
6460 }
6461}
6462
6464 llvm::IRBuilderBase &builder,
6465 llvm::OpenMPIRBuilder &ompBuilder,
6466 DataLayout &dl, MapInfosTy &combinedInfo,
6467 MapInfoData &mapData, uint64_t mapDataIndex,
6468 TargetDirectiveEnumTy targetDirective) {
6469 assert(!ompBuilder.Config.isTargetDevice() &&
6470 "function only supported for host device codegen");
6471
6472 auto parentClause =
6473 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
6474
6475 // If we have a partial map (no parent referenced in the map clauses of the
6476 // directive, only members) and only a single member, we do not need to bind
6477 // the map of the member to the parent, we can pass the member separately.
6478 if (parentClause.getMembers().size() == 1 && parentClause.getPartialMap()) {
6479 auto memberClause = llvm::cast<omp::MapInfoOp>(
6480 parentClause.getMembers()[0].getDefiningOp());
6481 int memberDataIdx = getMapDataMemberIdx(mapData, memberClause);
6482 // Note: Clang treats arrays with explicit bounds that fall into this
6483 // category as a parent with map case, however, it seems this isn't a
6484 // requirement, and processing them as an individual map is fine. So,
6485 // we will handle them as individual maps for the moment, as it's
6486 // difficult for us to check this as we always require bounds to be
6487 // specified currently and it's also marginally more optimal (single
6488 // map rather than two). The difference may come from the fact that
6489 // Clang maps array without bounds as pointers (which we do not
6490 // currently do), whereas we treat them as arrays in all cases
6491 // currently.
6493 builder, ompBuilder, mapData, memberDataIdx, combinedInfo,
6494 targetDirective,
6495 /*MemberOfFlag=*/llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE,
6496 /*isTargetParam=*/true, mapDataIndex);
6497 return;
6498 }
6499
6500 auto collectMapInfoIdxs =
6501 [&](llvm::SmallVectorImpl<int64_t> &mapsAndInfoIdx) {
6502 auto parentClause =
6503 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
6504 mapsAndInfoIdx.push_back(getMapDataMemberIdx(mapData, parentClause));
6505 for (auto member : parentClause.getMembers())
6506 mapsAndInfoIdx.push_back(getMapDataMemberIdx(
6507 mapData, llvm::cast<omp::MapInfoOp>(member.getDefiningOp())));
6508 };
6509
6510 llvm::SmallVector<int64_t> mapInfoIdx;
6511 collectMapInfoIdxs(mapInfoIdx);
6512
6513 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag =
6514 ompBuilder.getMemberOfFlag(combinedInfo.Types.size());
6515 for (size_t i = 0; i < mapInfoIdx.size(); i++) {
6516 // Index == 0 is the parent map and if it gets here it's an unattachable
6517 // type and should have OMP_MAP_TARGET_PARAM applied and no MEMBER_OF flag.
6518 if (i == 0) {
6519 mapParentWithMembers(moduleTranslation, builder, ompBuilder, dl,
6520 combinedInfo, mapData, mapInfoIdx[i], memberOfFlag,
6521 targetDirective);
6522 } else {
6523 processIndividualMap(builder, ompBuilder, mapData, mapInfoIdx[i],
6524 combinedInfo, targetDirective, memberOfFlag,
6525 /*isTargetParam=*/false, mapDataIndex);
6526 }
6527 }
6528}
6529
6530// This is a variation on Clang's GenerateOpenMPCapturedVars, which
6531// generates different operation (e.g. load/store) combinations for
6532// arguments to the kernel, based on map capture kinds which are then
6533// utilised in the combinedInfo in place of the original Map value.
6534static void
6535createAlteredByCaptureMap(MapInfoData &mapData,
6536 LLVM::ModuleTranslation &moduleTranslation,
6537 llvm::IRBuilderBase &builder) {
6538 assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
6539 "function only supported for host device codegen");
6540 for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
6541 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
6542 bool isAttachMap =
6543 ((convertClauseMapFlags(mapOp.getMapType()) &
6544 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH) ==
6545 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH);
6546
6547 // If it's declare target, skip it, it's handled separately. However, if
6548 // it's declare target, and an attach map, we want to calculate the exact
6549 // address offset so that we attach correctly.
6550 if (!mapData.IsDeclareTarget[i] ||
6551 (mapData.IsDeclareTarget[i] && isAttachMap)) {
6552 omp::VariableCaptureKind captureKind = mapOp.getMapCaptureType();
6553 bool isPtrTy = checkIfPointerMap(mapOp);
6554
6555 // Currently handles array sectioning lowerbound case, but more
6556 // logic may be required in the future. Clang invokes EmitLValue,
6557 // which has specialised logic for special Clang types such as user
6558 // defines, so it is possible we will have to extend this for
6559 // structures or other complex types. As the general idea is that this
6560 // function mimics some of the logic from Clang that we require for
6561 // kernel argument passing from host -> device.
6562 switch (captureKind) {
6563 case omp::VariableCaptureKind::ByRef: {
6564 llvm::Value *newV = mapData.Pointers[i];
6565 std::vector<llvm::Value *> offsetIdx = calculateBoundsOffset(
6566 moduleTranslation, builder, mapData.BaseType[i]->isArrayTy(),
6567 mapOp.getBounds());
6568 if (isPtrTy)
6569 newV = builder.CreateLoad(builder.getPtrTy(), newV);
6570
6571 if (!offsetIdx.empty())
6572 newV = builder.CreateInBoundsGEP(mapData.BaseType[i], newV, offsetIdx,
6573 "array_offset");
6574 mapData.Pointers[i] = newV;
6575 } break;
6576 case omp::VariableCaptureKind::ByCopy: {
6577 llvm::Type *type = mapData.BaseType[i];
6578 llvm::Value *newV;
6579 if (mapData.Pointers[i]->getType()->isPointerTy())
6580 newV = builder.CreateLoad(type, mapData.Pointers[i]);
6581 else
6582 newV = mapData.Pointers[i];
6583
6584 if (!isPtrTy) {
6585 auto curInsert = builder.saveIP();
6586 llvm::DebugLoc DbgLoc = builder.getCurrentDebugLocation();
6587 builder.restoreIP(findAllocInsertPoints(builder, moduleTranslation));
6588 auto *memTempAlloc =
6589 builder.CreateAlloca(builder.getPtrTy(), nullptr, ".casted");
6590 builder.SetCurrentDebugLocation(DbgLoc);
6591 builder.restoreIP(curInsert);
6592
6593 builder.CreateStore(newV, memTempAlloc);
6594 newV = builder.CreateLoad(builder.getPtrTy(), memTempAlloc);
6595 }
6596
6597 mapData.Pointers[i] = newV;
6598 mapData.BasePointers[i] = newV;
6599 } break;
6600 case omp::VariableCaptureKind::This:
6601 case omp::VariableCaptureKind::VLAType:
6602 mapData.MapClause[i]->emitOpError("Unhandled capture kind");
6603 break;
6604 }
6605 }
6606 }
6607}
6608
6609// Generate all map related information and fill the combinedInfo.
6610static void genMapInfos(llvm::IRBuilderBase &builder,
6611 LLVM::ModuleTranslation &moduleTranslation,
6612 DataLayout &dl, MapInfosTy &combinedInfo,
6613 MapInfoData &mapData,
6614 TargetDirectiveEnumTy targetDirective) {
6615 assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
6616 "function only supported for host device codegen");
6617 // We wish to modify some of the methods in which arguments are
6618 // passed based on their capture type by the target region, this can
6619 // involve generating new loads and stores, which changes the
6620 // MLIR value to LLVM value mapping, however, we only wish to do this
6621 // locally for the current function/target and also avoid altering
6622 // ModuleTranslation, so we remap the base pointer or pointer stored
6623 // in the map infos corresponding MapInfoData, which is later accessed
6624 // by genMapInfos and createTarget to help generate the kernel and
6625 // kernel arg structure. It primarily becomes relevant in cases like
6626 // bycopy, or byref range'd arrays. In the default case, we simply
6627 // pass thee pointer byref as both basePointer and pointer.
6628 createAlteredByCaptureMap(mapData, moduleTranslation, builder);
6629
6630 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
6631
6632 // We operate under the assumption that all vectors that are
6633 // required in MapInfoData are of equal lengths (either filled with
6634 // default constructed data or appropiate information) so we can
6635 // utilise the size from any component of MapInfoData, if we can't
6636 // something is missing from the initial MapInfoData construction.
6637 for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
6638 if (mapData.IsAMember[i])
6639 continue;
6640
6641 auto mapInfoOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[i]);
6642 if (!mapInfoOp.getMembers().empty()) {
6643 processMapWithMembersOf(moduleTranslation, builder, *ompBuilder, dl,
6644 combinedInfo, mapData, i, targetDirective);
6645 continue;
6646 }
6647
6648 processIndividualMap(builder, *ompBuilder, mapData, i, combinedInfo,
6649 targetDirective);
6650 }
6651}
6652
6653static llvm::Expected<llvm::Function *>
6654emitUserDefinedMapper(Operation *declMapperOp, llvm::IRBuilderBase &builder,
6655 LLVM::ModuleTranslation &moduleTranslation,
6656 llvm::StringRef mapperFuncName,
6657 TargetDirectiveEnumTy targetDirective);
6658
6659static llvm::Expected<llvm::Function *>
6660getOrCreateUserDefinedMapperFunc(Operation *op, llvm::IRBuilderBase &builder,
6661 LLVM::ModuleTranslation &moduleTranslation,
6662 TargetDirectiveEnumTy targetDirective) {
6663 assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
6664 "function only supported for host device codegen");
6665 auto declMapperOp = cast<omp::DeclareMapperOp>(op);
6666 std::string mapperFuncName =
6667 moduleTranslation.getOpenMPBuilder()->createPlatformSpecificName(
6668 {"omp_mapper", declMapperOp.getSymName()});
6669
6670 if (auto *lookupFunc = moduleTranslation.lookupFunction(mapperFuncName))
6671 return lookupFunc;
6672
6673 // Recursive types can cause re-entrant mapper emission. The mapper function
6674 // is created by OpenMPIRBuilder before the callbacks run, so it may already
6675 // exist in the LLVM module even though it is not yet registered in the
6676 // ModuleTranslation mapping table. Reuse and register it to break the
6677 // recursion.
6678 if (llvm::Function *existingFunc =
6679 moduleTranslation.getLLVMModule()->getFunction(mapperFuncName)) {
6680 moduleTranslation.mapFunction(mapperFuncName, existingFunc);
6681 return existingFunc;
6682 }
6683
6684 return emitUserDefinedMapper(declMapperOp, builder, moduleTranslation,
6685 mapperFuncName, targetDirective);
6686}
6687
6688static llvm::Expected<llvm::Function *>
6689emitUserDefinedMapper(Operation *op, llvm::IRBuilderBase &builder,
6690 LLVM::ModuleTranslation &moduleTranslation,
6691 llvm::StringRef mapperFuncName,
6692 TargetDirectiveEnumTy targetDirective) {
6693 assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
6694 "function only supported for host device codegen");
6695 auto declMapperOp = cast<omp::DeclareMapperOp>(op);
6696 auto declMapperInfoOp = declMapperOp.getDeclareMapperInfo();
6697 if (failed(checkImplementationStatus(*declMapperInfoOp)))
6698 return llvm::make_error<PreviouslyReportedError>();
6699
6700 DataLayout dl = DataLayout(declMapperOp->getParentOfType<ModuleOp>());
6701 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
6702 llvm::Type *varType = moduleTranslation.convertType(declMapperOp.getType());
6703 SmallVector<Value> mapVars = declMapperInfoOp.getMapVars();
6704
6705 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
6706
6707 // Fill up the arrays with all the mapped variables.
6708 MapInfosTy combinedInfo;
6709 auto genMapInfoCB =
6710 [&](InsertPointTy codeGenIP, llvm::Value *ptrPHI,
6711 llvm::Value *unused2) -> llvm::OpenMPIRBuilder::MapInfosOrErrorTy {
6712 builder.restoreIP(codeGenIP);
6713 moduleTranslation.mapValue(declMapperOp.getSymVal(), ptrPHI);
6714 moduleTranslation.mapBlock(&declMapperOp.getRegion().front(),
6715 builder.GetInsertBlock());
6716 if (failed(moduleTranslation.convertBlock(declMapperOp.getRegion().front(),
6717 /*ignoreArguments=*/true,
6718 builder)))
6719 return llvm::make_error<PreviouslyReportedError>();
6720 MapInfoData mapData;
6721 collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, dl,
6722 builder);
6723 genMapInfos(builder, moduleTranslation, dl, combinedInfo, mapData,
6724 targetDirective);
6725
6726 // Drop the mapping that is no longer necessary so that the same region
6727 // can be processed multiple times.
6728 moduleTranslation.forgetMapping(declMapperOp.getRegion());
6729 return combinedInfo;
6730 };
6731
6732 auto customMapperCB = [&](unsigned i) -> llvm::Expected<llvm::Function *> {
6733 if (!combinedInfo.Mappers[i])
6734 return nullptr;
6735 return getOrCreateUserDefinedMapperFunc(combinedInfo.Mappers[i], builder,
6736 moduleTranslation, targetDirective);
6737 };
6738
6739 llvm::Expected<llvm::Function *> newFn = ompBuilder->emitUserDefinedMapper(
6740 genMapInfoCB, varType, mapperFuncName, customMapperCB,
6741 /*PreserveMemberOfFlags=*/true);
6742 if (!newFn)
6743 return newFn.takeError();
6744 if ([[maybe_unused]] llvm::Function *mappedFunc =
6745 moduleTranslation.lookupFunction(mapperFuncName)) {
6746 assert(mappedFunc == *newFn &&
6747 "mapper function mapping disagrees with emitted function");
6748 } else {
6749 moduleTranslation.mapFunction(mapperFuncName, *newFn);
6750 }
6751 return *newFn;
6752}
6753
6754static LogicalResult
6755convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
6756 LLVM::ModuleTranslation &moduleTranslation) {
6757 llvm::Value *ifCond = nullptr;
6758 llvm::Value *deviceID = builder.getInt64(llvm::omp::OMP_DEVICEID_UNDEF);
6759 SmallVector<Value> mapVars;
6760 SmallVector<Value> useDevicePtrVars;
6761 SmallVector<Value> useDeviceAddrVars;
6762 llvm::omp::RuntimeFunction RTLFn;
6763 DataLayout DL = DataLayout(op->getParentOfType<ModuleOp>());
6764 TargetDirectiveEnumTy targetDirective = getTargetDirectiveEnumTyFromOp(op);
6765
6766 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
6767 llvm::OpenMPIRBuilder::TargetDataInfo info(
6768 /*RequiresDevicePointerInfo=*/true,
6769 /*SeparateBeginEndCalls=*/true);
6770 assert(!ompBuilder->Config.isTargetDevice() &&
6771 "target data/enter/exit/update are host ops");
6772 bool isOffloadEntry = !ompBuilder->Config.TargetTriples.empty();
6773
6774 auto getDeviceID = [&](mlir::Value dev) -> llvm::Value * {
6775 llvm::Value *v = moduleTranslation.lookupValue(dev);
6776 return builder.CreateIntCast(v, builder.getInt64Ty(), /*isSigned=*/true);
6777 };
6778
6779 LogicalResult result =
6781 .Case([&](omp::TargetDataOp dataOp) {
6782 if (failed(checkImplementationStatus(*dataOp)))
6783 return failure();
6784
6785 if (auto ifVar = dataOp.getIfExpr())
6786 ifCond = moduleTranslation.lookupValue(ifVar);
6787
6788 if (mlir::Value devId = dataOp.getDevice())
6789 deviceID = getDeviceID(devId);
6790
6791 mapVars = dataOp.getMapVars();
6792 useDevicePtrVars = dataOp.getUseDevicePtrVars();
6793 useDeviceAddrVars = dataOp.getUseDeviceAddrVars();
6794 return success();
6795 })
6796 .Case([&](omp::TargetEnterDataOp enterDataOp) -> LogicalResult {
6797 if (failed(checkImplementationStatus(*enterDataOp)))
6798 return failure();
6799
6800 if (auto ifVar = enterDataOp.getIfExpr())
6801 ifCond = moduleTranslation.lookupValue(ifVar);
6802
6803 if (mlir::Value devId = enterDataOp.getDevice())
6804 deviceID = getDeviceID(devId);
6805
6806 RTLFn =
6807 enterDataOp.getNowait()
6808 ? llvm::omp::OMPRTL___tgt_target_data_begin_nowait_mapper
6809 : llvm::omp::OMPRTL___tgt_target_data_begin_mapper;
6810 mapVars = enterDataOp.getMapVars();
6811 info.HasNoWait = enterDataOp.getNowait();
6812 return success();
6813 })
6814 .Case([&](omp::TargetExitDataOp exitDataOp) -> LogicalResult {
6815 if (failed(checkImplementationStatus(*exitDataOp)))
6816 return failure();
6817
6818 if (auto ifVar = exitDataOp.getIfExpr())
6819 ifCond = moduleTranslation.lookupValue(ifVar);
6820
6821 if (mlir::Value devId = exitDataOp.getDevice())
6822 deviceID = getDeviceID(devId);
6823
6824 RTLFn = exitDataOp.getNowait()
6825 ? llvm::omp::OMPRTL___tgt_target_data_end_nowait_mapper
6826 : llvm::omp::OMPRTL___tgt_target_data_end_mapper;
6827 mapVars = exitDataOp.getMapVars();
6828 info.HasNoWait = exitDataOp.getNowait();
6829 return success();
6830 })
6831 .Case([&](omp::TargetUpdateOp updateDataOp) -> LogicalResult {
6832 if (failed(checkImplementationStatus(*updateDataOp)))
6833 return failure();
6834
6835 if (auto ifVar = updateDataOp.getIfExpr())
6836 ifCond = moduleTranslation.lookupValue(ifVar);
6837
6838 if (mlir::Value devId = updateDataOp.getDevice())
6839 deviceID = getDeviceID(devId);
6840
6841 RTLFn =
6842 updateDataOp.getNowait()
6843 ? llvm::omp::OMPRTL___tgt_target_data_update_nowait_mapper
6844 : llvm::omp::OMPRTL___tgt_target_data_update_mapper;
6845 mapVars = updateDataOp.getMapVars();
6846 info.HasNoWait = updateDataOp.getNowait();
6847 return success();
6848 })
6849 .DefaultUnreachable("unexpected operation");
6850
6851 if (failed(result))
6852 return failure();
6853 // Pretend we have IF(false) if we're not doing offload.
6854 if (!isOffloadEntry)
6855 ifCond = builder.getFalse();
6856
6857 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
6858 MapInfoData mapData;
6859 collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, DL,
6860 builder, useDevicePtrVars, useDeviceAddrVars);
6861
6862 // Fill up the arrays with all the mapped variables.
6863 MapInfosTy combinedInfo;
6864 auto genMapInfoCB = [&](InsertPointTy codeGenIP) -> MapInfosTy & {
6865 builder.restoreIP(codeGenIP);
6866 genMapInfos(builder, moduleTranslation, DL, combinedInfo, mapData,
6867 targetDirective);
6868 return combinedInfo;
6869 };
6870
6871 // Define a lambda to apply mappings between use_device_addr and
6872 // use_device_ptr base pointers, and their associated block arguments.
6873 auto mapUseDevice =
6874 [&moduleTranslation](
6875 llvm::OpenMPIRBuilder::DeviceInfoTy type,
6877 llvm::SmallVectorImpl<Value> &useDeviceVars, MapInfoData &mapInfoData,
6878 llvm::function_ref<llvm::Value *(llvm::Value *)> mapper = nullptr) {
6879 for (auto [arg, useDevVar] :
6880 llvm::zip_equal(blockArgs, useDeviceVars)) {
6881
6882 auto getMapBasePtr = [](omp::MapInfoOp mapInfoOp) {
6883 return mapInfoOp.getVarPtrPtr() ? mapInfoOp.getVarPtrPtr()
6884 : mapInfoOp.getVarPtr();
6885 };
6886
6887 auto useDevMap = cast<omp::MapInfoOp>(useDevVar.getDefiningOp());
6888 for (auto [mapClause, devicePointer, basePointer] : llvm::zip_equal(
6889 mapInfoData.MapClause, mapInfoData.DevicePointers,
6890 mapInfoData.BasePointers)) {
6891 auto mapOp = cast<omp::MapInfoOp>(mapClause);
6892 if (getMapBasePtr(mapOp) != getMapBasePtr(useDevMap) ||
6893 devicePointer != type)
6894 continue;
6895
6896 if (llvm::Value *devPtrInfoMap =
6897 mapper ? mapper(basePointer) : basePointer) {
6898 moduleTranslation.mapValue(arg, devPtrInfoMap);
6899 break;
6900 }
6901 }
6902 }
6903 };
6904
6905 using BodyGenTy = llvm::OpenMPIRBuilder::BodyGenTy;
6906 auto bodyGenCB = [&](InsertPointTy codeGenIP, BodyGenTy bodyGenType)
6907 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
6908 // We must always restoreIP regardless of doing anything the caller
6909 // does not restore it, leading to incorrect (no) branch generation.
6910 builder.restoreIP(codeGenIP);
6911 assert(isa<omp::TargetDataOp>(op) &&
6912 "BodyGen requested for non TargetDataOp");
6913 auto blockArgIface = cast<omp::BlockArgOpenMPOpInterface>(op);
6914 Region &region = cast<omp::TargetDataOp>(op).getRegion();
6915 switch (bodyGenType) {
6916 case BodyGenTy::Priv:
6917 // Check if any device ptr/addr info is available
6918 if (!info.DevicePtrInfoMap.empty()) {
6919 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
6920 blockArgIface.getUseDeviceAddrBlockArgs(),
6921 useDeviceAddrVars, mapData,
6922 [&](llvm::Value *basePointer) -> llvm::Value * {
6923 if (!info.DevicePtrInfoMap[basePointer].second)
6924 return nullptr;
6925 return builder.CreateLoad(
6926 builder.getPtrTy(),
6927 info.DevicePtrInfoMap[basePointer].second);
6928 });
6929 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
6930 blockArgIface.getUseDevicePtrBlockArgs(), useDevicePtrVars,
6931 mapData, [&](llvm::Value *basePointer) {
6932 return info.DevicePtrInfoMap[basePointer].second;
6933 });
6934
6935 if (failed(inlineConvertOmpRegions(region, "omp.data.region", builder,
6936 moduleTranslation)))
6937 return llvm::make_error<PreviouslyReportedError>();
6938 }
6939 break;
6940 case BodyGenTy::DupNoPriv:
6941 if (info.DevicePtrInfoMap.empty()) {
6942 // For host device we still need to do the mapping for codegen,
6943 // otherwise it may try to lookup a missing value.
6944 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
6945 blockArgIface.getUseDeviceAddrBlockArgs(),
6946 useDeviceAddrVars, mapData);
6947 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
6948 blockArgIface.getUseDevicePtrBlockArgs(), useDevicePtrVars,
6949 mapData);
6950 }
6951 break;
6952 case BodyGenTy::NoPriv:
6953 // If device info is available then region has already been generated
6954 if (info.DevicePtrInfoMap.empty()) {
6955 if (failed(inlineConvertOmpRegions(region, "omp.data.region", builder,
6956 moduleTranslation)))
6957 return llvm::make_error<PreviouslyReportedError>();
6958 }
6959 break;
6960 }
6961 return builder.saveIP();
6962 };
6963
6964 auto customMapperCB =
6965 [&](unsigned int i) -> llvm::Expected<llvm::Function *> {
6966 if (!combinedInfo.Mappers[i])
6967 return nullptr;
6968 info.HasMapper = true;
6969 return getOrCreateUserDefinedMapperFunc(combinedInfo.Mappers[i], builder,
6970 moduleTranslation, targetDirective);
6971 };
6972
6973 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
6975 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
6976 findAllocInsertPoints(builder, moduleTranslation, &deallocBlocks);
6977 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = [&]() {
6978 if (isa<omp::TargetDataOp>(op))
6979 return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(),
6980 deallocBlocks, deviceID, ifCond, info,
6981 genMapInfoCB, customMapperCB,
6982 /*MapperFunc=*/nullptr, bodyGenCB,
6983 /*DeviceAddrCB=*/nullptr);
6984 return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(),
6985 deallocBlocks, deviceID, ifCond, info,
6986 genMapInfoCB, customMapperCB, &RTLFn);
6987 }();
6988
6989 if (failed(handleError(afterIP, *op)))
6990 return failure();
6991
6992 builder.restoreIP(*afterIP);
6993 return success();
6994}
6995
6996static LogicalResult
6997convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder,
6998 LLVM::ModuleTranslation &moduleTranslation) {
6999 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
7000 auto distributeOp = cast<omp::DistributeOp>(opInst);
7001 if (failed(checkImplementationStatus(opInst)))
7002 return failure();
7003
7004 /// Process teams op reduction in distribute if the reduction is contained in
7005 /// this specific distribute op.
7006 omp::TeamsOp teamsOp = opInst.getParentOfType<omp::TeamsOp>();
7007 bool doDistributeReduction =
7008 teamsOp && getDistributeCapturingTeamsReduction(teamsOp) == distributeOp;
7009
7010 DenseMap<Value, llvm::Value *> reductionVariableMap;
7011 unsigned numReductionVars = teamsOp ? teamsOp.getNumReductionVars() : 0;
7013 SmallVector<llvm::Value *> privateReductionVariables(numReductionVars);
7014 llvm::ArrayRef<bool> isByRef;
7015
7016 if (doDistributeReduction) {
7017 isByRef = getIsByRef(teamsOp.getReductionByref());
7018 assert(isByRef.size() == teamsOp.getNumReductionVars());
7019
7020 collectReductionDecls(teamsOp, reductionDecls);
7021 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
7022 findAllocInsertPoints(builder, moduleTranslation);
7023
7024 MutableArrayRef<BlockArgument> reductionArgs =
7025 llvm::cast<omp::BlockArgOpenMPOpInterface>(*teamsOp)
7026 .getReductionBlockArgs();
7027
7029 teamsOp, reductionArgs, builder, moduleTranslation, allocaIP,
7030 reductionDecls, privateReductionVariables, reductionVariableMap,
7031 isByRef)))
7032 return failure();
7033 }
7034
7035 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
7036 auto bodyGenCB =
7037 [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
7038 llvm::ArrayRef<llvm::BasicBlock *> deallocBlocks) -> llvm::Error {
7039 // Save the alloca insertion point on ModuleTranslation stack for use in
7040 // nested regions.
7042 moduleTranslation, allocaIP, deallocBlocks);
7043
7044 // DistributeOp has only one region associated with it.
7045 builder.restoreIP(codeGenIP);
7046 PrivateVarsInfo privVarsInfo(distributeOp);
7047
7049 distributeOp, builder, moduleTranslation, privVarsInfo, allocaIP);
7050 if (handleError(afterAllocas, opInst).failed())
7051 return llvm::make_error<PreviouslyReportedError>();
7052
7053 if (handleError(initPrivateVars(builder, moduleTranslation, privVarsInfo),
7054 opInst)
7055 .failed())
7056 return llvm::make_error<PreviouslyReportedError>();
7057
7058 if (failed(copyFirstPrivateVars(
7059 distributeOp, builder, moduleTranslation, privVarsInfo.mlirVars,
7060 privVarsInfo.llvmVars, privVarsInfo.privatizers,
7061 distributeOp.getPrivateNeedsBarrier())))
7062 return llvm::make_error<PreviouslyReportedError>();
7063
7064 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
7065 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
7067 convertOmpOpRegions(distributeOp.getRegion(), "omp.distribute.region",
7068 builder, moduleTranslation);
7069 if (!regionBlock)
7070 return regionBlock.takeError();
7071 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
7072
7073 // Skip applying a workshare loop below when translating 'distribute
7074 // parallel do' (it's been already handled by this point while translating
7075 // the nested omp.wsloop).
7076 if (!isa_and_present<omp::WsloopOp>(distributeOp.getNestedWrapper())) {
7077 // TODO: Add support for clauses which are valid for DISTRIBUTE
7078 // constructs. Static schedule is the default.
7079 bool hasDistSchedule = distributeOp.getDistScheduleStatic();
7080 auto schedule = hasDistSchedule ? omp::ClauseScheduleKind::Distribute
7081 : omp::ClauseScheduleKind::Static;
7082 // dist_schedule clauses are ordered - otherise this should be false
7083 bool isOrdered = hasDistSchedule;
7084 std::optional<omp::ScheduleModifier> scheduleMod;
7085 bool isSimd = false;
7086 llvm::omp::WorksharingLoopType workshareLoopType =
7087 llvm::omp::WorksharingLoopType::DistributeStaticLoop;
7088 bool loopNeedsBarrier = false;
7089 llvm::Value *chunk = moduleTranslation.lookupValue(
7090 distributeOp.getDistScheduleChunkSize());
7091 llvm::CanonicalLoopInfo *loopInfo =
7092 findCurrentLoopInfo(moduleTranslation);
7093 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
7094 ompBuilder->applyWorkshareLoop(
7095 ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
7096 convertToScheduleKind(schedule), chunk, isSimd,
7097 scheduleMod == omp::ScheduleModifier::monotonic,
7098 scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
7099 workshareLoopType, false, hasDistSchedule, chunk);
7100
7101 if (!wsloopIP)
7102 return wsloopIP.takeError();
7103 }
7104 if (failed(cleanupPrivateVars(distributeOp, builder, moduleTranslation,
7105 distributeOp.getLoc(), privVarsInfo)))
7106 return llvm::make_error<PreviouslyReportedError>();
7107
7108 return llvm::Error::success();
7109 };
7110
7112 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
7113 findAllocInsertPoints(builder, moduleTranslation, &deallocBlocks);
7114 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
7115 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
7116 ompBuilder->createDistribute(ompLoc, allocaIP, deallocBlocks, bodyGenCB);
7117
7118 if (failed(handleError(afterIP, opInst)))
7119 return failure();
7120
7121 builder.restoreIP(*afterIP);
7122
7123 if (doDistributeReduction) {
7124 // Process the reductions if required.
7126 teamsOp, builder, moduleTranslation, allocaIP, reductionDecls,
7127 privateReductionVariables, isByRef,
7128 /*isNoWait*/ false, /*isTeamsReduction*/ true);
7129 }
7130 return success();
7131}
7132
7133/// Lowers the FlagsAttr which is applied to the module on the device
7134/// pass when offloading, this attribute contains OpenMP RTL globals that can
7135/// be passed as flags to the frontend, otherwise they are set to default
7136static LogicalResult
7137convertFlagsAttr(Operation *op, mlir::omp::FlagsAttr attribute,
7138 LLVM::ModuleTranslation &moduleTranslation) {
7139 if (!cast<mlir::ModuleOp>(op))
7140 return failure();
7141
7142 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
7143
7144 ompBuilder->M.addModuleFlag(llvm::Module::Max, "openmp-device",
7145 attribute.getOpenmpDeviceVersion());
7146
7147 if (attribute.getNoGpuLib())
7148 return success();
7149
7150 ompBuilder->createGlobalFlag(
7151 attribute.getDebugKind() /*LangOpts().OpenMPTargetDebug*/,
7152 "__omp_rtl_debug_kind");
7153 ompBuilder->createGlobalFlag(
7154 attribute
7155 .getAssumeTeamsOversubscription() /*LangOpts().OpenMPTeamSubscription*/
7156 ,
7157 "__omp_rtl_assume_teams_oversubscription");
7158 ompBuilder->createGlobalFlag(
7159 attribute
7160 .getAssumeThreadsOversubscription() /*LangOpts().OpenMPThreadSubscription*/
7161 ,
7162 "__omp_rtl_assume_threads_oversubscription");
7163 ompBuilder->createGlobalFlag(
7164 attribute.getAssumeNoThreadState() /*LangOpts().OpenMPNoThreadState*/,
7165 "__omp_rtl_assume_no_thread_state");
7166 ompBuilder->createGlobalFlag(
7167 attribute
7168 .getAssumeNoNestedParallelism() /*LangOpts().OpenMPNoNestedParallelism*/
7169 ,
7170 "__omp_rtl_assume_no_nested_parallelism");
7171 return success();
7172}
7173
7174static void getTargetEntryUniqueInfo(llvm::TargetRegionEntryInfo &targetInfo,
7175 omp::TargetOp targetOp,
7176 llvm::OpenMPIRBuilder &ompBuilder,
7177 llvm::vfs::FileSystem &vfs,
7178 llvm::StringRef parentName = "") {
7179 auto fileLoc = targetOp.getLoc()->findInstanceOf<FileLineColLoc>();
7180 assert(fileLoc && "No file found from location");
7181
7182 auto fileInfoCallBack = [&fileLoc]() {
7183 return std::pair<std::string, uint64_t>(
7184 llvm::StringRef(fileLoc.getFilename()), fileLoc.getLine());
7185 };
7186
7187 targetInfo =
7188 ompBuilder.getTargetEntryUniqueInfo(fileInfoCallBack, vfs, parentName);
7189}
7190
7191static void
7192handleDeclareTargetMapVar(MapInfoData &mapData,
7193 LLVM::ModuleTranslation &moduleTranslation,
7194 llvm::IRBuilderBase &builder, llvm::Function *func) {
7195 assert(moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
7196 "function only supported for target device codegen");
7197 llvm::IRBuilderBase::InsertPointGuard guard(builder);
7198 for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
7199 // In the case of declare target mapped variables, the basePointer is
7200 // the reference pointer generated by the convertDeclareTargetAttr
7201 // method. Whereas the kernelValue is the original variable, so for
7202 // the device we must replace all uses of this original global variable
7203 // (stored in kernelValue) with the reference pointer (stored in
7204 // basePointer for declare target mapped variables), as for device the
7205 // data is mapped into this reference pointer and should be loaded
7206 // from it, the original variable is discarded. On host both exist and
7207 // metadata is generated (elsewhere in the convertDeclareTargetAttr)
7208 // function to link the two variables in the runtime and then both the
7209 // reference pointer and the pointer are assigned in the kernel argument
7210 // structure for the host.
7211 if (!mapData.IsDeclareTarget[i])
7212 continue;
7213 // If the original map value is a constant, then we have to make sure all
7214 // of it's uses within the current kernel/function that we are going to
7215 // rewrite are converted to instructions, as we will be altering the old
7216 // use (OriginalValue) from a constant to an instruction, which will be
7217 // illegal and ICE the compiler if the user is a constant expression of
7218 // some kind e.g. a constant GEP.
7219 if (auto *constant = dyn_cast<llvm::Constant>(mapData.OriginalValue[i]))
7220 convertUsersOfConstantsToInstructions(constant, func, false);
7221
7222 // The users iterator will get invalidated if we modify an element,
7223 // so we populate this vector of uses to alter each user on an
7224 // individual basis to emit its own load (rather than one load for
7225 // all).
7227 for (llvm::User *user : mapData.OriginalValue[i]->users())
7228 userVec.push_back(user);
7229
7230 for (llvm::User *user : userVec) {
7231 auto *insn = dyn_cast<llvm::Instruction>(user);
7232 if (!insn || insn->getFunction() != func)
7233 continue;
7234 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
7235 llvm::Value *substitute = mapData.BasePointers[i];
7236 auto declTarPtr =
7237 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
7238 if (isDeclareTargetLink(declTarPtr) ||
7239 (isDeclareTargetTo(declTarPtr) &&
7240 moduleTranslation.getOpenMPBuilder()
7241 ->Config.hasRequiresUnifiedSharedMemory())) {
7242 builder.SetCurrentDebugLocation(insn->getDebugLoc());
7243 substitute = builder.CreateLoad(mapData.BasePointers[i]->getType(),
7244 mapData.BasePointers[i]);
7245 cast<llvm::LoadInst>(substitute)->moveBefore(insn->getIterator());
7246 }
7247 user->replaceUsesOfWith(mapData.OriginalValue[i], substitute);
7248 }
7249 }
7250}
7251
7252// The createDeviceArgumentAccessor function generates
7253// instructions for retrieving (acessing) kernel
7254// arguments inside of the device kernel for use by
7255// the kernel. This enables different semantics such as
7256// the creation of temporary copies of data allowing
7257// semantics like read-only/no host write back kernel
7258// arguments.
7259//
7260// This currently implements a very light version of Clang's
7261// EmitParmDecl's handling of direct argument handling as well
7262// as a portion of the argument access generation based on
7263// capture types found at the end of emitOutlinedFunctionPrologue
7264// in Clang. The indirect path handling of EmitParmDecl's may be
7265// required for future work, but a direct 1-to-1 copy doesn't seem
7266// possible as the logic is rather scattered throughout Clang's
7267// lowering and perhaps we wish to deviate slightly.
7268//
7269// \param mapData - A container containing vectors of information
7270// corresponding to the input argument, which should have a
7271// corresponding entry in the MapInfoData containers
7272// OrigialValue's.
7273// \param arg - This is the generated kernel function argument that
7274// corresponds to the passed in input argument. We generated different
7275// accesses of this Argument, based on capture type and other Input
7276// related information.
7277// \param input - This is the host side value that will be passed to
7278// the kernel i.e. the kernel input, we rewrite all uses of this within
7279// the kernel (as we generate the kernel body based on the target's region
7280// which maintians references to the original input) to the retVal argument
7281// apon exit of this function inside of the OMPIRBuilder. This interlinks
7282// the kernel argument to future uses of it in the function providing
7283// appropriate "glue" instructions inbetween.
7284// \param retVal - This is the value that all uses of input inside of the
7285// kernel will be re-written to, the goal of this function is to generate
7286// an appropriate location for the kernel argument to be accessed from,
7287// e.g. ByRef will result in a temporary allocation location and then
7288// a store of the kernel argument into this allocated memory which
7289// will then be loaded from, ByCopy will use the allocated memory
7290// directly.
7291static llvm::IRBuilderBase::InsertPoint createDeviceArgumentAccessor(
7292 omp::TargetOp targetOp, MapInfoData &mapData, llvm::Argument &arg,
7293 llvm::Value *input, llvm::Value *&retVal, llvm::IRBuilderBase &builder,
7294 llvm::OpenMPIRBuilder &ompBuilder,
7295 LLVM::ModuleTranslation &moduleTranslation,
7296 llvm::IRBuilderBase::InsertPoint allocaIP,
7297 llvm::IRBuilderBase::InsertPoint codeGenIP,
7299 assert(ompBuilder.Config.isTargetDevice() &&
7300 "function only supported for target device codegen");
7301 builder.restoreIP(allocaIP);
7302
7303 omp::VariableCaptureKind capture = omp::VariableCaptureKind::ByRef;
7304 LLVM::TypeToLLVMIRTranslator typeToLLVMIRTranslator(
7305 ompBuilder.M.getContext());
7306 unsigned alignmentValue = 0;
7307 BlockArgument mlirArg;
7309 cast<omp::BlockArgOpenMPOpInterface>(*targetOp).getBlockArgsPairs(
7310 blockArgsPairs);
7311 // Find the associated MapInfoData entry for the current input
7312 for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
7313 if (mapData.OriginalValue[i] == input) {
7314 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
7315 capture = mapOp.getMapCaptureType();
7316 // Get information of alignment of mapped object
7317 alignmentValue = typeToLLVMIRTranslator.getPreferredAlignment(
7318 mapOp.getVarPtrType(), ompBuilder.M.getDataLayout());
7319
7320 // Find the corresponding entry block argument, which can be associated to
7321 // a map, use_device* or has_device* clause.
7322 for (auto &[val, arg] : blockArgsPairs) {
7323 if (mapOp.getResult() == val) {
7324 mlirArg = arg;
7325 break;
7326 }
7327 }
7328 assert(mlirArg && "expected to find entry block argument for map clause");
7329 break;
7330 }
7331 }
7332
7333 unsigned int allocaAS = ompBuilder.M.getDataLayout().getAllocaAddrSpace();
7334 unsigned int defaultAS =
7335 ompBuilder.M.getDataLayout().getProgramAddressSpace();
7336
7337 // Create the allocation for the argument.
7338 llvm::Value *v = nullptr;
7339 if (omp::opInSharedDeviceContext(*targetOp) &&
7341 // Use the beginning of the codeGenIP rather than the usual allocation point
7342 // for shared memory allocations because otherwise these would be done prior
7343 // to the target initialization call. Also, the exit block (where the
7344 // deallocation is placed) is only executed if the initialization call
7345 // succeeds.
7346 builder.SetInsertPoint(codeGenIP.getBlock()->getFirstInsertionPt());
7347 v = ompBuilder.createOMPAllocShared(builder, arg.getType());
7348
7349 // Create deallocations in all provided deallocation points and then restore
7350 // the insertion point to right after the new allocations.
7351 llvm::IRBuilderBase::InsertPointGuard guard(builder);
7352 for (auto deallocIP : deallocIPs) {
7353 builder.SetInsertPoint(deallocIP.getBlock(), deallocIP.getPoint());
7354 ompBuilder.createOMPFreeShared(builder, v, arg.getType());
7355 }
7356 } else {
7357 // Use the current point, which was previously set to allocaIP.
7358 v = builder.CreateAlloca(arg.getType(), allocaAS);
7359
7360 if (allocaAS != defaultAS && arg.getType()->isPointerTy())
7361 v = builder.CreateAddrSpaceCast(v, builder.getPtrTy(defaultAS));
7362 }
7363
7364 builder.CreateStore(&arg, v);
7365
7366 builder.restoreIP(codeGenIP);
7367
7368 switch (capture) {
7369 case omp::VariableCaptureKind::ByCopy: {
7370 retVal = v;
7371 break;
7372 }
7373 case omp::VariableCaptureKind::ByRef: {
7374 llvm::LoadInst *loadInst = builder.CreateAlignedLoad(
7375 v->getType(), v,
7376 ompBuilder.M.getDataLayout().getPrefTypeAlign(v->getType()));
7377 // CreateAlignedLoad function creates similar LLVM IR:
7378 // %res = load ptr, ptr %input, align 8
7379 // This LLVM IR does not contain information about alignment
7380 // of the loaded value. We need to add !align metadata to unblock
7381 // optimizer. The existence of the !align metadata on the instruction
7382 // tells the optimizer that the value loaded is known to be aligned to
7383 // a boundary specified by the integer value in the metadata node.
7384 // Example:
7385 // %res = load ptr, ptr %input, align 8, !align !align_md_node
7386 // ^ ^
7387 // | |
7388 // alignment of %input address |
7389 // |
7390 // alignment of %res object
7391 if (v->getType()->isPointerTy() && alignmentValue) {
7392 llvm::MDBuilder MDB(builder.getContext());
7393 loadInst->setMetadata(
7394 llvm::LLVMContext::MD_align,
7395 llvm::MDNode::get(builder.getContext(),
7396 MDB.createConstant(llvm::ConstantInt::get(
7397 llvm::Type::getInt64Ty(builder.getContext()),
7398 alignmentValue))));
7399 }
7400 retVal = loadInst;
7401
7402 break;
7403 }
7404 case omp::VariableCaptureKind::This:
7405 case omp::VariableCaptureKind::VLAType:
7406 // TODO: Consider returning error to use standard reporting for
7407 // unimplemented features.
7408 assert(false && "Currently unsupported capture kind");
7409 break;
7410 }
7411
7412 return builder.saveIP();
7413}
7414
7415/// Follow uses of `host_eval`-defined block arguments of the given `omp.target`
7416/// operation and populate output variables with their corresponding host value
7417/// (i.e. operand evaluated outside of the target region), based on their uses
7418/// inside of the target region.
7419///
7420/// Loop bounds and steps are only optionally populated, if output vectors are
7421/// provided.
7422static void
7423extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
7424 Value &numTeamsLower, Value &numTeamsUpper,
7425 Value &threadLimit,
7426 llvm::SmallVectorImpl<Value> *lowerBounds = nullptr,
7427 llvm::SmallVectorImpl<Value> *upperBounds = nullptr,
7428 llvm::SmallVectorImpl<Value> *steps = nullptr) {
7429 auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
7430 for (auto item : llvm::zip_equal(targetOp.getHostEvalVars(),
7431 blockArgIface.getHostEvalBlockArgs())) {
7432 Value hostEvalVar = std::get<0>(item), blockArg = std::get<1>(item);
7433
7434 for (Operation *user : blockArg.getUsers()) {
7436 .Case([&](omp::TeamsOp teamsOp) {
7437 if (teamsOp.getNumTeamsLower() == blockArg)
7438 numTeamsLower = hostEvalVar;
7439 else if (llvm::is_contained(teamsOp.getNumTeamsUpperVars(),
7440 blockArg))
7441 numTeamsUpper = hostEvalVar;
7442 else if (!teamsOp.getThreadLimitVars().empty() &&
7443 teamsOp.getThreadLimit(0) == blockArg)
7444 threadLimit = hostEvalVar;
7445 else
7446 llvm_unreachable("unsupported host_eval use");
7447 })
7448 .Case([&](omp::ParallelOp parallelOp) {
7449 if (!parallelOp.getNumThreadsVars().empty() &&
7450 parallelOp.getNumThreads(0) == blockArg)
7451 numThreads = hostEvalVar;
7452 else
7453 llvm_unreachable("unsupported host_eval use");
7454 })
7455 .Case([&](omp::LoopNestOp loopOp) {
7456 auto processBounds =
7457 [&](OperandRange opBounds,
7458 llvm::SmallVectorImpl<Value> *outBounds) -> bool {
7459 bool found = false;
7460 for (auto [i, lb] : llvm::enumerate(opBounds)) {
7461 if (lb == blockArg) {
7462 found = true;
7463 if (outBounds)
7464 (*outBounds)[i] = hostEvalVar;
7465 }
7466 }
7467 return found;
7468 };
7469 bool found =
7470 processBounds(loopOp.getLoopLowerBounds(), lowerBounds);
7471 found = processBounds(loopOp.getLoopUpperBounds(), upperBounds) ||
7472 found;
7473 found = processBounds(loopOp.getLoopSteps(), steps) || found;
7474 (void)found;
7475 assert(found && "unsupported host_eval use");
7476 })
7477 .DefaultUnreachable("unsupported host_eval use");
7478 }
7479 }
7480}
7481
7482/// If \p op is of the given type parameter, return it casted to that type.
7483/// Otherwise, if its immediate parent operation (or some other higher-level
7484/// parent, if \p immediateParent is false) is of that type, return that parent
7485/// casted to the given type.
7486///
7487/// If \p op is \c null or neither it or its parent(s) are of the specified
7488/// type, return a \c null operation.
7489template <typename OpTy>
7490static OpTy castOrGetParentOfType(Operation *op, bool immediateParent = false) {
7491 if (!op)
7492 return OpTy();
7493
7494 if (OpTy casted = dyn_cast<OpTy>(op))
7495 return casted;
7496
7497 if (immediateParent)
7498 return dyn_cast_if_present<OpTy>(op->getParentOp());
7499
7500 return op->getParentOfType<OpTy>();
7501}
7502
7503/// If the given \p value is defined by an \c llvm.mlir.constant operation and
7504/// it is of an integer type, return its value.
7505static std::optional<int64_t> extractConstInteger(Value value) {
7506 if (!value)
7507 return std::nullopt;
7508
7509 if (auto constOp = value.getDefiningOp<LLVM::ConstantOp>())
7510 if (auto constAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
7511 return constAttr.getInt();
7512
7513 return std::nullopt;
7514}
7515
7516static uint64_t getTypeByteSize(mlir::Type type, const DataLayout &dl) {
7517 uint64_t sizeInBits = dl.getTypeSizeInBits(type);
7518 uint64_t sizeInBytes = sizeInBits / 8;
7519 return sizeInBytes;
7520}
7521
7522template <typename OpTy>
7523static uint64_t getReductionDataSize(OpTy &op) {
7524 if (op.getNumReductionVars() > 0) {
7526 collectReductionDecls(op, reductions);
7527
7529 members.reserve(reductions.size());
7530 for (omp::DeclareReductionOp &red : reductions) {
7531 // For by-ref reductions, use the actual element type rather than the
7532 // pointer type so that the buffer size matches the access pattern in
7533 // the copy/reduce callbacks generated by OMPIRBuilder.
7534 if (red.getByrefElementType())
7535 members.push_back(*red.getByrefElementType());
7536 else
7537 members.push_back(red.getType());
7538 }
7539 Operation *opp = op.getOperation();
7540 auto structType = mlir::LLVM::LLVMStructType::getLiteral(
7541 opp->getContext(), members, /*isPacked=*/false);
7542 DataLayout dl = DataLayout(opp->getParentOfType<ModuleOp>());
7543 return getTypeByteSize(structType, dl);
7544 }
7545 return 0;
7546}
7547
7548/// Populate default `MinTeams`, `MaxTeams` and `MaxThreads` to their default
7549/// values as stated by the corresponding clauses, if constant.
7550///
7551/// These default values must be set before the creation of the outlined LLVM
7552/// function for the target region, so that they can be used to initialize the
7553/// corresponding global `ConfigurationEnvironmentTy` structure.
7554static void
7555initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
7556 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs,
7557 bool isTargetDevice, bool isGPU) {
7558 // TODO: Handle constant 'if' clauses.
7559
7560 Value numThreads, numTeamsLower, numTeamsUpper, threadLimit;
7561 if (!isTargetDevice) {
7562 extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper,
7563 threadLimit);
7564 } else {
7565 // In the target device, values for these clauses are not passed as
7566 // host_eval, but instead evaluated prior to entry to the region. This
7567 // ensures values are mapped and available inside of the target region.
7568 if (auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
7569 numTeamsLower = teamsOp.getNumTeamsLower();
7570 // Handle num_teams upper bounds (only first value for now)
7571 if (!teamsOp.getNumTeamsUpperVars().empty())
7572 numTeamsUpper = teamsOp.getNumTeams(0);
7573 if (!teamsOp.getThreadLimitVars().empty())
7574 threadLimit = teamsOp.getThreadLimit(0);
7575 }
7576
7577 if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) {
7578 if (!parallelOp.getNumThreadsVars().empty())
7579 numThreads = parallelOp.getNumThreads(0);
7580 }
7581 }
7582
7583 // Handle clauses impacting the number of teams.
7584
7585 int32_t minTeamsVal = 1, maxTeamsVal = -1;
7586 if (castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
7587 // TODO: Use `hostNumTeamsLower` to initialize `minTeamsVal`. For now,
7588 // match clang and set min and max to the same value.
7589 if (numTeamsUpper) {
7590 if (auto val = extractConstInteger(numTeamsUpper))
7591 minTeamsVal = maxTeamsVal = *val;
7592 } else {
7593 minTeamsVal = maxTeamsVal = 0;
7594 }
7595 } else if (castOrGetParentOfType<omp::ParallelOp>(capturedOp,
7596 /*immediateParent=*/true) ||
7598 /*immediateParent=*/true)) {
7599 minTeamsVal = maxTeamsVal = 1;
7600 } else {
7601 minTeamsVal = maxTeamsVal = -1;
7602 }
7603
7604 // Handle clauses impacting the number of threads.
7605
7606 auto setMaxValueFromClause = [](Value clauseValue, int32_t &result) {
7607 if (!clauseValue)
7608 return;
7609
7610 if (auto val = extractConstInteger(clauseValue))
7611 result = *val;
7612
7613 // Found an applicable clause, so it's not undefined. Mark as unknown
7614 // because it's not constant.
7615 if (result < 0)
7616 result = 0;
7617 };
7618
7619 // Extract 'thread_limit' clause from 'target' and 'teams' directives.
7620 int32_t targetThreadLimitVal = -1, teamsThreadLimitVal = -1;
7621 if (!targetOp.getThreadLimitVars().empty())
7622 setMaxValueFromClause(targetOp.getThreadLimit(0), targetThreadLimitVal);
7623 setMaxValueFromClause(threadLimit, teamsThreadLimitVal);
7624
7625 // Extract 'max_threads' clause from 'parallel' or set to 1 if it's SIMD.
7626 int32_t maxThreadsVal = -1;
7628 setMaxValueFromClause(numThreads, maxThreadsVal);
7629 else if (castOrGetParentOfType<omp::SimdOp>(capturedOp,
7630 /*immediateParent=*/true))
7631 maxThreadsVal = 1;
7632
7633 // For max values, < 0 means unset, == 0 means set but unknown. Select the
7634 // minimum value between 'max_threads' and 'thread_limit' clauses that were
7635 // set.
7636 int32_t combinedMaxThreadsVal = targetThreadLimitVal;
7637 if (combinedMaxThreadsVal < 0 ||
7638 (teamsThreadLimitVal >= 0 && teamsThreadLimitVal < combinedMaxThreadsVal))
7639 combinedMaxThreadsVal = teamsThreadLimitVal;
7640
7641 if (combinedMaxThreadsVal < 0 ||
7642 (maxThreadsVal >= 0 && maxThreadsVal < combinedMaxThreadsVal))
7643 combinedMaxThreadsVal = maxThreadsVal;
7644
7645 int32_t reductionDataSize = 0;
7646 if (isGPU && capturedOp) {
7647 if (auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp))
7648 reductionDataSize = getReductionDataSize(teamsOp);
7649 }
7650
7651 // Update kernel bounds structure for the `OpenMPIRBuilder` to use.
7652 omp::TargetExecMode execMode = targetOp.getKernelExecFlags(capturedOp);
7653 switch (execMode) {
7654 case omp::TargetExecMode::bare:
7655 attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_BARE;
7656 break;
7657 case omp::TargetExecMode::generic:
7658 attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_GENERIC;
7659 break;
7660 case omp::TargetExecMode::spmd:
7661 attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_SPMD;
7662 break;
7663 case omp::TargetExecMode::no_loop:
7664 attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_SPMD_NO_LOOP;
7665 break;
7666 }
7667 attrs.MinTeams = minTeamsVal;
7668 attrs.MaxTeams.front() = maxTeamsVal;
7669 attrs.MinThreads = 1;
7670 attrs.MaxThreads.front() = combinedMaxThreadsVal;
7671 attrs.ReductionDataSize = reductionDataSize;
7672 // TODO: Allow modified buffer length similar to
7673 // fopenmp-cuda-teams-reduction-recs-num flag in clang.
7674 if (attrs.ReductionDataSize != 0)
7675 attrs.ReductionBufferLength = 1024;
7676}
7677
7678/// Gather LLVM runtime values for all clauses evaluated in the host that are
7679/// passed to the kernel invocation.
7680///
7681/// This function must be called only when compiling for the host. Also, it will
7682/// only provide correct results if it's called after the body of \c targetOp
7683/// has been fully generated.
7684static void
7685initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
7686 LLVM::ModuleTranslation &moduleTranslation,
7687 omp::TargetOp targetOp, Operation *capturedOp,
7688 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs) {
7689 omp::LoopNestOp loopOp = castOrGetParentOfType<omp::LoopNestOp>(capturedOp);
7690 unsigned numLoops = loopOp ? loopOp.getNumLoops() : 0;
7691
7692 Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
7693 llvm::SmallVector<Value> lowerBounds(numLoops), upperBounds(numLoops),
7694 steps(numLoops);
7695 extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper,
7696 teamsThreadLimit, &lowerBounds, &upperBounds, &steps);
7697
7698 // TODO: Handle constant 'if' clauses.
7699 if (!targetOp.getThreadLimitVars().empty()) {
7700 Value targetThreadLimit = targetOp.getThreadLimit(0);
7701 attrs.TargetThreadLimit.front() =
7702 moduleTranslation.lookupValue(targetThreadLimit);
7703 }
7704
7705 // The __kmpc_push_num_teams_51 function expects int32 as the arguments. So,
7706 // truncate or sign extend lower and upper num_teams bounds as well as
7707 // thread_limit to match int32 ABI requirements for the OpenMP runtime.
7708 if (numTeamsLower)
7709 attrs.MinTeams = builder.CreateSExtOrTrunc(
7710 moduleTranslation.lookupValue(numTeamsLower), builder.getInt32Ty());
7711
7712 if (numTeamsUpper)
7713 attrs.MaxTeams.front() = builder.CreateSExtOrTrunc(
7714 moduleTranslation.lookupValue(numTeamsUpper), builder.getInt32Ty());
7715
7716 if (teamsThreadLimit)
7717 attrs.TeamsThreadLimit.front() = builder.CreateSExtOrTrunc(
7718 moduleTranslation.lookupValue(teamsThreadLimit), builder.getInt32Ty());
7719
7720 if (numThreads)
7721 attrs.MaxThreads = moduleTranslation.lookupValue(numThreads);
7722
7723 bool hostEvalTripCount;
7724 targetOp.getKernelExecFlags(capturedOp, &hostEvalTripCount);
7725 if (hostEvalTripCount) {
7726 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
7727 attrs.LoopTripCount = nullptr;
7728
7729 // To calculate the trip count, we multiply together the trip counts of
7730 // every collapsed canonical loop. We don't need to create the loop nests
7731 // here, since we're only interested in the trip count.
7732 for (auto [loopLower, loopUpper, loopStep] :
7733 llvm::zip_equal(lowerBounds, upperBounds, steps)) {
7734 llvm::Value *lowerBound = moduleTranslation.lookupValue(loopLower);
7735 llvm::Value *upperBound = moduleTranslation.lookupValue(loopUpper);
7736 llvm::Value *step = moduleTranslation.lookupValue(loopStep);
7737
7738 if (!lowerBound || !upperBound || !step) {
7739 attrs.LoopTripCount = nullptr;
7740 break;
7741 }
7742
7743 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
7744 llvm::Value *tripCount = ompBuilder->calculateCanonicalLoopTripCount(
7745 loc, lowerBound, upperBound, step, /*IsSigned=*/true,
7746 loopOp.getLoopInclusive());
7747
7748 if (!attrs.LoopTripCount) {
7749 attrs.LoopTripCount = tripCount;
7750 continue;
7751 }
7752
7753 // TODO: Enable UndefinedSanitizer to diagnose an overflow here.
7754 attrs.LoopTripCount = builder.CreateMul(attrs.LoopTripCount, tripCount,
7755 {}, /*HasNUW=*/true);
7756 }
7757 }
7758
7759 attrs.DeviceID = builder.getInt64(llvm::omp::OMP_DEVICEID_UNDEF);
7760 if (mlir::Value devId = targetOp.getDevice()) {
7761 attrs.DeviceID = moduleTranslation.lookupValue(devId);
7762 attrs.DeviceID =
7763 builder.CreateSExtOrTrunc(attrs.DeviceID, builder.getInt64Ty());
7764 }
7765}
7766
7767static llvm::omp::OMPDynGroupprivateFallbackType
7768getDynGroupprivateFallbackType(omp::FallbackModifierAttr fallbackAttr) {
7769 omp::FallbackModifier fb = fallbackAttr ? fallbackAttr.getValue()
7770 : omp::FallbackModifier::default_mem;
7771 switch (fb) {
7772 case omp::FallbackModifier::abort:
7773 return llvm::omp::OMPDynGroupprivateFallbackType::Abort;
7774 case omp::FallbackModifier::null:
7775 return llvm::omp::OMPDynGroupprivateFallbackType::Null;
7776 case omp::FallbackModifier::default_mem:
7777 return llvm::omp::OMPDynGroupprivateFallbackType::DefaultMem;
7778 }
7779
7780 llvm_unreachable("unexpected dyn_groupprivate fallback type");
7781}
7782
7783static LogicalResult
7784convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
7785 LLVM::ModuleTranslation &moduleTranslation) {
7786 auto targetOp = cast<omp::TargetOp>(opInst);
7787
7788 // The current debug location already has the DISubprogram for the outlined
7789 // function that will be created for the target op. We save it here so that
7790 // we can set it on the outlined function.
7791 llvm::DebugLoc outlinedFnLoc = builder.getCurrentDebugLocation();
7792 if (failed(checkImplementationStatus(opInst)))
7793 return failure();
7794
7795 // During the handling of target op, we will generate instructions in the
7796 // parent function like call to the oulined function or branch to a new
7797 // BasicBlock. We set the debug location here to parent function so that those
7798 // get the correct debug locations. For outlined functions, the normal MLIR op
7799 // conversion will automatically pick the correct location.
7800 llvm::BasicBlock *parentBB = builder.GetInsertBlock();
7801 assert(parentBB && "No insert block is set for the builder");
7802 llvm::Function *parentLLVMFn = parentBB->getParent();
7803 assert(parentLLVMFn && "Parent Function must be valid");
7804 if (llvm::DISubprogram *SP = parentLLVMFn->getSubprogram())
7805 builder.SetCurrentDebugLocation(llvm::DILocation::get(
7806 parentLLVMFn->getContext(), outlinedFnLoc.getLine(),
7807 outlinedFnLoc.getCol(), SP, outlinedFnLoc.getInlinedAt()));
7808
7809 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
7810 bool isTargetDevice = ompBuilder->Config.isTargetDevice();
7811 bool isGPU = ompBuilder->Config.isGPU();
7812
7813 auto parentFn = opInst.getParentOfType<LLVM::LLVMFuncOp>();
7814 auto argIface = cast<omp::BlockArgOpenMPOpInterface>(opInst);
7815 auto &targetRegion = targetOp.getRegion();
7816 // Holds the private vars that have been mapped along with the block
7817 // argument that corresponds to the MapInfoOp corresponding to the private
7818 // var in question. So, for instance:
7819 //
7820 // %10 = omp.map.info var_ptr(%6#0 : !fir.ref<!fir.box<!fir.heap<i32>>>, ..)
7821 // omp.target map_entries(%10 -> %arg0) private(@box.privatizer %6#0-> %arg1)
7822 //
7823 // Then, %10 has been created so that the descriptor can be used by the
7824 // privatizer @box.privatizer on the device side. Here we'd record {%6#0,
7825 // %arg0} in the mappedPrivateVars map.
7826 llvm::DenseMap<Value, Value> mappedPrivateVars;
7827 DataLayout dl = DataLayout(opInst.getParentOfType<ModuleOp>());
7828 SmallVector<Value> mapVars = targetOp.getMapVars();
7829 SmallVector<Value> hdaVars = targetOp.getHasDeviceAddrVars();
7830 ArrayRef<BlockArgument> mapBlockArgs = argIface.getMapBlockArgs();
7831 ArrayRef<BlockArgument> hdaBlockArgs = argIface.getHasDeviceAddrBlockArgs();
7832 llvm::Function *llvmOutlinedFn = nullptr;
7833 TargetDirectiveEnumTy targetDirective =
7834 getTargetDirectiveEnumTyFromOp(&opInst);
7835
7836 // TODO: It can also be false if a compile-time constant `false` IF clause is
7837 // specified.
7838 bool isOffloadEntry =
7839 isTargetDevice || !ompBuilder->Config.TargetTriples.empty();
7840
7841 // For some private variables, the MapsForPrivatizedVariablesPass
7842 // creates MapInfoOp instances. Go through the private variables and
7843 // the mapped variables so that during codegeneration we are able
7844 // to quickly look up the corresponding map variable, if any for each
7845 // private variable.
7846 if (!targetOp.getPrivateVars().empty() && !targetOp.getMapVars().empty()) {
7847 OperandRange privateVars = targetOp.getPrivateVars();
7848 std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms();
7849 std::optional<DenseI64ArrayAttr> privateMapIndices =
7850 targetOp.getPrivateMapsAttr();
7851
7852 for (auto [privVarIdx, privVarSymPair] :
7853 llvm::enumerate(llvm::zip_equal(privateVars, *privateSyms))) {
7854 auto privVar = std::get<0>(privVarSymPair);
7855 auto privSym = std::get<1>(privVarSymPair);
7856
7857 SymbolRefAttr privatizerName = llvm::cast<SymbolRefAttr>(privSym);
7858 omp::PrivateClauseOp privatizer =
7859 findPrivatizer(targetOp, privatizerName);
7860
7861 if (!privatizer.needsMap())
7862 continue;
7863
7864 mlir::Value mappedValue =
7865 targetOp.getMappedValueForPrivateVar(privVarIdx);
7866 assert(mappedValue && "Expected to find mapped value for a privatized "
7867 "variable that needs mapping");
7868
7869 // The MapInfoOp defining the map var isn't really needed later.
7870 // So, we don't store it in any datastructure. Instead, we just
7871 // do some sanity checks on it right now.
7872 auto mapInfoOp = mappedValue.getDefiningOp<omp::MapInfoOp>();
7873 [[maybe_unused]] Type varType = mapInfoOp.getVarPtrType();
7874
7875 // Check #1: Check that the type of the private variable matches
7876 // the type of the variable being mapped.
7877 if (!isa<LLVM::LLVMPointerType>(privVar.getType()))
7878 assert(
7879 varType == privVar.getType() &&
7880 "Type of private var doesn't match the type of the mapped value");
7881
7882 // Ok, only 1 sanity check for now.
7883 // Record the block argument corresponding to this mapvar.
7884 mappedPrivateVars.insert(
7885 {privVar,
7886 targetRegion.getArgument(argIface.getMapBlockArgsStart() +
7887 (*privateMapIndices)[privVarIdx])});
7888 }
7889 }
7890
7891 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
7892 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
7893 ArrayRef<llvm::BasicBlock *> deallocBlocks)
7894 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
7895 llvm::IRBuilderBase::InsertPointGuard guard(builder);
7896 builder.SetCurrentDebugLocation(llvm::DebugLoc());
7897 // Forward target-cpu and target-features function attributes from the
7898 // original function to the new outlined function.
7899 llvm::Function *llvmParentFn =
7900 moduleTranslation.lookupFunction(parentFn.getName());
7901 llvmOutlinedFn = codeGenIP.getBlock()->getParent();
7902 assert(llvmParentFn && llvmOutlinedFn &&
7903 "Both parent and outlined functions must exist at this point");
7904
7905 if (outlinedFnLoc && llvmParentFn->getSubprogram())
7906 llvmOutlinedFn->setSubprogram(outlinedFnLoc->getScope()->getSubprogram());
7907
7908 if (auto attr = llvmParentFn->getFnAttribute("target-cpu");
7909 attr.isStringAttribute())
7910 llvmOutlinedFn->addFnAttr(attr);
7911
7912 if (auto attr = llvmParentFn->getFnAttribute("target-features");
7913 attr.isStringAttribute())
7914 llvmOutlinedFn->addFnAttr(attr);
7915
7916 for (auto [arg, mapOp] : llvm::zip_equal(mapBlockArgs, mapVars)) {
7917 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
7918 llvm::Value *mapOpValue =
7919 moduleTranslation.lookupValue(mapInfoOp.getVarPtr());
7920 moduleTranslation.mapValue(arg, mapOpValue);
7921 }
7922 for (auto [arg, mapOp] : llvm::zip_equal(hdaBlockArgs, hdaVars)) {
7923 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
7924 llvm::Value *mapOpValue =
7925 moduleTranslation.lookupValue(mapInfoOp.getVarPtr());
7926 moduleTranslation.mapValue(arg, mapOpValue);
7927 }
7928
7929 // Do privatization after moduleTranslation has already recorded
7930 // mapped values.
7931 PrivateVarsInfo privateVarsInfo(targetOp);
7932
7934 allocatePrivateVars(targetOp, builder, moduleTranslation,
7935 privateVarsInfo, allocaIP, &mappedPrivateVars);
7936
7937 if (failed(handleError(afterAllocas, *targetOp)))
7938 return llvm::make_error<PreviouslyReportedError>();
7939
7940 builder.restoreIP(codeGenIP);
7941 if (handleError(initPrivateVars(builder, moduleTranslation, privateVarsInfo,
7942 &mappedPrivateVars),
7943 *targetOp)
7944 .failed())
7945 return llvm::make_error<PreviouslyReportedError>();
7946
7947 if (failed(copyFirstPrivateVars(
7948 targetOp, builder, moduleTranslation, privateVarsInfo.mlirVars,
7949 privateVarsInfo.llvmVars, privateVarsInfo.privatizers,
7950 targetOp.getPrivateNeedsBarrier(), &mappedPrivateVars)))
7951 return llvm::make_error<PreviouslyReportedError>();
7952
7954 moduleTranslation, allocaIP, deallocBlocks);
7956 targetRegion, "omp.target", builder, moduleTranslation);
7957
7958 if (failed(handleError(exitBlock, *targetOp)))
7959 return llvm::make_error<PreviouslyReportedError>();
7960
7961 builder.SetInsertPoint(exitBlock.get()->getTerminator());
7962
7963 if (failed(cleanupPrivateVars(targetOp, builder, moduleTranslation,
7964 targetOp.getLoc(), privateVarsInfo)))
7965 return llvm::make_error<PreviouslyReportedError>();
7966
7967 return builder.saveIP();
7968 };
7969
7970 StringRef parentName = parentFn.getName();
7971
7972 llvm::TargetRegionEntryInfo entryInfo;
7973
7974 getTargetEntryUniqueInfo(entryInfo, targetOp,
7975 *moduleTranslation.getOpenMPBuilder(),
7976 moduleTranslation.getFileSystem(), parentName);
7977
7978 MapInfoData mapData;
7979 collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, dl,
7980 builder, /*useDevPtrOperands=*/{},
7981 /*useDevAddrOperands=*/{}, hdaVars);
7982
7983 MapInfosTy combinedInfos;
7984 auto genMapInfoCB =
7985 [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP) -> MapInfosTy & {
7986 builder.restoreIP(codeGenIP);
7987 genMapInfos(builder, moduleTranslation, dl, combinedInfos, mapData,
7988 targetDirective);
7989
7990 // Append a null entry for the implicit dyn_ptr argument so the argument
7991 // count sent to the runtime already includes it.
7992 auto *nullPtr = llvm::Constant::getNullValue(builder.getPtrTy());
7993 combinedInfos.BasePointers.push_back(nullPtr);
7994 combinedInfos.Pointers.push_back(nullPtr);
7995 combinedInfos.DevicePointers.push_back(
7996 llvm::OpenMPIRBuilder::DeviceInfoTy::None);
7997 combinedInfos.Sizes.push_back(builder.getInt64(0));
7998 combinedInfos.Types.push_back(
7999 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM |
8000 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL);
8001 if (!combinedInfos.Names.empty())
8002 combinedInfos.Names.push_back(nullPtr);
8003 combinedInfos.Mappers.push_back(nullptr);
8004
8005 return combinedInfos;
8006 };
8007
8008 auto argAccessorCB = [&](llvm::Argument &arg, llvm::Value *input,
8009 llvm::Value *&retVal, InsertPointTy allocaIP,
8010 InsertPointTy codeGenIP,
8012 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
8013 llvm::IRBuilderBase::InsertPointGuard guard(builder);
8014 builder.SetCurrentDebugLocation(llvm::DebugLoc());
8015 // We just return the unaltered argument for the host function
8016 // for now, some alterations may be required in the future to
8017 // keep host fallback functions working identically to the device
8018 // version (e.g. pass ByCopy values should be treated as such on
8019 // host and device, currently not always the case)
8020 if (!isTargetDevice) {
8021 retVal = cast<llvm::Value>(&arg);
8022 return codeGenIP;
8023 }
8024
8025 return createDeviceArgumentAccessor(targetOp, mapData, arg, input, retVal,
8026 builder, *ompBuilder, moduleTranslation,
8027 allocaIP, codeGenIP, deallocIPs);
8028 };
8029
8030 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs runtimeAttrs;
8031 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs;
8032 Operation *targetCapturedOp = targetOp.getInnermostCapturedOmpOp();
8033 initTargetDefaultAttrs(targetOp, targetCapturedOp, defaultAttrs,
8034 isTargetDevice, isGPU);
8035
8036 // Collect host-evaluated values needed to properly launch the kernel from the
8037 // host.
8038 if (!isTargetDevice)
8039 initTargetRuntimeAttrs(builder, moduleTranslation, targetOp,
8040 targetCapturedOp, runtimeAttrs);
8041
8042 // Pass host-evaluated values as parameters to the kernel / host fallback,
8043 // except if they are constants. In any case, map the MLIR block argument to
8044 // the corresponding LLVM values.
8046 SmallVector<Value> hostEvalVars = targetOp.getHostEvalVars();
8047 ArrayRef<BlockArgument> hostEvalBlockArgs = argIface.getHostEvalBlockArgs();
8048 for (auto [arg, var] : llvm::zip_equal(hostEvalBlockArgs, hostEvalVars)) {
8049 llvm::Value *value = moduleTranslation.lookupValue(var);
8050 moduleTranslation.mapValue(arg, value);
8051
8052 if (!llvm::isa<llvm::Constant>(value))
8053 kernelInput.push_back(value);
8054 }
8055
8056 for (size_t i = 0, e = mapData.OriginalValue.size(); i != e; ++i) {
8057 // 1) Declare target arguments are not passed to kernels as arguments.
8058 // 2) Attach maps are not passed in as arguments to kernels.
8059 // 3) Children of record objects are not passed in as arguments.
8060 // TODO: We currently do not handle cases where a member is explicitly
8061 // passed in as an argument, this will likley need to be handled in
8062 // the near future, rather than using IsAMember, it may be better to
8063 // test if the relevant BlockArg is used within the target region and
8064 // then use that as a basis for exclusion in the kernel inputs.
8065 bool isAttachMap = (mapData.Types[i] &
8066 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH) ==
8067 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH;
8068 if (!mapData.IsDeclareTarget[i] && !mapData.IsAMember[i] && !isAttachMap)
8069 kernelInput.push_back(mapData.OriginalValue[i]);
8070 }
8071
8073 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
8074 findAllocInsertPoints(builder, moduleTranslation, &deallocBlocks);
8075
8076 llvm::OpenMPIRBuilder::DependenciesInfo dds;
8077 if (failed(buildDependData(
8078 targetOp.getDependVars(), targetOp.getDependKinds(),
8079 targetOp.getDependIterated(), targetOp.getDependIteratedKinds(),
8080 builder, moduleTranslation, dds)))
8081 return failure();
8082
8083 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
8084
8085 llvm::OpenMPIRBuilder::TargetDataInfo info(
8086 /*RequiresDevicePointerInfo=*/false,
8087 /*SeparateBeginEndCalls=*/true);
8088
8089 auto customMapperCB =
8090 [&](unsigned int i) -> llvm::Expected<llvm::Function *> {
8091 if (!combinedInfos.Mappers[i])
8092 return nullptr;
8093 info.HasMapper = true;
8094 return getOrCreateUserDefinedMapperFunc(combinedInfos.Mappers[i], builder,
8095 moduleTranslation, targetDirective);
8096 };
8097
8098 llvm::Value *ifCond = nullptr;
8099 if (Value targetIfCond = targetOp.getIfExpr())
8100 ifCond = moduleTranslation.lookupValue(targetIfCond);
8101
8102 Value dynGroupPrivateSize = targetOp.getDynGroupprivateSize();
8103 llvm::Value *dynSizeVal = nullptr;
8104 if (dynGroupPrivateSize) {
8105 dynSizeVal = moduleTranslation.lookupValue(dynGroupPrivateSize);
8106 dynSizeVal = builder.CreateIntCast(dynSizeVal, builder.getInt32Ty(),
8107 /*isSigned=*/false);
8108 }
8109
8110 llvm::omp::OMPDynGroupprivateFallbackType fallbackType =
8111 getDynGroupprivateFallbackType(targetOp.getDynGroupprivateFallbackAttr());
8112
8113 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
8114 moduleTranslation.getOpenMPBuilder()->createTarget(
8115 ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), deallocBlocks,
8116 info, entryInfo, defaultAttrs, runtimeAttrs, ifCond, kernelInput,
8117 genMapInfoCB, bodyCB, argAccessorCB, customMapperCB, dds,
8118 targetOp.getNowait(), dynSizeVal, fallbackType);
8119
8120 if (failed(handleError(afterIP, opInst)))
8121 return failure();
8122
8123 builder.restoreIP(*afterIP);
8124
8125 if (dds.DepArray)
8126 builder.CreateFree(dds.DepArray);
8127
8128 // Remap access operations to declare target reference pointers for the
8129 // device, essentially generating extra loadop's as necessary
8130 if (moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice())
8131 handleDeclareTargetMapVar(mapData, moduleTranslation, builder,
8132 llvmOutlinedFn);
8133
8134 return success();
8135}
8136
8137static LogicalResult
8138convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute,
8139 llvm::OpenMPIRBuilder *ompBuilder,
8140 LLVM::ModuleTranslation &moduleTranslation) {
8141 // Amend omp.declare_target by deleting the IR of the outlined functions
8142 // created for target regions. They cannot be filtered out from MLIR earlier
8143 // because the omp.target operation inside must be translated to LLVM, but
8144 // the wrapper functions themselves must not remain at the end of the
8145 // process. We know that functions where omp.declare_target does not match
8146 // omp.is_target_device at this stage can only be wrapper functions because
8147 // those that aren't are removed earlier as an MLIR transformation pass.
8148 if (FunctionOpInterface funcOp = dyn_cast<FunctionOpInterface>(op)) {
8149 if (auto offloadMod = dyn_cast<omp::OffloadModuleInterface>(
8150 op->getParentOfType<ModuleOp>().getOperation())) {
8151 if (!offloadMod.getIsTargetDevice())
8152 return success();
8153
8154 omp::DeclareTargetDeviceType declareType =
8155 attribute.getDeviceType().getValue();
8156
8157 if (declareType == omp::DeclareTargetDeviceType::host) {
8158 llvm::Function *llvmFunc =
8159 moduleTranslation.lookupFunction(funcOp.getName());
8160 llvmFunc->dropAllReferences();
8161 llvmFunc->eraseFromParent();
8162
8163 // Invalidate the builder's current insertion point, as it now points to
8164 // a deleted block.
8165 ompBuilder->Builder.ClearInsertionPoint();
8166 ompBuilder->Builder.SetCurrentDebugLocation(llvm::DebugLoc());
8167 }
8168 }
8169 return success();
8170 }
8171
8172 if (LLVM::GlobalOp gOp = dyn_cast<LLVM::GlobalOp>(op)) {
8173 llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
8174 if (auto *gVal = llvmModule->getNamedValue(gOp.getSymName())) {
8175 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
8176 bool isDeclaration = gOp.isDeclaration();
8177 bool isExternallyVisible =
8178 gOp.getVisibility() != mlir::SymbolTable::Visibility::Private;
8179 auto loc = op->getLoc()->findInstanceOf<FileLineColLoc>();
8180 llvm::StringRef mangledName = gOp.getSymName();
8181 auto captureClause =
8182 convertToCaptureClauseKind(attribute.getCaptureClause().getValue());
8183 auto deviceClause =
8184 convertToDeviceClauseKind(attribute.getDeviceType().getValue());
8185 // unused for MLIR at the moment, required in Clang for book
8186 // keeping
8187 std::vector<llvm::GlobalVariable *> generatedRefs;
8188
8189 std::vector<llvm::Triple> targetTriple;
8190 auto targetTripleAttr = dyn_cast_or_null<mlir::StringAttr>(
8191 op->getParentOfType<mlir::ModuleOp>()->getAttr(
8192 LLVM::LLVMDialect::getTargetTripleAttrName()));
8193 if (targetTripleAttr)
8194 targetTriple.emplace_back(targetTripleAttr.data());
8195
8196 auto fileInfoCallBack = [&loc]() {
8197 std::string filename = "";
8198 std::uint64_t lineNo = 0;
8199
8200 if (loc) {
8201 filename = loc.getFilename().str();
8202 lineNo = loc.getLine();
8203 }
8204
8205 return std::pair<std::string, std::uint64_t>(llvm::StringRef(filename),
8206 lineNo);
8207 };
8208
8209 llvm::vfs::FileSystem &vfs = moduleTranslation.getFileSystem();
8210
8211 ompBuilder->registerTargetGlobalVariable(
8212 captureClause, deviceClause, isDeclaration, isExternallyVisible,
8213 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack, vfs),
8214 mangledName, generatedRefs, /*OpenMPSimd*/ false, targetTriple,
8215 /*GlobalInitializer*/ nullptr, /*VariableLinkage*/ nullptr,
8216 gVal->getType(), gVal);
8217
8218 bool requiresUSM = ompBuilder->Config.hasRequiresUnifiedSharedMemory();
8219 if (ompBuilder->Config.isTargetDevice() &&
8220 (attribute.getCaptureClause().getValue() ==
8221 mlir::omp::DeclareTargetCaptureClause::link ||
8222 requiresUSM)) {
8223 llvm::Type *ptrTy = gVal->getType();
8224 // For USM the global type becomes a pointer handle, as opposed to the
8225 // globals original type.
8226 if (requiresUSM)
8227 ptrTy = llvm::PointerType::get(llvmModule->getContext(), 0);
8228 ompBuilder->getAddrOfDeclareTargetVar(
8229 captureClause, deviceClause, isDeclaration, isExternallyVisible,
8230 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack, vfs),
8231 mangledName, generatedRefs, /*OpenMPSimd*/ false, targetTriple,
8232 ptrTy, /*GlobalInitializer*/ nullptr,
8233 /*VariableLinkage*/ nullptr);
8234 }
8235 }
8236 }
8237
8238 return success();
8239}
8240
8241namespace {
8242
8243/// Implementation of the dialect interface that converts operations belonging
8244/// to the OpenMP dialect to LLVM IR.
8245class OpenMPDialectLLVMIRTranslationInterface
8246 : public LLVMTranslationDialectInterface {
8247public:
8248 using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
8249
8250 /// Translates the given operation to LLVM IR using the provided IR builder
8251 /// and saving the state in `moduleTranslation`.
8252 LogicalResult
8253 convertOperation(Operation *op, llvm::IRBuilderBase &builder,
8254 LLVM::ModuleTranslation &moduleTranslation) const final;
8255
8256 /// Given an OpenMP MLIR attribute, create the corresponding LLVM-IR,
8257 /// runtime calls, or operation amendments
8258 LogicalResult
8259 amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
8260 NamedAttribute attribute,
8261 LLVM::ModuleTranslation &moduleTranslation) const final;
8262
8263 /// Records the LLVM alloc pointer produced for an OMP ALLOCATE variable so
8264 /// that the paired omp.allocate_free op can generate the matching
8265 /// __kmpc_free call.
8266 void registerAllocatedPtr(Value var, llvm::Value *ptr) const {
8267 ompAllocatedPtrs[var] = ptr;
8268 }
8269
8270 /// Returns the LLVM alloc pointer previously registered for var, or
8271 /// nullptr if no allocation was recorded.
8272 llvm::Value *lookupAllocatedPtr(Value var) const {
8273 auto it = ompAllocatedPtrs.find(var);
8274 return it != ompAllocatedPtrs.end() ? it->second : nullptr;
8275 }
8276
8277private:
8278 /// Maps each MLIR variable value that appeared in an omp.allocate_dir op to
8279 /// the LLVM pointer returned by the corresponding __kmpc_alloc call. The
8280 /// paired omp.allocate_free op looks up these pointers to emit __kmpc_free.
8281 mutable DenseMap<Value, llvm::Value *> ompAllocatedPtrs;
8282};
8283
8284} // namespace
8285
8286LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
8287 Operation *op, ArrayRef<llvm::Instruction *> instructions,
8288 NamedAttribute attribute,
8289 LLVM::ModuleTranslation &moduleTranslation) const {
8290 return llvm::StringSwitch<llvm::function_ref<LogicalResult(Attribute)>>(
8291 attribute.getName())
8292 .Case("omp.is_target_device",
8293 [&](Attribute attr) {
8294 if (auto deviceAttr = dyn_cast<BoolAttr>(attr)) {
8295 llvm::OpenMPIRBuilderConfig &config =
8296 moduleTranslation.getOpenMPBuilder()->Config;
8297 config.setIsTargetDevice(deviceAttr.getValue());
8298 return success();
8299 }
8300 return failure();
8301 })
8302 .Case("omp.is_gpu",
8303 [&](Attribute attr) {
8304 if (auto gpuAttr = dyn_cast<BoolAttr>(attr)) {
8305 llvm::OpenMPIRBuilderConfig &config =
8306 moduleTranslation.getOpenMPBuilder()->Config;
8307 config.setIsGPU(gpuAttr.getValue());
8308 return success();
8309 }
8310 return failure();
8311 })
8312 .Case("omp.host_ir_filepath",
8313 [&](Attribute attr) {
8314 if (auto filepathAttr = dyn_cast<StringAttr>(attr)) {
8315 llvm::OpenMPIRBuilder *ompBuilder =
8316 moduleTranslation.getOpenMPBuilder();
8317 ompBuilder->loadOffloadInfoMetadata(
8318 moduleTranslation.getFileSystem(), filepathAttr.getValue());
8319 return success();
8320 }
8321 return failure();
8322 })
8323 .Case("omp.flags",
8324 [&](Attribute attr) {
8325 if (auto rtlAttr = dyn_cast<omp::FlagsAttr>(attr))
8326 return convertFlagsAttr(op, rtlAttr, moduleTranslation);
8327 return failure();
8328 })
8329 .Case("omp.version",
8330 [&](Attribute attr) {
8331 if (auto versionAttr = dyn_cast<omp::VersionAttr>(attr)) {
8332 llvm::OpenMPIRBuilder *ompBuilder =
8333 moduleTranslation.getOpenMPBuilder();
8334 ompBuilder->M.addModuleFlag(llvm::Module::Max, "openmp",
8335 versionAttr.getVersion());
8336 return success();
8337 }
8338 return failure();
8339 })
8340 .Case("omp.declare_target",
8341 [&](Attribute attr) {
8342 if (auto declareTargetAttr =
8343 dyn_cast<omp::DeclareTargetAttr>(attr)) {
8344 llvm::OpenMPIRBuilder *ompBuilder =
8345 moduleTranslation.getOpenMPBuilder();
8346 return convertDeclareTargetAttr(op, declareTargetAttr,
8347 ompBuilder, moduleTranslation);
8348 }
8349 return failure();
8350 })
8351 .Case("omp.requires",
8352 [&](Attribute attr) {
8353 if (auto requiresAttr = dyn_cast<omp::ClauseRequiresAttr>(attr)) {
8354 using Requires = omp::ClauseRequires;
8355 Requires flags = requiresAttr.getValue();
8356 llvm::OpenMPIRBuilderConfig &config =
8357 moduleTranslation.getOpenMPBuilder()->Config;
8358 config.setHasRequiresReverseOffload(
8359 bitEnumContainsAll(flags, Requires::reverse_offload));
8360 config.setHasRequiresUnifiedAddress(
8361 bitEnumContainsAll(flags, Requires::unified_address));
8362 config.setHasRequiresUnifiedSharedMemory(
8363 bitEnumContainsAll(flags, Requires::unified_shared_memory));
8364 config.setHasRequiresDynamicAllocators(
8365 bitEnumContainsAll(flags, Requires::dynamic_allocators));
8366 return success();
8367 }
8368 return failure();
8369 })
8370 .Case("omp.target_triples",
8371 [&](Attribute attr) {
8372 if (auto triplesAttr = dyn_cast<ArrayAttr>(attr)) {
8373 llvm::OpenMPIRBuilderConfig &config =
8374 moduleTranslation.getOpenMPBuilder()->Config;
8375 config.TargetTriples.clear();
8376 config.TargetTriples.reserve(triplesAttr.size());
8377 for (Attribute tripleAttr : triplesAttr) {
8378 if (auto tripleStrAttr = dyn_cast<StringAttr>(tripleAttr))
8379 config.TargetTriples.emplace_back(tripleStrAttr.getValue());
8380 else
8381 return failure();
8382 }
8383 return success();
8384 }
8385 return failure();
8386 })
8387 .Default([](Attribute) {
8388 // Fall through for omp attributes that do not require lowering.
8389 return success();
8390 })(attribute.getValue());
8391
8392 return failure();
8393}
8394
8395// Returns true if the operation is not inside a TargetOp, it is part of a
8396// function and that function is not declare target.
8397static bool isHostDeviceOp(Operation *op) {
8398 // Assumes no reverse offloading
8399 if (op->getParentOfType<omp::TargetOp>())
8400 return false;
8401
8402 if (auto parentFn = op->getParentOfType<LLVM::LLVMFuncOp>()) {
8403 if (auto declareTargetIface =
8404 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
8405 parentFn.getOperation()))
8406 if (declareTargetIface.isDeclareTarget() &&
8407 declareTargetIface.getDeclareTargetDeviceType() !=
8408 mlir::omp::DeclareTargetDeviceType::host)
8409 return false;
8410
8411 return true;
8412 }
8413
8414 return false;
8415}
8416
8417static llvm::Function *getOmpTargetAlloc(llvm::IRBuilderBase &builder,
8418 llvm::Module *llvmModule) {
8419 llvm::Type *i64Ty = builder.getInt64Ty();
8420 llvm::Type *i32Ty = builder.getInt32Ty();
8421 llvm::Type *returnType = builder.getPtrTy(0);
8422 llvm::FunctionType *fnType =
8423 llvm::FunctionType::get(returnType, {i64Ty, i32Ty}, false);
8424 llvm::Function *func = cast<llvm::Function>(
8425 llvmModule->getOrInsertFunction("omp_target_alloc", fnType).getCallee());
8426 return func;
8427}
8428
8429template <typename T>
8430static llvm::Value *
8431getAllocationSize(llvm::IRBuilderBase &builder,
8432 LLVM::ModuleTranslation &moduleTranslation, T op) {
8433 llvm::DataLayout dataLayout =
8434 moduleTranslation.getLLVMModule()->getDataLayout();
8435 llvm::Type *llvmHeapTy =
8436 moduleTranslation.convertType(op.getMemElemTypeAttr().getValue());
8437
8438 auto alignment = op.getMemAlignment();
8439 llvm::TypeSize typeSize = llvm::alignTo(
8440 dataLayout.getTypeStoreSize(llvmHeapTy),
8441 alignment ? *alignment : dataLayout.getABITypeAlign(llvmHeapTy).value());
8442
8443 llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue());
8444 return builder.CreateMul(
8445 allocSize,
8446 builder.CreateIntCast(moduleTranslation.lookupValue(op.getMemArraySize()),
8447 builder.getInt64Ty(),
8448 /*isSigned=*/false));
8449}
8450
8451template <>
8452llvm::Value *getAllocationSize(llvm::IRBuilderBase &builder,
8453 LLVM::ModuleTranslation &moduleTranslation,
8454 omp::TargetAllocMemOp op) {
8455 llvm::DataLayout dataLayout =
8456 moduleTranslation.getLLVMModule()->getDataLayout();
8457 llvm::Type *llvmHeapTy = moduleTranslation.convertType(op.getAllocatedType());
8458 llvm::TypeSize typeSize = dataLayout.getTypeAllocSize(llvmHeapTy);
8459 llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue());
8460 for (auto typeParam : op.getTypeparams()) {
8461 allocSize = builder.CreateMul(
8462 allocSize,
8463 builder.CreateIntCast(moduleTranslation.lookupValue(typeParam),
8464 builder.getInt64Ty(),
8465 /*isSigned=*/false));
8466 }
8467 return allocSize;
8468}
8469
8470static LogicalResult
8471convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder,
8472 LLVM::ModuleTranslation &moduleTranslation) {
8473 auto allocMemOp = cast<omp::TargetAllocMemOp>(opInst);
8474 if (!allocMemOp)
8475 return failure();
8476
8477 // Get "omp_target_alloc" function
8478 llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
8479 llvm::Function *ompTargetAllocFunc = getOmpTargetAlloc(builder, llvmModule);
8480 // Get the corresponding device value in llvm
8481 mlir::Value deviceNum = allocMemOp.getDevice();
8482 llvm::Value *llvmDeviceNum = moduleTranslation.lookupValue(deviceNum);
8483 // Get the allocation size.
8484 llvm::Value *allocSize =
8485 getAllocationSize(builder, moduleTranslation, allocMemOp);
8486 // Create call to "omp_target_alloc" with the args as translated llvm values.
8487 llvm::CallInst *call =
8488 builder.CreateCall(ompTargetAllocFunc, {allocSize, llvmDeviceNum});
8489 llvm::Value *resultI64 = builder.CreatePtrToInt(call, builder.getInt64Ty());
8490
8491 // Map the result
8492 moduleTranslation.mapValue(allocMemOp.getResult(), resultI64);
8493 return success();
8494}
8495
8496static LogicalResult
8497convertAllocSharedMemOp(omp::AllocSharedMemOp allocMemOp,
8498 llvm::IRBuilderBase &builder,
8499 LLVM::ModuleTranslation &moduleTranslation) {
8500 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
8501 llvm::Value *size = getAllocationSize(builder, moduleTranslation, allocMemOp);
8502 moduleTranslation.mapValue(allocMemOp.getResult(),
8503 ompBuilder->createOMPAllocShared(builder, size));
8504 return success();
8505}
8506
8507static LogicalResult
8508convertAllocateDirOp(Operation &opInst, llvm::IRBuilderBase &builder,
8509 LLVM::ModuleTranslation &moduleTranslation,
8510 const OpenMPDialectLLVMIRTranslationInterface &ompIface) {
8511 auto allocateDirOp = cast<omp::AllocateDirOp>(opInst);
8512 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
8513
8514 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
8515 llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
8516 llvm::DataLayout dataLayout = llvmModule->getDataLayout();
8517 SmallVector<Value> vars = allocateDirOp.getVarList();
8518 std::optional<int64_t> alignAttr = allocateDirOp.getAlign();
8519
8520 llvm::Value *allocator;
8521 if (auto allocatorVar = allocateDirOp.getAllocator()) {
8522 allocator = moduleTranslation.lookupValue(allocatorVar);
8523 if (allocator->getType()->isIntegerTy())
8524 allocator = builder.CreateIntToPtr(allocator, builder.getPtrTy());
8525 else if (allocator->getType()->isPointerTy())
8526 allocator = builder.CreatePointerBitCastOrAddrSpaceCast(
8527 allocator, builder.getPtrTy());
8528 } else {
8529 allocator = llvm::ConstantPointerNull::get(builder.getPtrTy());
8530 }
8531
8532 for (Value var : vars) {
8533 llvm::Type *llvmVarTy = moduleTranslation.convertType(var.getType());
8534
8535 // Opaque pointers lose element type. Trace to GlobalOp for type
8536 // Falls back to llvmVarTy when not from a global.
8537 llvm::Type *typeToInspect = llvmVarTy;
8538 if (llvmVarTy->isPointerTy()) {
8539 Value baseVar = getBaseValueForTypeLookup(var);
8540 if (Operation *globalOp = getGlobalOpFromValue(baseVar)) {
8541 if (auto gop = dyn_cast<LLVM::GlobalOp>(globalOp))
8542 typeToInspect = moduleTranslation.convertType(gop.getGlobalType());
8543 }
8544 }
8545
8546 llvm::Value *size;
8547 if (auto arrTy = llvm::dyn_cast<llvm::ArrayType>(typeToInspect)) {
8548 llvm::Value *elementCount = builder.getInt64(1);
8549 llvm::Type *currentType = arrTy;
8550 while (auto nestedArrTy = llvm::dyn_cast<llvm::ArrayType>(currentType)) {
8551 elementCount = builder.CreateMul(
8552 elementCount, builder.getInt64(nestedArrTy->getNumElements()));
8553 currentType = nestedArrTy->getElementType();
8554 }
8555 uint64_t elemSizeInBits = dataLayout.getTypeSizeInBits(currentType);
8556 size =
8557 builder.CreateMul(elementCount, builder.getInt64(elemSizeInBits / 8));
8558 } else {
8559 size = builder.getInt64(
8560 dataLayout.getTypeStoreSize(typeToInspect).getFixedValue());
8561 }
8562
8563 uint64_t alignValue =
8564 alignAttr ? alignAttr.value()
8565 : dataLayout.getABITypeAlign(typeToInspect).value();
8566 llvm::Value *alignConst = builder.getInt64(alignValue);
8567 // Align the size: ((size + align - 1) / align) * align
8568 size = builder.CreateAdd(size, builder.getInt64(alignValue - 1), "", true);
8569 size = builder.CreateUDiv(size, alignConst);
8570 size = builder.CreateMul(size, alignConst, "", true);
8571
8572 std::string allocName =
8573 ompBuilder->createPlatformSpecificName({".void.addr"});
8574 llvm::CallInst *allocCall;
8575 if (alignAttr.has_value()) {
8576 allocCall = ompBuilder->createOMPAlignedAlloc(
8577 ompLoc, builder.getInt64(alignAttr.value()), size, allocator,
8578 allocName);
8579 } else {
8580 allocCall =
8581 ompBuilder->createOMPAlloc(ompLoc, size, allocator, allocName);
8582 }
8583 // Record the alloc pointer keyed by the MLIR variable value.
8584 ompIface.registerAllocatedPtr(var, allocCall);
8585 }
8586
8587 return success();
8588}
8589
8590static LogicalResult
8591convertAllocateFreeOp(Operation &opInst, llvm::IRBuilderBase &builder,
8592 LLVM::ModuleTranslation &moduleTranslation,
8593 const OpenMPDialectLLVMIRTranslationInterface &ompIface) {
8594 auto freeOp = cast<omp::AllocateFreeOp>(opInst);
8595 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
8596 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
8597
8598 llvm::Value *allocator;
8599 if (auto allocatorVar = freeOp.getAllocator()) {
8600 allocator = moduleTranslation.lookupValue(allocatorVar);
8601 if (allocator->getType()->isIntegerTy())
8602 allocator = builder.CreateIntToPtr(allocator, builder.getPtrTy());
8603 else if (allocator->getType()->isPointerTy())
8604 allocator = builder.CreatePointerBitCastOrAddrSpaceCast(
8605 allocator, builder.getPtrTy());
8606 } else {
8607 allocator = llvm::ConstantPointerNull::get(builder.getPtrTy());
8608 }
8609
8610 // Emit __kmpc_free for each variable in reverse allocation order.
8611 SmallVector<Value> vars = freeOp.getVarList();
8612 for (Value var : llvm::reverse(vars)) {
8613 llvm::Value *allocPtr = ompIface.lookupAllocatedPtr(var);
8614 if (!allocPtr)
8615 return opInst.emitError("omp.allocate_free: no allocation recorded");
8616 ompBuilder->createOMPFree(ompLoc, allocPtr, allocator, "");
8617 }
8618
8619 return success();
8620}
8621
8622static llvm::Function *getOmpTargetFree(llvm::IRBuilderBase &builder,
8623 llvm::Module *llvmModule) {
8624 llvm::Type *ptrTy = builder.getPtrTy(0);
8625 llvm::Type *i32Ty = builder.getInt32Ty();
8626 llvm::Type *voidTy = builder.getVoidTy();
8627 llvm::FunctionType *fnType =
8628 llvm::FunctionType::get(voidTy, {ptrTy, i32Ty}, false);
8629 llvm::Function *func = dyn_cast<llvm::Function>(
8630 llvmModule->getOrInsertFunction("omp_target_free", fnType).getCallee());
8631 return func;
8632}
8633
8634static LogicalResult
8635convertTargetFreeMemOp(Operation &opInst, llvm::IRBuilderBase &builder,
8636 LLVM::ModuleTranslation &moduleTranslation) {
8637 auto freeMemOp = cast<omp::TargetFreeMemOp>(opInst);
8638 if (!freeMemOp)
8639 return failure();
8640
8641 // Get "omp_target_free" function
8642 llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
8643 llvm::Function *ompTragetFreeFunc = getOmpTargetFree(builder, llvmModule);
8644 // Get the corresponding device value in llvm
8645 mlir::Value deviceNum = freeMemOp.getDevice();
8646 llvm::Value *llvmDeviceNum = moduleTranslation.lookupValue(deviceNum);
8647 // Get the corresponding heapref value in llvm
8648 mlir::Value heapref = freeMemOp.getHeapref();
8649 llvm::Value *llvmHeapref = moduleTranslation.lookupValue(heapref);
8650 // Convert heapref int to ptr and call "omp_target_free"
8651 llvm::Value *intToPtr =
8652 builder.CreateIntToPtr(llvmHeapref, builder.getPtrTy(0));
8653 builder.CreateCall(ompTragetFreeFunc, {intToPtr, llvmDeviceNum});
8654 return success();
8655}
8656
8657static LogicalResult
8658convertFreeSharedMemOp(omp::FreeSharedMemOp freeMemOp,
8659 llvm::IRBuilderBase &builder,
8660 LLVM::ModuleTranslation &moduleTranslation) {
8661 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
8662 llvm::Value *size = getAllocationSize(builder, moduleTranslation, freeMemOp);
8663 ompBuilder->createOMPFreeShared(
8664 builder, moduleTranslation.lookupValue(freeMemOp.getHeapref()), size);
8665 return success();
8666}
8667
8668/// Converts an OpenMP groupprivate operation into LLVM IR.
8669static LogicalResult
8670convertOmpGroupprivate(Operation &opInst, llvm::IRBuilderBase &builder,
8671 LLVM::ModuleTranslation &moduleTranslation) {
8672 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
8673 auto groupprivateOp = cast<omp::GroupprivateOp>(opInst);
8674
8675 if (failed(checkImplementationStatus(opInst)))
8676 return failure();
8677
8678 bool isTargetDevice = ompBuilder->Config.isTargetDevice();
8679
8680 // Determine whether group-private storage should be allocated based on
8681 // device_type. When not specified, default to 'any' (allocate on both).
8682 bool shouldAllocate = true;
8683 switch (groupprivateOp.getDeviceType().value_or(
8684 mlir::omp::DeclareTargetDeviceType::any)) {
8685 case mlir::omp::DeclareTargetDeviceType::host:
8686 shouldAllocate = !isTargetDevice;
8687 break;
8688 case mlir::omp::DeclareTargetDeviceType::nohost:
8689 shouldAllocate = isTargetDevice;
8690 break;
8691 case mlir::omp::DeclareTargetDeviceType::any:
8692 shouldAllocate = true;
8693 break;
8694 }
8695
8696 // Look up the global variable directly by symbol name.
8698 &opInst, groupprivateOp.getSymNameAttr());
8699 if (!global)
8700 return opInst.emitError()
8701 << "expected symbol '" << groupprivateOp.getSymName()
8702 << "' to reference an LLVM global variable";
8703
8704 llvm::GlobalValue *globalValue = moduleTranslation.lookupGlobal(global);
8705 llvm::Type *varType = moduleTranslation.convertType(global.getType());
8706 std::string varName = globalValue->getName().str();
8707
8708 llvm::Value *resultPtr;
8709 if (shouldAllocate && isTargetDevice) {
8710 llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
8711 llvm::Triple targetTriple(llvmModule->getTargetTriple());
8712 unsigned sharedAddressSpace;
8713 if (targetTriple.isAMDGCN())
8714 sharedAddressSpace = llvm::AMDGPUAS::LOCAL_ADDRESS;
8715 else if (targetTriple.isNVPTX())
8716 sharedAddressSpace = llvm::NVPTXAS::ADDRESS_SPACE_SHARED;
8717 else
8718 return opInst.emitError() << "groupprivate is not supported for target: "
8719 << targetTriple.str();
8720 llvm::GlobalVariable *sharedVar = new llvm::GlobalVariable(
8721 *llvmModule, varType, /*isConstant=*/false,
8722 llvm::GlobalValue::InternalLinkage, llvm::PoisonValue::get(varType),
8723 varName, /*InsertBefore=*/nullptr, llvm::GlobalValue::NotThreadLocal,
8724 sharedAddressSpace,
8725 /*isExternallyInitialized=*/false);
8726 resultPtr = sharedVar;
8727 } else {
8728 if (shouldAllocate && !isTargetDevice)
8729 opInst.emitWarning("groupprivate directive is currently ignored on the "
8730 "host, using original global");
8731 resultPtr = globalValue;
8732 }
8733
8734 moduleTranslation.mapValue(opInst.getResult(0), resultPtr);
8735 return success();
8736}
8737
8738/// Given an OpenMP MLIR operation, create the corresponding LLVM IR (including
8739/// OpenMP runtime calls).
8740LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
8741 Operation *op, llvm::IRBuilderBase &builder,
8742 LLVM::ModuleTranslation &moduleTranslation) const {
8743 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
8744
8745 if (ompBuilder->Config.isTargetDevice() &&
8746 !isa<omp::TargetOp, omp::MapInfoOp, omp::TerminatorOp, omp::YieldOp>(
8747 op) &&
8748 isHostDeviceOp(op))
8749 return op->emitOpError() << "unsupported host op found in device";
8750
8751 // For each loop, introduce one stack frame to hold loop information. Ensure
8752 // this is only done for the outermost loop wrapper to prevent introducing
8753 // multiple stack frames for a single loop. Initially set to null, the loop
8754 // information structure is initialized during translation of the nested
8755 // omp.loop_nest operation, making it available to translation of all loop
8756 // wrappers after their body has been successfully translated.
8757 bool isOutermostLoopWrapper =
8758 isa_and_present<omp::LoopWrapperInterface>(op) &&
8759 !dyn_cast_if_present<omp::LoopWrapperInterface>(op->getParentOp());
8760
8761 // The TASKLOOP construct is implemented with an outer taskloop.context
8762 // operation which is not a loop wrapper, containing an inner taskloop
8763 // operation which is a loop wrapper. The stack frame should be pushed when
8764 // translating the outer taskloop.context and popped when translating the
8765 // inner taskloop which is a loop wrapper. We need access to the loop
8766 // information in the outer taskloop context so we need to create it and pop
8767 // it around the taskloop context not the inner loop wrapper.
8768 if (isa<omp::TaskloopContextOp>(op))
8769 isOutermostLoopWrapper = true;
8770 else if (isa<omp::TaskloopWrapperOp>(op))
8771 isOutermostLoopWrapper = false;
8772
8773 if (isOutermostLoopWrapper)
8774 moduleTranslation.stackPush<OpenMPLoopInfoStackFrame>();
8775
8776 auto result =
8777 llvm::TypeSwitch<Operation *, LogicalResult>(op)
8778 .Case([&](omp::BarrierOp op) -> LogicalResult {
8780 return failure();
8781
8782 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
8783 ompBuilder->createBarrier(builder.saveIP(),
8784 llvm::omp::OMPD_barrier);
8785 LogicalResult res = handleError(afterIP, *op);
8786 if (res.succeeded()) {
8787 // If the barrier generated a cancellation check, the insertion
8788 // point might now need to be changed to a new continuation block
8789 builder.restoreIP(*afterIP);
8790 }
8791 return res;
8792 })
8793 .Case([&](omp::TaskyieldOp op) {
8795 return failure();
8796
8797 ompBuilder->createTaskyield(builder.saveIP());
8798 return success();
8799 })
8800 .Case([&](omp::FlushOp op) {
8802 return failure();
8803
8804 // No support in Openmp runtime function (__kmpc_flush) to accept
8805 // the argument list.
8806 // OpenMP standard states the following:
8807 // "An implementation may implement a flush with a list by ignoring
8808 // the list, and treating it the same as a flush without a list."
8809 //
8810 // The argument list is discarded so that, flush with a list is
8811 // treated same as a flush without a list.
8812 ompBuilder->createFlush(builder.saveIP());
8813 return success();
8814 })
8815 .Case([&](omp::ParallelOp op) {
8816 return convertOmpParallel(op, builder, moduleTranslation);
8817 })
8818 .Case([&](omp::MaskedOp) {
8819 return convertOmpMasked(*op, builder, moduleTranslation);
8820 })
8821 .Case([&](omp::MasterOp) {
8822 return convertOmpMaster(*op, builder, moduleTranslation);
8823 })
8824 .Case([&](omp::CriticalOp) {
8825 return convertOmpCritical(*op, builder, moduleTranslation);
8826 })
8827 .Case([&](omp::OrderedRegionOp) {
8828 return convertOmpOrderedRegion(*op, builder, moduleTranslation);
8829 })
8830 .Case([&](omp::OrderedOp) {
8831 return convertOmpOrdered(*op, builder, moduleTranslation);
8832 })
8833 .Case([&](omp::WsloopOp) {
8834 return convertOmpWsloop(*op, builder, moduleTranslation);
8835 })
8836 .Case([&](omp::SimdOp) {
8837 return convertOmpSimd(*op, builder, moduleTranslation);
8838 })
8839 .Case([&](omp::AtomicReadOp) {
8840 return convertOmpAtomicRead(*op, builder, moduleTranslation);
8841 })
8842 .Case([&](omp::AtomicWriteOp) {
8843 return convertOmpAtomicWrite(*op, builder, moduleTranslation);
8844 })
8845 .Case([&](omp::AtomicUpdateOp op) {
8846 return convertOmpAtomicUpdate(op, builder, moduleTranslation);
8847 })
8848 .Case([&](omp::AtomicCaptureOp op) {
8849 return convertOmpAtomicCapture(op, builder, moduleTranslation);
8850 })
8851 .Case([&](omp::AtomicCompareOp op) {
8852 return convertOmpAtomicCompare(op, builder, moduleTranslation);
8853 })
8854 .Case([&](omp::CancelOp op) {
8855 return convertOmpCancel(op, builder, moduleTranslation);
8856 })
8857 .Case([&](omp::CancellationPointOp op) {
8858 return convertOmpCancellationPoint(op, builder, moduleTranslation);
8859 })
8860 .Case([&](omp::SectionsOp) {
8861 return convertOmpSections(*op, builder, moduleTranslation);
8862 })
8863 .Case([&](omp::ScopeOp op) {
8864 return convertOmpScope(op, builder, moduleTranslation);
8865 })
8866 .Case([&](omp::SingleOp op) {
8867 return convertOmpSingle(op, builder, moduleTranslation);
8868 })
8869 .Case([&](omp::TeamsOp op) {
8870 return convertOmpTeams(op, builder, moduleTranslation);
8871 })
8872 .Case([&](omp::TaskOp op) {
8873 return convertOmpTaskOp(op, builder, moduleTranslation);
8874 })
8875 .Case([&](omp::TaskloopWrapperOp op) {
8876 return convertOmpTaskloopWrapperOp(op, builder, moduleTranslation);
8877 })
8878 .Case([&](omp::TaskloopContextOp op) {
8879 return convertOmpTaskloopContextOp(op, builder, moduleTranslation);
8880 })
8881 .Case([&](omp::TaskgroupOp op) {
8882 return convertOmpTaskgroupOp(op, builder, moduleTranslation);
8883 })
8884 .Case([&](omp::TaskwaitOp op) {
8885 return convertOmpTaskwaitOp(op, builder, moduleTranslation);
8886 })
8887 .Case<omp::YieldOp, omp::TerminatorOp, omp::DeclareMapperOp,
8888 omp::DeclareMapperInfoOp, omp::DeclareReductionOp,
8889 omp::CriticalDeclareOp>([](auto op) {
8890 // `yield` and `terminator` can be just omitted. The block structure
8891 // was created in the region that handles their parent operation.
8892 // `declare_reduction` will be used by reductions and is not
8893 // converted directly, skip it.
8894 // `declare_mapper` and `declare_mapper.info` are handled whenever
8895 // they are referred to through a `map` clause.
8896 // `critical.declare` is only used to declare names of critical
8897 // sections which will be used by `critical` ops and hence can be
8898 // ignored for lowering. The OpenMP IRBuilder will create unique
8899 // name for critical section names.
8900 return success();
8901 })
8902 .Case([&](omp::ThreadprivateOp) {
8903 return convertOmpThreadprivate(*op, builder, moduleTranslation);
8904 })
8905 .Case<omp::TargetDataOp, omp::TargetEnterDataOp,
8906 omp::TargetExitDataOp, omp::TargetUpdateOp>([&](auto op) {
8907 return convertOmpTargetData(op, builder, moduleTranslation);
8908 })
8909 .Case([&](omp::TargetOp) {
8910 return convertOmpTarget(*op, builder, moduleTranslation);
8911 })
8912 .Case([&](omp::DistributeOp) {
8913 return convertOmpDistribute(*op, builder, moduleTranslation);
8914 })
8915 .Case([&](omp::LoopNestOp) {
8916 return convertOmpLoopNest(*op, builder, moduleTranslation);
8917 })
8918 .Case<omp::MapInfoOp, omp::MapBoundsOp, omp::PrivateClauseOp,
8919 omp::AffinityEntryOp, omp::IteratorOp>([&](auto op) {
8920 // No-op, should be handled by relevant owning operations e.g.
8921 // TargetOp, TargetEnterDataOp, TargetExitDataOp, TargetDataOp
8922 // etc. and then discarded
8923 return success();
8924 })
8925 .Case([&](omp::NewCliOp op) {
8926 // Meta-operation: Doesn't do anything by itself, but used to
8927 // identify a loop.
8928 return success();
8929 })
8930 .Case([&](omp::CanonicalLoopOp op) {
8931 return convertOmpCanonicalLoopOp(op, builder, moduleTranslation);
8932 })
8933 .Case([&](omp::UnrollHeuristicOp op) {
8934 // FIXME: Handling omp.unroll_heuristic as an executable requires
8935 // that the generator (e.g. omp.canonical_loop) has been seen first.
8936 // For construct that require all codegen to occur inside a callback
8937 // (e.g. OpenMPIRBilder::createParallel), all codegen of that
8938 // contained region including their transformations must occur at
8939 // the omp.canonical_loop.
8940 return applyUnrollHeuristic(op, builder, moduleTranslation);
8941 })
8942 .Case([&](omp::TileOp op) {
8943 return applyTile(op, builder, moduleTranslation);
8944 })
8945 .Case([&](omp::FuseOp op) {
8946 return applyFuse(op, builder, moduleTranslation);
8947 })
8948 .Case([&](omp::TargetAllocMemOp) {
8949 return convertTargetAllocMemOp(*op, builder, moduleTranslation);
8950 })
8951 .Case([&](omp::TargetFreeMemOp) {
8952 return convertTargetFreeMemOp(*op, builder, moduleTranslation);
8953 })
8954 .Case([&](omp::AllocateDirOp) {
8955 return convertAllocateDirOp(*op, builder, moduleTranslation, *this);
8956 })
8957 .Case([&](omp::AllocateFreeOp) {
8958 return convertAllocateFreeOp(*op, builder, moduleTranslation,
8959 *this);
8960 })
8961 .Case([&](omp::AllocSharedMemOp op) {
8962 return convertAllocSharedMemOp(op, builder, moduleTranslation);
8963 })
8964 .Case([&](omp::FreeSharedMemOp op) {
8965 return convertFreeSharedMemOp(op, builder, moduleTranslation);
8966 })
8967 .Case([&](omp::GroupprivateOp) {
8968 return convertOmpGroupprivate(*op, builder, moduleTranslation);
8969 })
8970 .Default([&](Operation *inst) {
8971 return inst->emitError()
8972 << "not yet implemented: " << inst->getName();
8973 });
8974
8975 if (isOutermostLoopWrapper)
8976 moduleTranslation.stackPop();
8977
8978 return result;
8979}
8980
8982 registry.insert<omp::OpenMPDialect>();
8983 registry.addExtension(+[](MLIRContext *ctx, omp::OpenMPDialect *dialect) {
8984 dialect->addInterfaces<OpenMPDialectLLVMIRTranslationInterface>();
8985 });
8986}
8987
8989 DialectRegistry registry;
8991 context.appendDialectRegistry(registry);
8992}
for(Operation *op :ops)
return success()
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
if(!isCopyOut)
static mlir::LogicalResult buildDependData(OperandRange dependVars, std::optional< ArrayAttr > dependKinds, OperandRange dependIterated, std::optional< ArrayAttr > dependIteratedKinds, llvm::IRBuilderBase &builder, mlir::LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::DependenciesInfo &taskDeps)
static void handleDeclareTargetMapVar(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::Function *func)
static LogicalResult convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP atomic update operation using OpenMPIRBuilder.
static llvm::omp::OrderKind convertOrderKind(std::optional< omp::ClauseOrderKind > o)
Convert Order attribute to llvm::omp::OrderKind.
static void mapParentWithMembers(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, llvm::omp::OpenMPOffloadMappingFlags memberOfFlag, TargetDirectiveEnumTy targetDirective)
static void processIndividualMap(llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, MapInfoData &mapData, size_t mapDataIdx, MapInfosTy &combinedInfo, TargetDirectiveEnumTy targetDirective, llvm::omp::OpenMPOffloadMappingFlags memberOfFlag=llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE, bool isTargetParam=true, int mapDataParentIdx=-1)
This function handles the insertion of a single item of map data from MapInfoData into the OMPIRBuild...
static llvm::OpenMPIRBuilder::InsertPointTy findAllocInsertPoints(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::SmallVectorImpl< llvm::BasicBlock * > *deallocBlocks=nullptr)
Find the insertion point for allocas given the current insertion point for normal operations in the b...
static void sortMapIndices(llvm::SmallVectorImpl< size_t > &indices, omp::MapInfoOp mapInfo, bool first=true)
static LogicalResult convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
owningDataPtrPtrReductionGens[i]
static LogicalResult convertOmpTaskloopContextOp(omp::TaskloopContextOp contextOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static Operation * getGlobalOpFromValue(Value value)
static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind convertToCaptureClauseKind(mlir::omp::DeclareTargetCaptureClause captureClause)
static mlir::LogicalResult convertIteratorRegion(llvm::Value *linearIV, IteratorInfo &iterInfo, mlir::Block &iteratorRegionBlock, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo, bool first)
static OpTy castOrGetParentOfType(Operation *op, bool immediateParent=false)
If op is of the given type parameter, return it casted to that type. Otherwise, if its immediate pare...
static LogicalResult convertOmpOrderedRegion(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'ordered_region' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertFreeSharedMemOp(omp::FreeSharedMemOp freeMemOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertTargetFreeMemOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpAtomicWrite(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an omp.atomic.write operation to LLVM IR.
static OwningAtomicReductionGen makeAtomicReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Create an OpenMPIRBuilder-compatible atomic reduction generator for the given reduction declaration.
static OwningDataPtrPtrReductionGen makeRefDataPtrGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, bool isByRef)
Create an OpenMPIRBuilder-compatible data_ptr_ptr reduction generator for the given reduction declara...
static void popCancelFinalizationCB(const ArrayRef< llvm::UncondBrInst * > cancelTerminators, llvm::OpenMPIRBuilder &ompBuilder, const llvm::OpenMPIRBuilder::InsertPointTy &afterIP)
If we cancelled the construct, we should branch to the finalization block of that construct....
static llvm::Value * getRefPtrIfDeclareTarget(Value value, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult applyUnrollHeuristic(omp::UnrollHeuristicOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Apply a #pragma omp unroll / "!$omp unroll" transformation using the OpenMPIRBuilder.
static LogicalResult convertOmpMaster(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'master' operation into LLVM IR using OpenMPIRBuilder.
static void getAsIntegers(ArrayAttr values, llvm::SmallVector< int64_t > &ints)
static llvm::Value * findAssociatedValue(Value privateVar, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Return the llvm::Value * corresponding to the privateVar that is being privatized....
static ArrayRef< bool > getIsByRef(std::optional< ArrayRef< bool > > attr)
static llvm::Expected< llvm::Value * > lookupOrTranslatePureValue(Value value, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder)
Look up the given value in the mapping, and if it's not there, translate its defining operation at th...
static LogicalResult allocReductionVars(T op, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, SmallVectorImpl< DeferredStore > &deferredStores, llvm::ArrayRef< bool > isByRefs)
Allocate space for privatized reduction variables.
static LogicalResult convertOmpTaskwaitOp(omp::TaskwaitOp twOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP loop nest into LLVM IR using OpenMPIRBuilder.
static mlir::LogicalResult fillIteratorLoop(mlir::omp::IteratorOp itersOp, llvm::IRBuilderBase &builder, mlir::LLVM::ModuleTranslation &moduleTranslation, IteratorInfo &iterInfo, llvm::StringRef loopName, IteratorStoreEntryTy genStoreEntry)
static llvm::Expected< llvm::BasicBlock * > allocatePrivateVars(T op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, PrivateVarsInfo &privateVarsInfo, const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Allocate and initialize delayed private variables. Returns the basic block which comes after all of t...
static void createAlteredByCaptureMap(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder)
static LogicalResult convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP task construct into LLVM IR using OpenMPIRBuilder.
static void genMapInfos(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, TargetDirectiveEnumTy targetDirective)
static llvm::AtomicOrdering convertAtomicOrdering(std::optional< omp::ClauseMemoryOrderKind > ao)
Convert an Atomic Ordering attribute to llvm::AtomicOrdering.
static void setInsertPointForPossiblyEmptyBlock(llvm::IRBuilderBase &builder, llvm::BasicBlock *block=nullptr)
llvm::function_ref< void(llvm::Value *linearIV, mlir::omp::YieldOp yield)> IteratorStoreEntryTy
static LogicalResult convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'critical' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static omp::DistributeOp getDistributeCapturingTeamsReduction(omp::TeamsOp teamsOp)
static LogicalResult convertOmpCanonicalLoopOp(omp::CanonicalLoopOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Convert an omp.canonical_loop to LLVM-IR.
static LogicalResult convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static std::optional< int64_t > extractConstInteger(Value value)
If the given value is defined by an llvm.mlir.constant operation and it is of an integer type,...
static llvm::Expected< llvm::Value * > initPrivateVar(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, omp::PrivateClauseOp &privDecl, llvm::Value *nonPrivateVar, BlockArgument &blockArg, llvm::Value *llvmPrivateVar, llvm::BasicBlock *privInitBlock, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Initialize a single (first)private variable. You probably want to use allocateAndInitPrivateVars inst...
static mlir::LogicalResult buildAffinityData(mlir::omp::TaskOp &taskOp, llvm::IRBuilderBase &builder, mlir::LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::AffinityData &ad)
static LogicalResult allocAndInitializeReductionVars(OP op, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, llvm::ArrayRef< bool > isByRef)
static LogicalResult convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Value * getAllocationSize(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, T op)
static llvm::Function * getOmpTargetAlloc(llvm::IRBuilderBase &builder, llvm::Module *llvmModule)
static llvm::omp::OMPDynGroupprivateFallbackType getDynGroupprivateFallbackType(omp::FallbackModifierAttr fallbackAttr)
static llvm::Expected< llvm::Function * > emitUserDefinedMapper(Operation *declMapperOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::StringRef mapperFuncName, TargetDirectiveEnumTy targetDirective)
static LogicalResult convertOmpOrdered(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'ordered' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult cleanupPrivateVars(T op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, Location loc, PrivateVarsInfo &privateVarsInfo)
static void processMapWithMembersOf(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, TargetDirectiveEnumTy targetDirective)
static LogicalResult convertOmpMasked(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'masked' operation into LLVM IR using OpenMPIRBuilder.
static llvm::AtomicRMWInst::BinOp convertBinOpToAtomic(Operation &op)
Converts an LLVM dialect binary operation to the corresponding enum value for atomicrmw supported bin...
static LogicalResult convertOmpCancel(omp::CancelOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp)
allocatedType moduleTranslation static convertType(allocatedType) LogicalResult inlineOmpRegionCleanup(llvm::SmallVectorImpl< Region * > &cleanupRegions, llvm::ArrayRef< llvm::Value * > privateVariables, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, StringRef regionName, bool shouldLoadCleanupRegionArg=true)
handling of DeclareReductionOp's cleanup region
static LogicalResult applyFuse(omp::FuseOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Apply a #pragma omp fuse / !$omp fuse transformation using the OpenMPIRBuilder.
static LogicalResult convertOmpScope(omp::ScopeOp &scopeOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP scope construct into LLVM IR.
static llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Error initPrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, PrivateVarsInfo &privateVarsInfo, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
static LogicalResult convertAllocSharedMemOp(omp::AllocSharedMemOp allocMemOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::CanonicalLoopInfo * findCurrentLoopInfo(LLVM::ModuleTranslation &moduleTranslation)
Find the loop information structure for the loop nest being translated.
static OwningReductionGen makeReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Create an OpenMPIRBuilder-compatible reduction generator for the given reduction declaration.
static std::vector< llvm::Value * > calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, bool isArrayTy, OperandRange bounds)
This function calculates the array/pointer offset for map data provided with bounds operations,...
static void storeAffinityEntry(llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, llvm::Value *affinityList, llvm::Value *index, llvm::Value *addr, llvm::Value *len)
static LogicalResult convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts the OpenMP parallel operation to LLVM IR.
static void pushCancelFinalizationCB(SmallVectorImpl< llvm::UncondBrInst * > &cancelTerminators, llvm::IRBuilderBase &llvmBuilder, llvm::OpenMPIRBuilder &ompBuilder, mlir::Operation *op, llvm::omp::Directive cancelDirective)
Shared implementation of a callback which adds a termiator for the new block created for the branch t...
static LogicalResult inlineConvertOmpRegions(Region &region, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::Value * > *continuationBlockArgs=nullptr)
Translates the blocks contained in the given region and appends them to at the current insertion poin...
static void getTargetEntryUniqueInfo(llvm::TargetRegionEntryInfo &targetInfo, omp::TargetOp targetOp, llvm::OpenMPIRBuilder &ompBuilder, llvm::vfs::FileSystem &vfs, llvm::StringRef parentName="")
static LogicalResult convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP Threadprivate operation into LLVM IR using OpenMPIRBuilder.
static omp::PrivateClauseOp findPrivatizer(Operation *from, SymbolRefAttr symbolName)
Looks up from the operation from and returns the PrivateClauseOp with name symbolName.
static LogicalResult convertOmpGroupprivate(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP groupprivate operation into LLVM IR.
static llvm::Expected< llvm::Function * > getOrCreateUserDefinedMapperFunc(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, TargetDirectiveEnumTy targetDirective)
static uint64_t getTypeByteSize(mlir::Type type, const DataLayout &dl)
static llvm::SmallString< 64 > getDeclareTargetRefPtrSuffix(LLVM::GlobalOp globalOp, llvm::OpenMPIRBuilder &ompBuilder, llvm::vfs::FileSystem &vfs)
static void extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads, Value &numTeamsLower, Value &numTeamsUpper, Value &threadLimit, llvm::SmallVectorImpl< Value > *lowerBounds=nullptr, llvm::SmallVectorImpl< Value > *upperBounds=nullptr, llvm::SmallVectorImpl< Value > *steps=nullptr)
Follow uses of host_eval-defined block arguments of the given omp.target operation and populate outpu...
static llvm::Expected< llvm::BasicBlock * > convertOmpOpRegions(Region &region, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::PHINode * > *continuationBlockPHIs=nullptr)
Converts the given region that appears within an OpenMP dialect operation to LLVM IR,...
static LogicalResult convertOmpAtomicCompare(omp::AtomicCompareOp atomicCompareOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an omp.atomic.compare operation to LLVM IR.
static LogicalResult copyFirstPrivateVars(mlir::Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::Value * > &moldVars, ArrayRef< llvm::Value * > llvmPrivateVars, SmallVectorImpl< omp::PrivateClauseOp > &privateDecls, bool insertBarrier, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
static LogicalResult convertAllocateDirOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, const OpenMPDialectLLVMIRTranslationInterface &ompIface)
static bool constructIsCancellable(Operation *op)
Returns true if the construct contains omp.cancel or omp.cancellation_point.
static llvm::omp::OpenMPOffloadMappingFlags convertClauseMapFlags(omp::ClauseMapFlags mlirFlags)
static void buildDependDataLocator(std::optional< ArrayAttr > dependKinds, OperandRange dependVars, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::OpenMPIRBuilder::DependData > &dds)
static std::optional< llvm::omp::OMPAtomicCompareOp > convertFCmpPredicateToAtomicCompareOp(LLVM::FCmpPredicate predicate)
Helper to extract the OMPAtomicCompareOp from a floating-point comparison predicate....
static void mapInitializationArgs(T loop, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, DenseMap< Value, llvm::Value * > &reductionVariableMap, unsigned i)
Map input arguments to reduction initialization region.
static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind)
Convert ProcBindKind from MLIR-generated enum to LLVM enum.
static void fillAffinityLocators(Operation::operand_range affinityVars, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::Value *affinityList)
static LogicalResult convertOmpTaskloopWrapperOp(omp::TaskloopWrapperOp loopWrapperOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
The correct entry point is convertOmpTaskloopContextOp. This gets called whilst lowering the body of ...
static void getOverlappedMembers(llvm::SmallVectorImpl< size_t > &overlapMapDataIdxs, omp::MapInfoOp parentOp)
static LogicalResult convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP single construct into LLVM IR using OpenMPIRBuilder.
static bool isDeclareTargetTo(Value value)
static uint64_t getArrayElementSizeInBits(LLVM::LLVMArrayType arrTy, DataLayout &dl)
static void collectReductionDecls(T op, SmallVectorImpl< omp::DeclareReductionOp > &reductions)
Populates reductions with reduction declarations used in the given op.
static LogicalResult handleError(llvm::Error error, Operation &op)
static LogicalResult convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind convertToDeviceClauseKind(mlir::omp::DeclareTargetDeviceType deviceClause)
static std::optional< llvm::omp::OMPAtomicCompareOp > convertICmpPredicateToAtomicCompareOp(LLVM::ICmpPredicate predicate)
Helper to extract the OMPAtomicCompareOp from an integer comparison predicate. Returns std::nullopt f...
static llvm::Error computeTaskloopBounds(omp::LoopNestOp loopOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::Value *&lbVal, llvm::Value *&ubVal, llvm::Value *&stepVal)
static LogicalResult checkImplementationStatus(Operation &op)
Check whether translation to LLVM IR for the given operation is currently supported.
static llvm::IRBuilderBase::InsertPoint createDeviceArgumentAccessor(omp::TargetOp targetOp, MapInfoData &mapData, llvm::Argument &arg, llvm::Value *input, llvm::Value *&retVal, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase::InsertPoint allocaIP, llvm::IRBuilderBase::InsertPoint codeGenIP, llvm::ArrayRef< llvm::IRBuilderBase::InsertPoint > deallocIPs)
static LogicalResult createReductionsAndCleanup(OP op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, ArrayRef< llvm::Value * > privateReductionVariables, ArrayRef< bool > isByRef, bool isNowait=false, bool isTeamsReduction=false)
static LogicalResult convertOmpCancellationPoint(omp::CancellationPointOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static bool opIsInSingleThread(mlir::Operation *op)
This can't always be determined statically, but when we can, it is good to avoid generating compiler-...
static uint64_t getReductionDataSize(OpTy &op)
static LogicalResult convertOmpAtomicRead(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Convert omp.atomic.read operation to LLVM IR.
static llvm::omp::Directive convertCancellationConstructType(omp::ClauseCancellationConstructType directive)
static void initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp, llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs, bool isTargetDevice, bool isGPU)
Populate default MinTeams, MaxTeams and MaxThreads to their default values as stated by the correspon...
static llvm::omp::RTLDependenceKindTy convertDependKind(mlir::omp::ClauseTaskDepend kind)
static void initTargetRuntimeAttrs(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, omp::TargetOp targetOp, Operation *capturedOp, llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs)
Gather LLVM runtime values for all clauses evaluated in the host that are passed to the kernel invoca...
static LogicalResult convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static Value getBaseValueForTypeLookup(Value value)
static bool isHostDeviceOp(Operation *op)
static LogicalResult convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute, llvm::OpenMPIRBuilder *ompBuilder, LLVM::ModuleTranslation &moduleTranslation)
static bool isDeclareTargetLink(Value value)
static LogicalResult convertFlagsAttr(Operation *op, mlir::omp::FlagsAttr attribute, LLVM::ModuleTranslation &moduleTranslation)
Lowers the FlagsAttr which is applied to the module on the device pass when offloading,...
static bool checkIfPointerMap(omp::MapInfoOp mapOp)
static LogicalResult applyTile(omp::TileOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Apply a #pragma omp tile / !$omp tile transformation using the OpenMPIRBuilder.
static LogicalResult convertAllocateFreeOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, const OpenMPDialectLLVMIRTranslationInterface &ompIface)
static llvm::Function * getOmpTargetFree(llvm::IRBuilderBase &builder, llvm::Module *llvmModule)
static LogicalResult convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder.
static void collectMapDataFromMapOperands(MapInfoData &mapData, SmallVectorImpl< Value > &mapVars, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, llvm::IRBuilderBase &builder, ArrayRef< Value > useDevPtrOperands={}, ArrayRef< Value > useDevAddrOperands={}, ArrayRef< Value > hasDevAddrOperands={})
static void extractAtomicControlFlags(omp::AtomicUpdateOp atomicUpdateOp, bool &isIgnoreDenormalMode, bool &isFineGrainedMemory, bool &isRemoteMemory)
static Operation * genLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr, unsigned numCases, bool needsUniv, ArrayRef< TensorLevel > tidLvls)
Generates a for-loop or a while-loop, depending on whether it implements singleton iteration or co-it...
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition TypeID.h:331
#define div(a, b)
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:306
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
OpListType & getOperations()
Definition Block.h:147
Operation & front()
Definition Block.h:163
Operation & back()
Definition Block.h:162
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition Block.h:222
iterator begin()
Definition Block.h:153
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
An instance of this location represents a tuple of file, line number, and column number.
Definition Location.h:174
Implementation class for module translation.
llvm::BasicBlock * lookupBlock(Block *block) const
Finds an LLVM IR basic block that corresponds to the given MLIR block.
WalkResult stackWalk(llvm::function_ref< WalkResult(T &)> callback)
Calls callback for every ModuleTranslation stack frame of type T starting from the top of the stack.
void stackPush(Args &&...args)
Creates a stack frame of type T on ModuleTranslation stack.
LogicalResult convertBlock(Block &bb, bool ignoreArguments, llvm::IRBuilderBase &builder)
Translates the contents of the given block to LLVM IR using this translator.
SmallVector< llvm::Value * > lookupValues(ValueRange values)
Looks up remapped a list of remapped values.
void mapFunction(StringRef name, llvm::Function *func)
Stores the mapping between a function name and its LLVM IR representation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
void invalidateOmpLoop(omp::NewCliOp mlir)
Mark an OpenMP loop as having been consumed.
SymbolTableCollection & symbolTable()
llvm::Type * convertType(Type type)
Converts the type from MLIR LLVM dialect to LLVM.
llvm::OpenMPIRBuilder * getOpenMPBuilder()
Returns the OpenMP IR builder associated with the LLVM IR module being constructed.
llvm::vfs::FileSystem & getFileSystem()
Returns the virtual filesystem to use for file operations.
void mapOmpLoop(omp::NewCliOp mlir, llvm::CanonicalLoopInfo *llvm)
Map an MLIR OpenMP dialect CanonicalLoopInfo to its lowered LLVM-IR OpenMPIRBuilder CanonicalLoopInfo...
llvm::GlobalValue * lookupGlobal(Operation *op)
Finds an LLVM IR global value that corresponds to the given MLIR operation defining a global value.
SaveStateStack< T, ModuleTranslation > SaveStack
RAII object calling stackPush/stackPop on construction/destruction.
LogicalResult convertOperation(Operation &op, llvm::IRBuilderBase &builder)
Converts the given MLIR operation into LLVM IR using this translator.
llvm::Function * lookupFunction(StringRef name) const
Finds an LLVM IR function by its name.
void mapBlock(Block *mlir, llvm::BasicBlock *llvm)
Stores the mapping between an MLIR block and LLVM IR basic block.
llvm::Module * getLLVMModule()
Returns the LLVM module in which the IR is being constructed.
void stackPop()
Pops the last element from the ModuleTranslation stack.
void forgetMapping(Region &region)
Removes the mapping for blocks contained in the region and values defined in these blocks.
void mapValue(Value mlir, llvm::Value *llvm)
Stores the mapping between an MLIR value and its LLVM IR counterpart.
llvm::CanonicalLoopInfo * lookupOMPLoop(omp::NewCliOp mlir) const
Find the LLVM-IR loop that represents an MLIR loop.
llvm::LLVMContext & getLLVMContext() const
Returns the LLVM context in which the IR is being constructed.
Utility class to translate MLIR LLVM dialect types to LLVM IR.
Definition TypeToLLVM.h:39
unsigned getPreferredAlignment(Type type, const llvm::DataLayout &layout)
Returns the preferred alignment for the type given the data layout.
T findInstanceOf()
Return an instance of the given location type if one is nested under the current location.
Definition Location.h:45
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
void appendDialectRegistry(const DialectRegistry &registry)
Append the contents of the given dialect registry to the registry associated with this context.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:44
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:711
Value getOperand(unsigned idx)
Definition Operation.h:375
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:432
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:699
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:240
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:251
unsigned getNumOperands()
Definition Operation.h:371
OperandRange operand_range
Definition Operation.h:396
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:255
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:115
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:403
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:822
user_range getUsers()
Returns a range of all users.
Definition Operation.h:898
result_range getResults()
Definition Operation.h:440
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:233
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:429
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
BlockArgListType getArguments()
Definition Region.h:81
bool empty()
Definition Region.h:60
unsigned getNumArguments()
Definition Region.h:123
iterator begin()
Definition Region.h:55
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition Region.h:200
BlockListType & getBlocks()
Definition Region.h:45
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
Concrete CRTP base class for StateStack frames.
Definition StateStack.h:47
@ Private
The symbol is private and may only be referenced by SymbolRefAttrs local to the operations within the...
Definition SymbolTable.h:97
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult skip()
Definition WalkResult.h:48
static WalkResult advance()
Definition WalkResult.h:47
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition WalkResult.h:51
static WalkResult interrupt()
Definition WalkResult.h:46
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Definition CallGraph.h:227
void connectPHINodes(Region &region, const ModuleTranslation &state)
For all blocks in the region that were converted to LLVM IR using the given ModuleTranslation,...
llvm::Constant * createMappingInformation(Location loc, llvm::OpenMPIRBuilder &builder)
Create a constant string representing the mapping information extracted from the MLIR location inform...
bool opInSharedDeviceContext(Operation &op)
Check whether the given operation is located in a context where an allocation to be used by multiple ...
Definition Utils.cpp:66
bool allocaUsesRequireSharedMem(Value alloc)
Check whether the value representing an allocation, assumed to have been defined in a shared device c...
Definition Utils.cpp:51
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
auto getDims(VectorType vType)
Returns a range over the dims (size and scalability) of a VectorType.
Include the generated interface declarations.
SetVector< Block * > getBlocksSortedByDominance(Region &region)
Gets a list of blocks that is sorted according to dominance.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool isPure(Operation *op)
Returns true if the given operation is pure, i.e., is speculatable that does not touch memory.
void registerOpenMPDialectTranslation(DialectRegistry &registry)
Register the OpenMP dialect and the translation from it to the LLVM IR in the given registry;.
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:125
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition Utils.cpp:1330
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:120
A util to collect info needed to convert delayed privatizers from MLIR to LLVM.
SmallVector< mlir::Value > mlirVars
SmallVector< omp::PrivateClauseOp > privatizers
MutableArrayRef< BlockArgument > blockArgs
SmallVector< llvm::Value * > llvmVars
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.