MLIR 23.0.0git
OpenMPToLLVMIRTranslation.cpp
Go to the documentation of this file.
1//===- OpenMPToLLVMIRTranslation.cpp - Translate OpenMP dialect to LLVM IR-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a translation between the MLIR OpenMP dialect and LLVM
10// IR.
11//
12//===----------------------------------------------------------------------===//
19#include "mlir/IR/Operation.h"
20#include "mlir/Support/LLVM.h"
23
24#include "llvm/ADT/ArrayRef.h"
25#include "llvm/ADT/SmallVector.h"
26#include "llvm/ADT/TypeSwitch.h"
27#include "llvm/Frontend/OpenMP/OMPConstants.h"
28#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
29#include "llvm/IR/Constants.h"
30#include "llvm/IR/DebugInfoMetadata.h"
31#include "llvm/IR/DerivedTypes.h"
32#include "llvm/IR/IRBuilder.h"
33#include "llvm/IR/MDBuilder.h"
34#include "llvm/IR/ReplaceConstant.h"
35#include "llvm/Support/FileSystem.h"
36#include "llvm/Support/VirtualFileSystem.h"
37#include "llvm/TargetParser/Triple.h"
38#include "llvm/Transforms/Utils/ModuleUtils.h"
39
40#include <cstdint>
41#include <iterator>
42#include <numeric>
43#include <optional>
44#include <utility>
45
46using namespace mlir;
47
48namespace {
49static llvm::omp::ScheduleKind
50convertToScheduleKind(std::optional<omp::ClauseScheduleKind> schedKind) {
51 if (!schedKind.has_value())
52 return llvm::omp::OMP_SCHEDULE_Default;
53 switch (schedKind.value()) {
54 case omp::ClauseScheduleKind::Static:
55 return llvm::omp::OMP_SCHEDULE_Static;
56 case omp::ClauseScheduleKind::Dynamic:
57 return llvm::omp::OMP_SCHEDULE_Dynamic;
58 case omp::ClauseScheduleKind::Guided:
59 return llvm::omp::OMP_SCHEDULE_Guided;
60 case omp::ClauseScheduleKind::Auto:
61 return llvm::omp::OMP_SCHEDULE_Auto;
62 case omp::ClauseScheduleKind::Runtime:
63 return llvm::omp::OMP_SCHEDULE_Runtime;
64 case omp::ClauseScheduleKind::Distribute:
65 return llvm::omp::OMP_SCHEDULE_Distribute;
66 }
67 llvm_unreachable("unhandled schedule clause argument");
68}
69
70/// ModuleTranslation stack frame for OpenMP operations. This keeps track of the
71/// insertion points for allocas.
72class OpenMPAllocaStackFrame
73 : public StateStackFrameBase<OpenMPAllocaStackFrame> {
74public:
76
77 explicit OpenMPAllocaStackFrame(llvm::OpenMPIRBuilder::InsertPointTy allocaIP)
78 : allocaInsertPoint(allocaIP) {}
79 llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
80};
81
82/// Stack frame to hold a \see llvm::CanonicalLoopInfo representing the
83/// collapsed canonical loop information corresponding to an \c omp.loop_nest
84/// operation.
85class OpenMPLoopInfoStackFrame
86 : public StateStackFrameBase<OpenMPLoopInfoStackFrame> {
87public:
88 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPLoopInfoStackFrame)
89 llvm::CanonicalLoopInfo *loopInfo = nullptr;
90};
91
92/// Custom error class to signal translation errors that don't need reporting,
93/// since encountering them will have already triggered relevant error messages.
94///
95/// Its purpose is to serve as the glue between MLIR failures represented as
96/// \see LogicalResult instances and \see llvm::Error instances used to
97/// propagate errors through the \see llvm::OpenMPIRBuilder. Generally, when an
98/// error of the first type is raised, a message is emitted directly (the \see
99/// LogicalResult itself does not hold any information). If we need to forward
100/// this error condition as an \see llvm::Error while avoiding triggering some
101/// redundant error reporting later on, we need a custom \see llvm::ErrorInfo
102/// class to just signal this situation has happened.
103///
104/// For example, this class should be used to trigger errors from within
105/// callbacks passed to the \see OpenMPIRBuilder when they were triggered by the
106/// translation of their own regions. This unclutters the error log from
107/// redundant messages.
108class PreviouslyReportedError
109 : public llvm::ErrorInfo<PreviouslyReportedError> {
110public:
111 void log(raw_ostream &) const override {
112 // Do not log anything.
113 }
114
115 std::error_code convertToErrorCode() const override {
116 llvm_unreachable(
117 "PreviouslyReportedError doesn't support ECError conversion");
118 }
119
120 // Used by ErrorInfo::classID.
121 static char ID;
122};
123
124char PreviouslyReportedError::ID = 0;
125
126/*
127 * Custom class for processing linear clause for omp.wsloop
128 * and omp.simd. Linear clause translation requires setup,
129 * initialization, update, and finalization at varying
130 * basic blocks in the IR. This class helps maintain
131 * internal state to allow consistent translation in
132 * each of these stages.
133 */
134
135class LinearClauseProcessor {
136
137private:
138 SmallVector<llvm::Value *> linearPreconditionVars;
139 SmallVector<llvm::Value *> linearLoopBodyTemps;
140 SmallVector<llvm::Value *> linearOrigVal;
141 SmallVector<llvm::Value *> linearSteps;
142 SmallVector<llvm::Type *> linearVarTypes;
143 llvm::BasicBlock *linearFinalizationBB;
144 llvm::BasicBlock *linearExitBB;
145 llvm::BasicBlock *linearLastIterExitBB;
146
147public:
148 // Register type for the linear variables
149 void registerType(LLVM::ModuleTranslation &moduleTranslation,
150 mlir::Attribute &ty) {
151 linearVarTypes.push_back(moduleTranslation.convertType(
152 mlir::cast<mlir::TypeAttr>(ty).getValue()));
153 }
154
155 // Allocate space for linear variabes
156 void createLinearVar(llvm::IRBuilderBase &builder,
157 LLVM::ModuleTranslation &moduleTranslation,
158 llvm::Value *linearVar, int idx) {
159 linearPreconditionVars.push_back(
160 builder.CreateAlloca(linearVarTypes[idx], nullptr, ".linear_var"));
161 llvm::Value *linearLoopBodyTemp =
162 builder.CreateAlloca(linearVarTypes[idx], nullptr, ".linear_result");
163 linearOrigVal.push_back(linearVar);
164 linearLoopBodyTemps.push_back(linearLoopBodyTemp);
165 }
166
167 // Initialize linear step
168 inline void initLinearStep(LLVM::ModuleTranslation &moduleTranslation,
169 mlir::Value &linearStep) {
170 linearSteps.push_back(moduleTranslation.lookupValue(linearStep));
171 }
172
173 // Emit IR for initialization of linear variables
174 void initLinearVar(llvm::IRBuilderBase &builder,
175 LLVM::ModuleTranslation &moduleTranslation,
176 llvm::BasicBlock *loopPreHeader) {
177 builder.SetInsertPoint(loopPreHeader->getTerminator());
178 for (size_t index = 0; index < linearOrigVal.size(); index++) {
179 llvm::LoadInst *linearVarLoad =
180 builder.CreateLoad(linearVarTypes[index], linearOrigVal[index]);
181 builder.CreateStore(linearVarLoad, linearPreconditionVars[index]);
182 }
183 }
184
185 // Emit IR for updating Linear variables
186 void updateLinearVar(llvm::IRBuilderBase &builder, llvm::BasicBlock *loopBody,
187 llvm::Value *loopInductionVar) {
188 builder.SetInsertPoint(loopBody->getTerminator());
189 for (size_t index = 0; index < linearPreconditionVars.size(); index++) {
190 llvm::Type *linearVarType = linearVarTypes[index];
191 llvm::Value *iv = loopInductionVar;
192 llvm::Value *step = linearSteps[index];
193
194 if (!iv->getType()->isIntegerTy())
195 llvm_unreachable("OpenMP loop induction variable must be an integer "
196 "type");
197
198 if (linearVarType->isIntegerTy()) {
199 // Integer path: normalize all arithmetic to linearVarType
200 iv = builder.CreateSExtOrTrunc(iv, linearVarType);
201 step = builder.CreateSExtOrTrunc(step, linearVarType);
202
203 llvm::LoadInst *linearVarStart =
204 builder.CreateLoad(linearVarType, linearPreconditionVars[index]);
205 llvm::Value *mulInst = builder.CreateMul(iv, step);
206 llvm::Value *addInst = builder.CreateAdd(linearVarStart, mulInst);
207 builder.CreateStore(addInst, linearLoopBodyTemps[index]);
208 } else if (linearVarType->isFloatingPointTy()) {
209 // Float path: perform multiply in integer, then convert to float
210 step = builder.CreateSExtOrTrunc(step, iv->getType());
211 llvm::Value *mulInst = builder.CreateMul(iv, step);
212
213 llvm::LoadInst *linearVarStart =
214 builder.CreateLoad(linearVarType, linearPreconditionVars[index]);
215 llvm::Value *mulFp = builder.CreateSIToFP(mulInst, linearVarType);
216 llvm::Value *addInst = builder.CreateFAdd(linearVarStart, mulFp);
217 builder.CreateStore(addInst, linearLoopBodyTemps[index]);
218 } else {
219 llvm_unreachable(
220 "Linear variable must be of integer or floating-point type");
221 }
222 }
223 }
224
225 // Linear variable finalization is conditional on the last logical iteration.
226 // Create BB splits to manage the same.
227 void splitLinearFiniBB(llvm::IRBuilderBase &builder,
228 llvm::BasicBlock *loopExit) {
229 linearFinalizationBB = loopExit->splitBasicBlock(
230 loopExit->getTerminator(), "omp_loop.linear_finalization");
231 linearExitBB = linearFinalizationBB->splitBasicBlock(
232 linearFinalizationBB->getTerminator(), "omp_loop.linear_exit");
233 linearLastIterExitBB = linearFinalizationBB->splitBasicBlock(
234 linearFinalizationBB->getTerminator(), "omp_loop.linear_lastiter_exit");
235 }
236
237 // Finalize the linear vars
238 llvm::OpenMPIRBuilder::InsertPointOrErrorTy
239 finalizeLinearVar(llvm::IRBuilderBase &builder,
240 LLVM::ModuleTranslation &moduleTranslation,
241 llvm::Value *lastIter) {
242 // Emit condition to check whether last logical iteration is being executed
243 builder.SetInsertPoint(linearFinalizationBB->getTerminator());
244 llvm::Value *loopLastIterLoad = builder.CreateLoad(
245 llvm::Type::getInt32Ty(builder.getContext()), lastIter);
246 llvm::Value *isLast =
247 builder.CreateCmp(llvm::CmpInst::ICMP_NE, loopLastIterLoad,
248 llvm::ConstantInt::get(
249 llvm::Type::getInt32Ty(builder.getContext()), 0));
250 // Store the linear variable values to original variables.
251 builder.SetInsertPoint(linearLastIterExitBB->getTerminator());
252 for (size_t index = 0; index < linearOrigVal.size(); index++) {
253 llvm::LoadInst *linearVarTemp =
254 builder.CreateLoad(linearVarTypes[index], linearLoopBodyTemps[index]);
255 builder.CreateStore(linearVarTemp, linearOrigVal[index]);
256 }
257
258 // Create conditional branch such that the linear variable
259 // values are stored to original variables only at the
260 // last logical iteration
261 builder.SetInsertPoint(linearFinalizationBB->getTerminator());
262 builder.CreateCondBr(isLast, linearLastIterExitBB, linearExitBB);
263 linearFinalizationBB->getTerminator()->eraseFromParent();
264 // Emit barrier
265 builder.SetInsertPoint(linearExitBB->getTerminator());
266 return moduleTranslation.getOpenMPBuilder()->createBarrier(
267 builder.saveIP(), llvm::omp::OMPD_barrier);
268 }
269
270 // Emit stores for linear variables. Useful in case of SIMD
271 // construct.
272 void emitStoresForLinearVar(llvm::IRBuilderBase &builder) {
273 for (size_t index = 0; index < linearOrigVal.size(); index++) {
274 llvm::LoadInst *linearVarTemp =
275 builder.CreateLoad(linearVarTypes[index], linearLoopBodyTemps[index]);
276 builder.CreateStore(linearVarTemp, linearOrigVal[index]);
277 }
278 }
279
280 // Rewrite all uses of the original variable, in the basic blocks whose names
281 // start with `prefix`, with the linear variable in-place.
282 void rewriteInPlace(llvm::IRBuilderBase &builder, llvm::BasicBlock *startBB,
283 llvm::BasicBlock *endBB, llvm::StringRef prefix,
284 size_t varIndex) {
285 llvm::SmallVector<llvm::BasicBlock *, 32> worklist;
286 llvm::SmallPtrSet<llvm::BasicBlock *, 32> visited;
287 llvm::SmallPtrSet<llvm::BasicBlock *, 32> matchingBBs;
288
289 assert(startBB && endBB && "Invalid startBB/endBB");
290
291 // Traverse basic blocks from startBB to endBB and save those
292 // whose names start with the specified prefix.
293 worklist.push_back(startBB);
294 visited.insert(startBB);
295
296 while (!worklist.empty()) {
297 llvm::BasicBlock *bb = worklist.pop_back_val();
298
299 if (bb->hasName() && bb->getName().starts_with(prefix))
300 matchingBBs.insert(bb);
301
302 if (bb == endBB)
303 continue;
304
305 for (llvm::BasicBlock *succ : llvm::successors(bb)) {
306 if (visited.insert(succ).second)
307 worklist.push_back(succ);
308 }
309 }
310
311 // Rewrite all uses in the matching BBs.
312 for (auto *user : linearOrigVal[varIndex]->users()) {
313 if (auto *userInst = dyn_cast<llvm::Instruction>(user)) {
314 if (matchingBBs.contains(userInst->getParent()))
315 user->replaceUsesOfWith(linearOrigVal[varIndex],
316 linearLoopBodyTemps[varIndex]);
317 }
318 }
319 }
320};
321
322} // namespace
323
324/// Looks up from the operation from and returns the PrivateClauseOp with
325/// name symbolName
326static omp::PrivateClauseOp findPrivatizer(Operation *from,
327 SymbolRefAttr symbolName) {
328 omp::PrivateClauseOp privatizer =
330 symbolName);
331 assert(privatizer && "privatizer not found in the symbol table");
332 return privatizer;
333}
334
335/// Check whether translation to LLVM IR for the given operation is currently
336/// supported. If not, descriptive diagnostics will be emitted to let users know
337/// this is a not-yet-implemented feature.
338///
339/// \returns success if no unimplemented features are needed to translate the
340/// given operation.
341static LogicalResult checkImplementationStatus(Operation &op) {
342 auto todo = [&op](StringRef clauseName) {
343 return op.emitError() << "not yet implemented: Unhandled clause "
344 << clauseName << " in " << op.getName()
345 << " operation";
346 };
347
348 auto checkAllocate = [&todo](auto op, LogicalResult &result) {
349 if (!op.getAllocateVars().empty() || !op.getAllocatorVars().empty())
350 result = todo("allocate");
351 };
352 auto checkBare = [&todo](auto op, LogicalResult &result) {
353 if (op.getBare())
354 result = todo("ompx_bare");
355 };
356 auto checkDepend = [&todo](auto op, LogicalResult &result) {
357 if (!op.getDependVars().empty() || op.getDependKinds())
358 result = todo("depend");
359 };
360 auto checkDependIteratorModifier = [&todo](auto op, LogicalResult &result) {
361 if (!op.getDependIterated().empty() ||
362 (op.getDependIteratedKinds() && !op.getDependIteratedKinds()->empty()))
363 result = todo("depend with iterator modifier");
364 };
365 auto checkHint = [](auto op, LogicalResult &) {
366 if (op.getHint())
367 op.emitWarning("hint clause discarded");
368 };
369 auto checkInReduction = [&todo](auto op, LogicalResult &result) {
370 if (!op.getInReductionVars().empty() || op.getInReductionByref() ||
371 op.getInReductionSyms())
372 result = todo("in_reduction");
373 };
374 auto checkNowait = [&todo](auto op, LogicalResult &result) {
375 if (op.getNowait())
376 result = todo("nowait");
377 };
378 auto checkOrder = [&todo](auto op, LogicalResult &result) {
379 if (op.getOrder() || op.getOrderMod())
380 result = todo("order");
381 };
382 auto checkPrivate = [&todo](auto op, LogicalResult &result) {
383 if (!op.getPrivateVars().empty() || op.getPrivateSyms())
384 result = todo("privatization");
385 };
386 auto checkReduction = [&todo](auto op, LogicalResult &result) {
387 if (isa<omp::TeamsOp>(op) || isa<omp::TaskloopOp>(op))
388 if (!op.getReductionVars().empty() || op.getReductionByref() ||
389 op.getReductionSyms())
390 result = todo("reduction");
391 if (op.getReductionMod() &&
392 op.getReductionMod().value() != omp::ReductionModifier::defaultmod)
393 result = todo("reduction with modifier");
394 };
395 auto checkTaskReduction = [&todo](auto op, LogicalResult &result) {
396 if (!op.getTaskReductionVars().empty() || op.getTaskReductionByref() ||
397 op.getTaskReductionSyms())
398 result = todo("task_reduction");
399 };
400 auto checkNumTeams = [&todo](auto op, LogicalResult &result) {
401 if (op.hasNumTeamsMultiDim())
402 result = todo("num_teams with multi-dimensional values");
403 };
404 auto checkNumThreads = [&todo](auto op, LogicalResult &result) {
405 if (op.hasNumThreadsMultiDim())
406 result = todo("num_threads with multi-dimensional values");
407 };
408
409 auto checkThreadLimit = [&todo](auto op, LogicalResult &result) {
410 if (op.hasThreadLimitMultiDim())
411 result = todo("thread_limit with multi-dimensional values");
412 };
413
414 LogicalResult result = success();
416 .Case([&](omp::DistributeOp op) {
417 checkAllocate(op, result);
418 checkOrder(op, result);
419 })
420 .Case([&](omp::SectionsOp op) {
421 checkAllocate(op, result);
422 checkPrivate(op, result);
423 checkReduction(op, result);
424 })
425 .Case([&](omp::SingleOp op) {
426 checkAllocate(op, result);
427 checkPrivate(op, result);
428 })
429 .Case([&](omp::TeamsOp op) {
430 checkAllocate(op, result);
431 checkPrivate(op, result);
432 checkNumTeams(op, result);
433 checkThreadLimit(op, result);
434 })
435 .Case([&](omp::TaskOp op) {
436 checkAllocate(op, result);
437 checkDependIteratorModifier(op, result);
438 checkInReduction(op, result);
439 })
440 .Case([&](omp::TaskgroupOp op) {
441 checkAllocate(op, result);
442 checkTaskReduction(op, result);
443 })
444 .Case([&](omp::TaskwaitOp op) {
445 checkDepend(op, result);
446 checkNowait(op, result);
447 })
448 .Case([&](omp::TaskloopOp op) {
449 checkAllocate(op, result);
450 checkInReduction(op, result);
451 checkReduction(op, result);
452 })
453 .Case([&](omp::WsloopOp op) {
454 checkAllocate(op, result);
455 checkOrder(op, result);
456 checkReduction(op, result);
457 })
458 .Case([&](omp::ParallelOp op) {
459 checkAllocate(op, result);
460 checkReduction(op, result);
461 checkNumThreads(op, result);
462 })
463 .Case([&](omp::SimdOp op) { checkReduction(op, result); })
464 .Case<omp::AtomicReadOp, omp::AtomicWriteOp, omp::AtomicUpdateOp,
465 omp::AtomicCaptureOp>([&](auto op) { checkHint(op, result); })
466 .Case<omp::TargetEnterDataOp, omp::TargetExitDataOp>(
467 [&](auto op) { checkDepend(op, result); })
468 .Case([&](omp::TargetUpdateOp op) { checkDepend(op, result); })
469 .Case([&](omp::TargetOp op) {
470 checkAllocate(op, result);
471 checkBare(op, result);
472 checkDependIteratorModifier(op, result);
473 checkInReduction(op, result);
474 checkThreadLimit(op, result);
475 })
476 .Default([](Operation &) {
477 // Assume all clauses for an operation can be translated unless they are
478 // checked above.
479 });
480 return result;
481}
482
483static LogicalResult handleError(llvm::Error error, Operation &op) {
484 LogicalResult result = success();
485 if (error) {
486 llvm::handleAllErrors(
487 std::move(error),
488 [&](const PreviouslyReportedError &) { result = failure(); },
489 [&](const llvm::ErrorInfoBase &err) {
490 result = op.emitError(err.message());
491 });
492 }
493 return result;
494}
495
496template <typename T>
497static LogicalResult handleError(llvm::Expected<T> &result, Operation &op) {
498 if (!result)
499 return handleError(result.takeError(), op);
500
501 return success();
502}
503
504/// Find the insertion point for allocas given the current insertion point for
505/// normal operations in the builder.
506static llvm::OpenMPIRBuilder::InsertPointTy
507findAllocaInsertPoint(llvm::IRBuilderBase &builder,
508 LLVM::ModuleTranslation &moduleTranslation) {
509 // If there is an alloca insertion point on stack, i.e. we are in a nested
510 // operation and a specific point was provided by some surrounding operation,
511 // use it.
512 llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
513 WalkResult walkResult = moduleTranslation.stackWalk<OpenMPAllocaStackFrame>(
514 [&](OpenMPAllocaStackFrame &frame) {
515 allocaInsertPoint = frame.allocaInsertPoint;
516 return WalkResult::interrupt();
517 });
518 // In cases with multiple levels of outlining, the tree walk might find an
519 // alloca insertion point that is inside the original function while the
520 // builder insertion point is inside the outlined function. We need to make
521 // sure that we do not use it in those cases.
522 if (walkResult.wasInterrupted() &&
523 allocaInsertPoint.getBlock()->getParent() ==
524 builder.GetInsertBlock()->getParent())
525 return allocaInsertPoint;
526
527 // Otherwise, insert to the entry block of the surrounding function.
528 // If the current IRBuilder InsertPoint is the function's entry, it cannot
529 // also be used for alloca insertion which would result in insertion order
530 // confusion. Create a new BasicBlock for the Builder and use the entry block
531 // for the allocs.
532 // TODO: Create a dedicated alloca BasicBlock at function creation such that
533 // we do not need to move the current InertPoint here.
534 if (builder.GetInsertBlock() ==
535 &builder.GetInsertBlock()->getParent()->getEntryBlock()) {
536 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end() &&
537 "Assuming end of basic block");
538 llvm::BasicBlock *entryBB = llvm::BasicBlock::Create(
539 builder.getContext(), "entry", builder.GetInsertBlock()->getParent(),
540 builder.GetInsertBlock()->getNextNode());
541 builder.CreateBr(entryBB);
542 builder.SetInsertPoint(entryBB);
543 }
544
545 llvm::BasicBlock &funcEntryBlock =
546 builder.GetInsertBlock()->getParent()->getEntryBlock();
547 return llvm::OpenMPIRBuilder::InsertPointTy(
548 &funcEntryBlock, funcEntryBlock.getFirstInsertionPt());
549}
550
551/// Find the loop information structure for the loop nest being translated. It
552/// will return a `null` value unless called from the translation function for
553/// a loop wrapper operation after successfully translating its body.
554static llvm::CanonicalLoopInfo *
556 llvm::CanonicalLoopInfo *loopInfo = nullptr;
557 moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
558 [&](OpenMPLoopInfoStackFrame &frame) {
559 loopInfo = frame.loopInfo;
560 return WalkResult::interrupt();
561 });
562 return loopInfo;
563}
564
565/// Converts the given region that appears within an OpenMP dialect operation to
566/// LLVM IR, creating a branch from the `sourceBlock` to the entry block of the
567/// region, and a branch from any block with an successor-less OpenMP terminator
568/// to `continuationBlock`. Populates `continuationBlockPHIs` with the PHI nodes
569/// of the continuation block if provided.
571 Region &region, StringRef blockName, llvm::IRBuilderBase &builder,
572 LLVM::ModuleTranslation &moduleTranslation,
573 SmallVectorImpl<llvm::PHINode *> *continuationBlockPHIs = nullptr) {
574 bool isLoopWrapper = isa<omp::LoopWrapperInterface>(region.getParentOp());
575
576 llvm::BasicBlock *continuationBlock =
577 splitBB(builder, true, "omp.region.cont");
578 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
579
580 llvm::LLVMContext &llvmContext = builder.getContext();
581 for (Block &bb : region) {
582 llvm::BasicBlock *llvmBB = llvm::BasicBlock::Create(
583 llvmContext, blockName, builder.GetInsertBlock()->getParent(),
584 builder.GetInsertBlock()->getNextNode());
585 moduleTranslation.mapBlock(&bb, llvmBB);
586 }
587
588 llvm::Instruction *sourceTerminator = sourceBlock->getTerminator();
589
590 // Terminators (namely YieldOp) may be forwarding values to the region that
591 // need to be available in the continuation block. Collect the types of these
592 // operands in preparation of creating PHI nodes. This is skipped for loop
593 // wrapper operations, for which we know in advance they have no terminators.
594 SmallVector<llvm::Type *> continuationBlockPHITypes;
595 unsigned numYields = 0;
596
597 if (!isLoopWrapper) {
598 bool operandsProcessed = false;
599 for (Block &bb : region.getBlocks()) {
600 if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(bb.getTerminator())) {
601 if (!operandsProcessed) {
602 for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
603 continuationBlockPHITypes.push_back(
604 moduleTranslation.convertType(yield->getOperand(i).getType()));
605 }
606 operandsProcessed = true;
607 } else {
608 assert(continuationBlockPHITypes.size() == yield->getNumOperands() &&
609 "mismatching number of values yielded from the region");
610 for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
611 llvm::Type *operandType =
612 moduleTranslation.convertType(yield->getOperand(i).getType());
613 (void)operandType;
614 assert(continuationBlockPHITypes[i] == operandType &&
615 "values of mismatching types yielded from the region");
616 }
617 }
618 numYields++;
619 }
620 }
621 }
622
623 // Insert PHI nodes in the continuation block for any values forwarded by the
624 // terminators in this region.
625 if (!continuationBlockPHITypes.empty())
626 assert(
627 continuationBlockPHIs &&
628 "expected continuation block PHIs if converted regions yield values");
629 if (continuationBlockPHIs) {
630 llvm::IRBuilderBase::InsertPointGuard guard(builder);
631 continuationBlockPHIs->reserve(continuationBlockPHITypes.size());
632 builder.SetInsertPoint(continuationBlock, continuationBlock->begin());
633 for (llvm::Type *ty : continuationBlockPHITypes)
634 continuationBlockPHIs->push_back(builder.CreatePHI(ty, numYields));
635 }
636
637 // Convert blocks one by one in topological order to ensure
638 // defs are converted before uses.
640 for (Block *bb : blocks) {
641 llvm::BasicBlock *llvmBB = moduleTranslation.lookupBlock(bb);
642 // Retarget the branch of the entry block to the entry block of the
643 // converted region (regions are single-entry).
644 if (bb->isEntryBlock()) {
645 assert(sourceTerminator->getNumSuccessors() == 1 &&
646 "provided entry block has multiple successors");
647 assert(sourceTerminator->getSuccessor(0) == continuationBlock &&
648 "ContinuationBlock is not the successor of the entry block");
649 sourceTerminator->setSuccessor(0, llvmBB);
650 }
651
652 llvm::IRBuilderBase::InsertPointGuard guard(builder);
653 if (failed(
654 moduleTranslation.convertBlock(*bb, bb->isEntryBlock(), builder)))
655 return llvm::make_error<PreviouslyReportedError>();
656
657 // Create a direct branch here for loop wrappers to prevent their lack of a
658 // terminator from causing a crash below.
659 if (isLoopWrapper) {
660 builder.CreateBr(continuationBlock);
661 continue;
662 }
663
664 // Special handling for `omp.yield` and `omp.terminator` (we may have more
665 // than one): they return the control to the parent OpenMP dialect operation
666 // so replace them with the branch to the continuation block. We handle this
667 // here to avoid relying inter-function communication through the
668 // ModuleTranslation class to set up the correct insertion point. This is
669 // also consistent with MLIR's idiom of handling special region terminators
670 // in the same code that handles the region-owning operation.
671 Operation *terminator = bb->getTerminator();
672 if (isa<omp::TerminatorOp, omp::YieldOp>(terminator)) {
673 builder.CreateBr(continuationBlock);
674
675 for (unsigned i = 0, e = terminator->getNumOperands(); i < e; ++i)
676 (*continuationBlockPHIs)[i]->addIncoming(
677 moduleTranslation.lookupValue(terminator->getOperand(i)), llvmBB);
678 }
679 }
680 // After all blocks have been traversed and values mapped, connect the PHI
681 // nodes to the results of preceding blocks.
682 LLVM::detail::connectPHINodes(region, moduleTranslation);
683
684 // Remove the blocks and values defined in this region from the mapping since
685 // they are not visible outside of this region. This allows the same region to
686 // be converted several times, that is cloned, without clashes, and slightly
687 // speeds up the lookups.
688 moduleTranslation.forgetMapping(region);
689
690 return continuationBlock;
691}
692
693/// Convert ProcBindKind from MLIR-generated enum to LLVM enum.
694static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind) {
695 switch (kind) {
696 case omp::ClauseProcBindKind::Close:
697 return llvm::omp::ProcBindKind::OMP_PROC_BIND_close;
698 case omp::ClauseProcBindKind::Master:
699 return llvm::omp::ProcBindKind::OMP_PROC_BIND_master;
700 case omp::ClauseProcBindKind::Primary:
701 return llvm::omp::ProcBindKind::OMP_PROC_BIND_primary;
702 case omp::ClauseProcBindKind::Spread:
703 return llvm::omp::ProcBindKind::OMP_PROC_BIND_spread;
704 }
705 llvm_unreachable("Unknown ClauseProcBindKind kind");
706}
707
708/// Converts an OpenMP 'masked' operation into LLVM IR using OpenMPIRBuilder.
709static LogicalResult
710convertOmpMasked(Operation &opInst, llvm::IRBuilderBase &builder,
711 LLVM::ModuleTranslation &moduleTranslation) {
712 auto maskedOp = cast<omp::MaskedOp>(opInst);
713 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
714
715 if (failed(checkImplementationStatus(opInst)))
716 return failure();
717
718 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
719 // MaskedOp has only one region associated with it.
720 auto &region = maskedOp.getRegion();
721 builder.restoreIP(codeGenIP);
722 return convertOmpOpRegions(region, "omp.masked.region", builder,
723 moduleTranslation)
724 .takeError();
725 };
726
727 // TODO: Perform finalization actions for variables. This has to be
728 // called for variables which have destructors/finalizers.
729 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
730
731 llvm::Value *filterVal = nullptr;
732 if (auto filterVar = maskedOp.getFilteredThreadId()) {
733 filterVal = moduleTranslation.lookupValue(filterVar);
734 } else {
735 llvm::LLVMContext &llvmContext = builder.getContext();
736 filterVal =
737 llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext), /*V=*/0);
738 }
739 assert(filterVal != nullptr);
740 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
741 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
742 moduleTranslation.getOpenMPBuilder()->createMasked(ompLoc, bodyGenCB,
743 finiCB, filterVal);
744
745 if (failed(handleError(afterIP, opInst)))
746 return failure();
747
748 builder.restoreIP(*afterIP);
749 return success();
750}
751
752/// Converts an OpenMP 'master' operation into LLVM IR using OpenMPIRBuilder.
753static LogicalResult
754convertOmpMaster(Operation &opInst, llvm::IRBuilderBase &builder,
755 LLVM::ModuleTranslation &moduleTranslation) {
756 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
757 auto masterOp = cast<omp::MasterOp>(opInst);
758
759 if (failed(checkImplementationStatus(opInst)))
760 return failure();
761
762 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
763 // MasterOp has only one region associated with it.
764 auto &region = masterOp.getRegion();
765 builder.restoreIP(codeGenIP);
766 return convertOmpOpRegions(region, "omp.master.region", builder,
767 moduleTranslation)
768 .takeError();
769 };
770
771 // TODO: Perform finalization actions for variables. This has to be
772 // called for variables which have destructors/finalizers.
773 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
774
775 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
776 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
777 moduleTranslation.getOpenMPBuilder()->createMaster(ompLoc, bodyGenCB,
778 finiCB);
779
780 if (failed(handleError(afterIP, opInst)))
781 return failure();
782
783 builder.restoreIP(*afterIP);
784 return success();
785}
786
787/// Converts an OpenMP 'critical' operation into LLVM IR using OpenMPIRBuilder.
788static LogicalResult
789convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder,
790 LLVM::ModuleTranslation &moduleTranslation) {
791 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
792 auto criticalOp = cast<omp::CriticalOp>(opInst);
793
794 if (failed(checkImplementationStatus(opInst)))
795 return failure();
796
797 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
798 // CriticalOp has only one region associated with it.
799 auto &region = cast<omp::CriticalOp>(opInst).getRegion();
800 builder.restoreIP(codeGenIP);
801 return convertOmpOpRegions(region, "omp.critical.region", builder,
802 moduleTranslation)
803 .takeError();
804 };
805
806 // TODO: Perform finalization actions for variables. This has to be
807 // called for variables which have destructors/finalizers.
808 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
809
810 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
811 llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
812 llvm::Constant *hint = nullptr;
813
814 // If it has a name, it probably has a hint too.
815 if (criticalOp.getNameAttr()) {
816 // The verifiers in OpenMP Dialect guarentee that all the pointers are
817 // non-null
818 auto symbolRef = cast<SymbolRefAttr>(criticalOp.getNameAttr());
819 auto criticalDeclareOp =
821 symbolRef);
822 hint =
823 llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext),
824 static_cast<int>(criticalDeclareOp.getHint()));
825 }
826 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
827 moduleTranslation.getOpenMPBuilder()->createCritical(
828 ompLoc, bodyGenCB, finiCB, criticalOp.getName().value_or(""), hint);
829
830 if (failed(handleError(afterIP, opInst)))
831 return failure();
832
833 builder.restoreIP(*afterIP);
834 return success();
835}
836
837/// A util to collect info needed to convert delayed privatizers from MLIR to
838/// LLVM.
840 template <typename OP>
842 : blockArgs(
843 cast<omp::BlockArgOpenMPOpInterface>(*op).getPrivateBlockArgs()) {
844 mlirVars.reserve(blockArgs.size());
845 llvmVars.reserve(blockArgs.size());
846 collectPrivatizationDecls<OP>(op);
847
848 for (mlir::Value privateVar : op.getPrivateVars())
849 mlirVars.push_back(privateVar);
850 }
851
856
857private:
858 /// Populates `privatizations` with privatization declarations used for the
859 /// given op.
860 template <class OP>
861 void collectPrivatizationDecls(OP op) {
862 std::optional<ArrayAttr> attr = op.getPrivateSyms();
863 if (!attr)
864 return;
865
866 privatizers.reserve(privatizers.size() + attr->size());
867 for (auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
868 privatizers.push_back(findPrivatizer(op, symbolRef));
869 }
870 }
871};
872
873/// Populates `reductions` with reduction declarations used in the given op.
874template <typename T>
875static void
878 std::optional<ArrayAttr> attr = op.getReductionSyms();
879 if (!attr)
880 return;
881
882 reductions.reserve(reductions.size() + op.getNumReductionVars());
883 for (auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
884 reductions.push_back(
886 op, symbolRef));
887 }
888}
889
890/// Translates the blocks contained in the given region and appends them to at
891/// the current insertion point of `builder`. The operations of the entry block
892/// are appended to the current insertion block. If set, `continuationBlockArgs`
893/// is populated with translated values that correspond to the values
894/// omp.yield'ed from the region.
895static LogicalResult inlineConvertOmpRegions(
896 Region &region, StringRef blockName, llvm::IRBuilderBase &builder,
897 LLVM::ModuleTranslation &moduleTranslation,
898 SmallVectorImpl<llvm::Value *> *continuationBlockArgs = nullptr) {
899 if (region.empty())
900 return success();
901
902 // Special case for single-block regions that don't create additional blocks:
903 // insert operations without creating additional blocks.
904 if (region.hasOneBlock()) {
905 llvm::Instruction *potentialTerminator =
906 builder.GetInsertBlock()->empty() ? nullptr
907 : &builder.GetInsertBlock()->back();
908
909 if (potentialTerminator && potentialTerminator->isTerminator())
910 potentialTerminator->removeFromParent();
911 moduleTranslation.mapBlock(&region.front(), builder.GetInsertBlock());
912
913 if (failed(moduleTranslation.convertBlock(
914 region.front(), /*ignoreArguments=*/true, builder)))
915 return failure();
916
917 // The continuation arguments are simply the translated terminator operands.
918 if (continuationBlockArgs)
919 llvm::append_range(
920 *continuationBlockArgs,
921 moduleTranslation.lookupValues(region.front().back().getOperands()));
922
923 // Drop the mapping that is no longer necessary so that the same region can
924 // be processed multiple times.
925 moduleTranslation.forgetMapping(region);
926
927 if (potentialTerminator && potentialTerminator->isTerminator()) {
928 llvm::BasicBlock *block = builder.GetInsertBlock();
929 if (block->empty()) {
930 // this can happen for really simple reduction init regions e.g.
931 // %0 = llvm.mlir.constant(0 : i32) : i32
932 // omp.yield(%0 : i32)
933 // because the llvm.mlir.constant (MLIR op) isn't converted into any
934 // llvm op
935 potentialTerminator->insertInto(block, block->begin());
936 } else {
937 potentialTerminator->insertAfter(&block->back());
938 }
939 }
940
941 return success();
942 }
943
945 llvm::Expected<llvm::BasicBlock *> continuationBlock =
946 convertOmpOpRegions(region, blockName, builder, moduleTranslation, &phis);
947
948 if (failed(handleError(continuationBlock, *region.getParentOp())))
949 return failure();
950
951 if (continuationBlockArgs)
952 llvm::append_range(*continuationBlockArgs, phis);
953 builder.SetInsertPoint(*continuationBlock,
954 (*continuationBlock)->getFirstInsertionPt());
955 return success();
956}
957
958namespace {
959/// Owning equivalents of OpenMPIRBuilder::(Atomic)ReductionGen that are used to
960/// store lambdas with capture.
961using OwningReductionGen =
962 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
963 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *,
964 llvm::Value *&)>;
965using OwningAtomicReductionGen =
966 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
967 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Type *, llvm::Value *,
968 llvm::Value *)>;
969using OwningDataPtrPtrReductionGen =
970 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
971 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *&)>;
972} // namespace
973
974/// Create an OpenMPIRBuilder-compatible reduction generator for the given
975/// reduction declaration. The generator uses `builder` but ignores its
976/// insertion point.
977static OwningReductionGen
978makeReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder,
979 LLVM::ModuleTranslation &moduleTranslation) {
980 // The lambda is mutable because we need access to non-const methods of decl
981 // (which aren't actually mutating it), and we must capture decl by-value to
982 // avoid the dangling reference after the parent function returns.
983 OwningReductionGen gen =
984 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
985 llvm::Value *lhs, llvm::Value *rhs,
986 llvm::Value *&result) mutable
987 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
988 moduleTranslation.mapValue(decl.getReductionLhsArg(), lhs);
989 moduleTranslation.mapValue(decl.getReductionRhsArg(), rhs);
990 builder.restoreIP(insertPoint);
992 if (failed(inlineConvertOmpRegions(decl.getReductionRegion(),
993 "omp.reduction.nonatomic.body", builder,
994 moduleTranslation, &phis)))
995 return llvm::createStringError(
996 "failed to inline `combiner` region of `omp.declare_reduction`");
997 result = llvm::getSingleElement(phis);
998 return builder.saveIP();
999 };
1000 return gen;
1001}
1002
1003/// Create an OpenMPIRBuilder-compatible atomic reduction generator for the
1004/// given reduction declaration. The generator uses `builder` but ignores its
1005/// insertion point. Returns null if there is no atomic region available in the
1006/// reduction declaration.
1007static OwningAtomicReductionGen
1008makeAtomicReductionGen(omp::DeclareReductionOp decl,
1009 llvm::IRBuilderBase &builder,
1010 LLVM::ModuleTranslation &moduleTranslation) {
1011 if (decl.getAtomicReductionRegion().empty())
1012 return OwningAtomicReductionGen();
1013
1014 // The lambda is mutable because we need access to non-const methods of decl
1015 // (which aren't actually mutating it), and we must capture decl by-value to
1016 // avoid the dangling reference after the parent function returns.
1017 OwningAtomicReductionGen atomicGen =
1018 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint, llvm::Type *,
1019 llvm::Value *lhs, llvm::Value *rhs) mutable
1020 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1021 moduleTranslation.mapValue(decl.getAtomicReductionLhsArg(), lhs);
1022 moduleTranslation.mapValue(decl.getAtomicReductionRhsArg(), rhs);
1023 builder.restoreIP(insertPoint);
1025 if (failed(inlineConvertOmpRegions(decl.getAtomicReductionRegion(),
1026 "omp.reduction.atomic.body", builder,
1027 moduleTranslation, &phis)))
1028 return llvm::createStringError(
1029 "failed to inline `atomic` region of `omp.declare_reduction`");
1030 assert(phis.empty());
1031 return builder.saveIP();
1032 };
1033 return atomicGen;
1034}
1035
1036/// Create an OpenMPIRBuilder-compatible `data_ptr_ptr` reduction generator for
1037/// the given reduction declaration. The generator uses `builder` but ignores
1038/// its insertion point. Returns null if there is no `data_ptr_ptr` region
1039/// available in the reduction declaration.
1040static OwningDataPtrPtrReductionGen
1041makeRefDataPtrGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder,
1042 LLVM::ModuleTranslation &moduleTranslation, bool isByRef) {
1043 if (!isByRef)
1044 return OwningDataPtrPtrReductionGen();
1045
1046 OwningDataPtrPtrReductionGen refDataPtrGen =
1047 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
1048 llvm::Value *byRefVal, llvm::Value *&result) mutable
1049 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1050 moduleTranslation.mapValue(decl.getDataPtrPtrRegionArg(), byRefVal);
1051 builder.restoreIP(insertPoint);
1053 if (failed(inlineConvertOmpRegions(decl.getDataPtrPtrRegion(),
1054 "omp.data_ptr_ptr.body", builder,
1055 moduleTranslation, &phis)))
1056 return llvm::createStringError(
1057 "failed to inline `data_ptr_ptr` region of `omp.declare_reduction`");
1058 result = llvm::getSingleElement(phis);
1059 return builder.saveIP();
1060 };
1061
1062 return refDataPtrGen;
1063}
1064
1065/// Converts an OpenMP 'ordered' operation into LLVM IR using OpenMPIRBuilder.
1066static LogicalResult
1067convertOmpOrdered(Operation &opInst, llvm::IRBuilderBase &builder,
1068 LLVM::ModuleTranslation &moduleTranslation) {
1069 auto orderedOp = cast<omp::OrderedOp>(opInst);
1070
1071 if (failed(checkImplementationStatus(opInst)))
1072 return failure();
1073
1074 omp::ClauseDepend dependType = *orderedOp.getDoacrossDependType();
1075 bool isDependSource = dependType == omp::ClauseDepend::dependsource;
1076 unsigned numLoops = *orderedOp.getDoacrossNumLoops();
1077 SmallVector<llvm::Value *> vecValues =
1078 moduleTranslation.lookupValues(orderedOp.getDoacrossDependVars());
1079
1080 size_t indexVecValues = 0;
1081 while (indexVecValues < vecValues.size()) {
1082 SmallVector<llvm::Value *> storeValues;
1083 storeValues.reserve(numLoops);
1084 for (unsigned i = 0; i < numLoops; i++) {
1085 storeValues.push_back(vecValues[indexVecValues]);
1086 indexVecValues++;
1087 }
1088 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1089 findAllocaInsertPoint(builder, moduleTranslation);
1090 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1091 builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createOrderedDepend(
1092 ompLoc, allocaIP, numLoops, storeValues, ".cnt.addr", isDependSource));
1093 }
1094 return success();
1095}
1096
1097/// Converts an OpenMP 'ordered_region' operation into LLVM IR using
1098/// OpenMPIRBuilder.
1099static LogicalResult
1100convertOmpOrderedRegion(Operation &opInst, llvm::IRBuilderBase &builder,
1101 LLVM::ModuleTranslation &moduleTranslation) {
1102 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1103 auto orderedRegionOp = cast<omp::OrderedRegionOp>(opInst);
1104
1105 if (failed(checkImplementationStatus(opInst)))
1106 return failure();
1107
1108 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
1109 // OrderedOp has only one region associated with it.
1110 auto &region = cast<omp::OrderedRegionOp>(opInst).getRegion();
1111 builder.restoreIP(codeGenIP);
1112 return convertOmpOpRegions(region, "omp.ordered.region", builder,
1113 moduleTranslation)
1114 .takeError();
1115 };
1116
1117 // TODO: Perform finalization actions for variables. This has to be
1118 // called for variables which have destructors/finalizers.
1119 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
1120
1121 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1122 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1123 moduleTranslation.getOpenMPBuilder()->createOrderedThreadsSimd(
1124 ompLoc, bodyGenCB, finiCB, !orderedRegionOp.getParLevelSimd());
1125
1126 if (failed(handleError(afterIP, opInst)))
1127 return failure();
1128
1129 builder.restoreIP(*afterIP);
1130 return success();
1131}
1132
1133namespace {
1134/// Contains the arguments for an LLVM store operation
1135struct DeferredStore {
1136 DeferredStore(llvm::Value *value, llvm::Value *address)
1137 : value(value), address(address) {}
1138
1139 llvm::Value *value;
1140 llvm::Value *address;
1141};
1142} // namespace
1143
1144/// Allocate space for privatized reduction variables.
1145/// `deferredStores` contains information to create store operations which needs
1146/// to be inserted after all allocas
1147template <typename T>
1148static LogicalResult
1150 llvm::IRBuilderBase &builder,
1151 LLVM::ModuleTranslation &moduleTranslation,
1152 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1154 SmallVectorImpl<llvm::Value *> &privateReductionVariables,
1155 DenseMap<Value, llvm::Value *> &reductionVariableMap,
1156 SmallVectorImpl<DeferredStore> &deferredStores,
1157 llvm::ArrayRef<bool> isByRefs) {
1158 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1159 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1160
1161 // delay creating stores until after all allocas
1162 deferredStores.reserve(loop.getNumReductionVars());
1163
1164 for (std::size_t i = 0; i < loop.getNumReductionVars(); ++i) {
1165 Region &allocRegion = reductionDecls[i].getAllocRegion();
1166 if (isByRefs[i]) {
1167 if (allocRegion.empty())
1168 continue;
1169
1171 if (failed(inlineConvertOmpRegions(allocRegion, "omp.reduction.alloc",
1172 builder, moduleTranslation, &phis)))
1173 return loop.emitError(
1174 "failed to inline `alloc` region of `omp.declare_reduction`");
1175
1176 assert(phis.size() == 1 && "expected one allocation to be yielded");
1177 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1178
1179 // Allocate reduction variable (which is a pointer to the real reduction
1180 // variable allocated in the inlined region)
1181 llvm::Value *var = builder.CreateAlloca(
1182 moduleTranslation.convertType(reductionDecls[i].getType()));
1183
1184 llvm::Type *ptrTy = builder.getPtrTy();
1185 llvm::Value *castVar =
1186 builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
1187 llvm::Value *castPhi =
1188 builder.CreatePointerBitCastOrAddrSpaceCast(phis[0], ptrTy);
1189
1190 deferredStores.emplace_back(castPhi, castVar);
1191
1192 privateReductionVariables[i] = castVar;
1193 moduleTranslation.mapValue(reductionArgs[i], castPhi);
1194 reductionVariableMap.try_emplace(loop.getReductionVars()[i], castPhi);
1195 } else {
1196 assert(allocRegion.empty() &&
1197 "allocaction is implicit for by-val reduction");
1198 llvm::Value *var = builder.CreateAlloca(
1199 moduleTranslation.convertType(reductionDecls[i].getType()));
1200
1201 llvm::Type *ptrTy = builder.getPtrTy();
1202 llvm::Value *castVar =
1203 builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
1204
1205 moduleTranslation.mapValue(reductionArgs[i], castVar);
1206 privateReductionVariables[i] = castVar;
1207 reductionVariableMap.try_emplace(loop.getReductionVars()[i], castVar);
1208 }
1209 }
1210
1211 return success();
1212}
1213
1214/// Map input arguments to reduction initialization region
1215template <typename T>
1216static void
1218 llvm::IRBuilderBase &builder,
1220 DenseMap<Value, llvm::Value *> &reductionVariableMap,
1221 unsigned i) {
1222 // map input argument to the initialization region
1223 mlir::omp::DeclareReductionOp &reduction = reductionDecls[i];
1224 Region &initializerRegion = reduction.getInitializerRegion();
1225 Block &entry = initializerRegion.front();
1226
1227 mlir::Value mlirSource = loop.getReductionVars()[i];
1228 llvm::Value *llvmSource = moduleTranslation.lookupValue(mlirSource);
1229 llvm::Value *origVal = llvmSource;
1230 // If a non-pointer value is expected, load the value from the source pointer.
1231 if (!isa<LLVM::LLVMPointerType>(
1232 reduction.getInitializerMoldArg().getType()) &&
1233 isa<LLVM::LLVMPointerType>(mlirSource.getType())) {
1234 origVal =
1235 builder.CreateLoad(moduleTranslation.convertType(
1236 reduction.getInitializerMoldArg().getType()),
1237 llvmSource, "omp_orig");
1238 }
1239 moduleTranslation.mapValue(reduction.getInitializerMoldArg(), origVal);
1240
1241 if (entry.getNumArguments() > 1) {
1242 llvm::Value *allocation =
1243 reductionVariableMap.lookup(loop.getReductionVars()[i]);
1244 moduleTranslation.mapValue(reduction.getInitializerAllocArg(), allocation);
1245 }
1246}
1247
1248static void
1249setInsertPointForPossiblyEmptyBlock(llvm::IRBuilderBase &builder,
1250 llvm::BasicBlock *block = nullptr) {
1251 if (block == nullptr)
1252 block = builder.GetInsertBlock();
1253
1254 if (!block->hasTerminator())
1255 builder.SetInsertPoint(block);
1256 else
1257 builder.SetInsertPoint(block->getTerminator());
1258}
1259
1260/// Inline reductions' `init` regions. This functions assumes that the
1261/// `builder`'s insertion point is where the user wants the `init` regions to be
1262/// inlined; i.e. it does not try to find a proper insertion location for the
1263/// `init` regions. It also leaves the `builder's insertions point in a state
1264/// where the user can continue the code-gen directly afterwards.
1265template <typename OP>
1266static LogicalResult
1267initReductionVars(OP op, ArrayRef<BlockArgument> reductionArgs,
1268 llvm::IRBuilderBase &builder,
1269 LLVM::ModuleTranslation &moduleTranslation,
1270 llvm::BasicBlock *latestAllocaBlock,
1272 SmallVectorImpl<llvm::Value *> &privateReductionVariables,
1273 DenseMap<Value, llvm::Value *> &reductionVariableMap,
1274 llvm::ArrayRef<bool> isByRef,
1275 SmallVectorImpl<DeferredStore> &deferredStores) {
1276 if (op.getNumReductionVars() == 0)
1277 return success();
1278
1279 llvm::BasicBlock *initBlock = splitBB(builder, true, "omp.reduction.init");
1280 auto allocaIP = llvm::IRBuilderBase::InsertPoint(
1281 latestAllocaBlock, latestAllocaBlock->getTerminator()->getIterator());
1282 builder.restoreIP(allocaIP);
1283 SmallVector<llvm::Value *> byRefVars(op.getNumReductionVars());
1284
1285 for (unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1286 if (isByRef[i]) {
1287 if (!reductionDecls[i].getAllocRegion().empty())
1288 continue;
1289
1290 // TODO: remove after all users of by-ref are updated to use the alloc
1291 // region: Allocate reduction variable (which is a pointer to the real
1292 // reduciton variable allocated in the inlined region)
1293 byRefVars[i] = builder.CreateAlloca(
1294 moduleTranslation.convertType(reductionDecls[i].getType()));
1295 }
1296 }
1297
1298 setInsertPointForPossiblyEmptyBlock(builder, initBlock);
1299
1300 // store result of the alloc region to the allocated pointer to the real
1301 // reduction variable
1302 for (auto [data, addr] : deferredStores)
1303 builder.CreateStore(data, addr);
1304
1305 // Before the loop, store the initial values of reductions into reduction
1306 // variables. Although this could be done after allocas, we don't want to mess
1307 // up with the alloca insertion point.
1308 for (unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1310
1311 // map block argument to initializer region
1312 mapInitializationArgs(op, moduleTranslation, builder, reductionDecls,
1313 reductionVariableMap, i);
1314
1315 // TODO In some cases (specially on the GPU), the init regions may
1316 // contains stack alloctaions. If the region is inlined in a loop, this is
1317 // problematic. Instead of just inlining the region, handle allocations by
1318 // hoisting fixed length allocations to the function entry and using
1319 // stacksave and restore for variable length ones.
1320 if (failed(inlineConvertOmpRegions(reductionDecls[i].getInitializerRegion(),
1321 "omp.reduction.neutral", builder,
1322 moduleTranslation, &phis)))
1323 return failure();
1324
1325 assert(phis.size() == 1 && "expected one value to be yielded from the "
1326 "reduction neutral element declaration region");
1327
1329
1330 if (isByRef[i]) {
1331 if (!reductionDecls[i].getAllocRegion().empty())
1332 // done in allocReductionVars
1333 continue;
1334
1335 // TODO: this path can be removed once all users of by-ref are updated to
1336 // use an alloc region
1337
1338 // Store the result of the inlined region to the allocated reduction var
1339 // ptr
1340 builder.CreateStore(phis[0], byRefVars[i]);
1341
1342 privateReductionVariables[i] = byRefVars[i];
1343 moduleTranslation.mapValue(reductionArgs[i], phis[0]);
1344 reductionVariableMap.try_emplace(op.getReductionVars()[i], phis[0]);
1345 } else {
1346 // for by-ref case the store is inside of the reduction region
1347 builder.CreateStore(phis[0], privateReductionVariables[i]);
1348 // the rest was handled in allocByValReductionVars
1349 }
1350
1351 // forget the mapping for the initializer region because we might need a
1352 // different mapping if this reduction declaration is re-used for a
1353 // different variable
1354 moduleTranslation.forgetMapping(reductionDecls[i].getInitializerRegion());
1355 }
1356
1357 return success();
1358}
1359
1360/// Collect reduction info
1361template <typename T>
1362static void collectReductionInfo(
1363 T loop, llvm::IRBuilderBase &builder,
1364 LLVM::ModuleTranslation &moduleTranslation,
1367 SmallVectorImpl<OwningAtomicReductionGen> &owningAtomicReductionGens,
1369 const ArrayRef<llvm::Value *> privateReductionVariables,
1371 ArrayRef<bool> isByRef) {
1372 unsigned numReductions = loop.getNumReductionVars();
1373
1374 for (unsigned i = 0; i < numReductions; ++i) {
1375 owningReductionGens.push_back(
1376 makeReductionGen(reductionDecls[i], builder, moduleTranslation));
1377 owningAtomicReductionGens.push_back(
1378 makeAtomicReductionGen(reductionDecls[i], builder, moduleTranslation));
1380 reductionDecls[i], builder, moduleTranslation, isByRef[i]));
1381 }
1382
1383 // Collect the reduction information.
1384 reductionInfos.reserve(numReductions);
1385 for (unsigned i = 0; i < numReductions; ++i) {
1386 llvm::OpenMPIRBuilder::ReductionGenAtomicCBTy atomicGen = nullptr;
1387 if (owningAtomicReductionGens[i])
1388 atomicGen = owningAtomicReductionGens[i];
1389 llvm::Value *variable =
1390 moduleTranslation.lookupValue(loop.getReductionVars()[i]);
1391 mlir::Type allocatedType;
1392 reductionDecls[i].getAllocRegion().walk([&](mlir::Operation *op) {
1393 if (auto alloca = mlir::dyn_cast<LLVM::AllocaOp>(op)) {
1394 allocatedType = alloca.getElemType();
1396 }
1397
1399 });
1400
1401 reductionInfos.push_back(
1402 {moduleTranslation.convertType(reductionDecls[i].getType()), variable,
1403 privateReductionVariables[i],
1404 /*EvaluationKind=*/llvm::OpenMPIRBuilder::EvalKind::Scalar,
1406 /*ReductionGenClang=*/nullptr, atomicGen,
1408 allocatedType ? moduleTranslation.convertType(allocatedType) : nullptr,
1409 reductionDecls[i].getByrefElementType()
1410 ? moduleTranslation.convertType(
1411 *reductionDecls[i].getByrefElementType())
1412 : nullptr});
1413 }
1414}
1415
1416/// handling of DeclareReductionOp's cleanup region
1417static LogicalResult
1419 llvm::ArrayRef<llvm::Value *> privateVariables,
1420 LLVM::ModuleTranslation &moduleTranslation,
1421 llvm::IRBuilderBase &builder, StringRef regionName,
1422 bool shouldLoadCleanupRegionArg = true) {
1423 for (auto [i, cleanupRegion] : llvm::enumerate(cleanupRegions)) {
1424 if (cleanupRegion->empty())
1425 continue;
1426
1427 // map the argument to the cleanup region
1428 Block &entry = cleanupRegion->front();
1429
1430 llvm::Instruction *potentialTerminator =
1431 builder.GetInsertBlock()->empty() ? nullptr
1432 : &builder.GetInsertBlock()->back();
1433 if (potentialTerminator && potentialTerminator->isTerminator())
1434 builder.SetInsertPoint(potentialTerminator);
1435 llvm::Value *privateVarValue =
1436 shouldLoadCleanupRegionArg
1437 ? builder.CreateLoad(
1438 moduleTranslation.convertType(entry.getArgument(0).getType()),
1439 privateVariables[i])
1440 : privateVariables[i];
1441
1442 moduleTranslation.mapValue(entry.getArgument(0), privateVarValue);
1443
1444 if (failed(inlineConvertOmpRegions(*cleanupRegion, regionName, builder,
1445 moduleTranslation)))
1446 return failure();
1447
1448 // clear block argument mapping in case it needs to be re-created with a
1449 // different source for another use of the same reduction decl
1450 moduleTranslation.forgetMapping(*cleanupRegion);
1451 }
1452 return success();
1453}
1454
1455// TODO: not used by ParallelOp
1456template <class OP>
1457static LogicalResult createReductionsAndCleanup(
1458 OP op, llvm::IRBuilderBase &builder,
1459 LLVM::ModuleTranslation &moduleTranslation,
1460 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1462 ArrayRef<llvm::Value *> privateReductionVariables, ArrayRef<bool> isByRef,
1463 bool isNowait = false, bool isTeamsReduction = false) {
1464 // Process the reductions if required.
1465 if (op.getNumReductionVars() == 0)
1466 return success();
1467
1469 SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
1470 SmallVector<OwningDataPtrPtrReductionGen> owningReductionGenRefDataPtrGens;
1472
1473 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
1474
1475 // Create the reduction generators. We need to own them here because
1476 // ReductionInfo only accepts references to the generators.
1477 collectReductionInfo(op, builder, moduleTranslation, reductionDecls,
1478 owningReductionGens, owningAtomicReductionGens,
1479 owningReductionGenRefDataPtrGens,
1480 privateReductionVariables, reductionInfos, isByRef);
1481
1482 // The call to createReductions below expects the block to have a
1483 // terminator. Create an unreachable instruction to serve as terminator
1484 // and remove it later.
1485 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
1486 builder.SetInsertPoint(tempTerminator);
1487 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
1488 ompBuilder->createReductions(builder.saveIP(), allocaIP, reductionInfos,
1489 isByRef, isNowait, isTeamsReduction);
1490
1491 if (failed(handleError(contInsertPoint, *op)))
1492 return failure();
1493
1494 if (!contInsertPoint->getBlock())
1495 return op->emitOpError() << "failed to convert reductions";
1496
1497 llvm::OpenMPIRBuilder::InsertPointTy afterIP = *contInsertPoint;
1498 if (!isTeamsReduction) {
1499 llvm::OpenMPIRBuilder::InsertPointOrErrorTy barrierIP =
1500 ompBuilder->createBarrier(*contInsertPoint, llvm::omp::OMPD_for);
1501
1502 if (failed(handleError(barrierIP, *op)))
1503 return failure();
1504 afterIP = *barrierIP;
1505 }
1506
1507 tempTerminator->eraseFromParent();
1508 builder.restoreIP(afterIP);
1509
1510 // after the construct, deallocate private reduction variables
1511 SmallVector<Region *> reductionRegions;
1512 llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
1513 [](omp::DeclareReductionOp reductionDecl) {
1514 return &reductionDecl.getCleanupRegion();
1515 });
1516 return inlineOmpRegionCleanup(reductionRegions, privateReductionVariables,
1517 moduleTranslation, builder,
1518 "omp.reduction.cleanup");
1519 return success();
1520}
1521
1522static ArrayRef<bool> getIsByRef(std::optional<ArrayRef<bool>> attr) {
1523 if (!attr)
1524 return {};
1525 return *attr;
1526}
1527
1528// TODO: not used by omp.parallel
1529template <typename OP>
1531 OP op, ArrayRef<BlockArgument> reductionArgs, llvm::IRBuilderBase &builder,
1532 LLVM::ModuleTranslation &moduleTranslation,
1533 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1535 SmallVectorImpl<llvm::Value *> &privateReductionVariables,
1536 DenseMap<Value, llvm::Value *> &reductionVariableMap,
1537 llvm::ArrayRef<bool> isByRef) {
1538 if (op.getNumReductionVars() == 0)
1539 return success();
1540
1541 SmallVector<DeferredStore> deferredStores;
1542
1543 if (failed(allocReductionVars(op, reductionArgs, builder, moduleTranslation,
1544 allocaIP, reductionDecls,
1545 privateReductionVariables, reductionVariableMap,
1546 deferredStores, isByRef)))
1547 return failure();
1548
1549 return initReductionVars(op, reductionArgs, builder, moduleTranslation,
1550 allocaIP.getBlock(), reductionDecls,
1551 privateReductionVariables, reductionVariableMap,
1552 isByRef, deferredStores);
1553}
1554
1555/// Return the llvm::Value * corresponding to the `privateVar` that
1556/// is being privatized. It isn't always as simple as looking up
1557/// moduleTranslation with privateVar. For instance, in case of
1558/// an allocatable, the descriptor for the allocatable is privatized.
1559/// This descriptor is mapped using an MapInfoOp. So, this function
1560/// will return a pointer to the llvm::Value corresponding to the
1561/// block argument for the mapped descriptor.
1562static llvm::Value *
1563findAssociatedValue(Value privateVar, llvm::IRBuilderBase &builder,
1564 LLVM::ModuleTranslation &moduleTranslation,
1565 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1566 if (mappedPrivateVars == nullptr || !mappedPrivateVars->contains(privateVar))
1567 return moduleTranslation.lookupValue(privateVar);
1568
1569 Value blockArg = (*mappedPrivateVars)[privateVar];
1570 Type privVarType = privateVar.getType();
1571 Type blockArgType = blockArg.getType();
1572 assert(isa<LLVM::LLVMPointerType>(blockArgType) &&
1573 "A block argument corresponding to a mapped var should have "
1574 "!llvm.ptr type");
1575
1576 if (privVarType == blockArgType)
1577 return moduleTranslation.lookupValue(blockArg);
1578
1579 // This typically happens when the privatized type is lowered from
1580 // boxchar<KIND> and gets lowered to !llvm.struct<(ptr, i64)>. That is the
1581 // struct/pair is passed by value. But, mapped values are passed only as
1582 // pointers, so before we privatize, we must load the pointer.
1583 if (!isa<LLVM::LLVMPointerType>(privVarType))
1584 return builder.CreateLoad(moduleTranslation.convertType(privVarType),
1585 moduleTranslation.lookupValue(blockArg));
1586
1587 return moduleTranslation.lookupValue(privateVar);
1588}
1589
1590/// Initialize a single (first)private variable. You probably want to use
1591/// allocateAndInitPrivateVars instead of this.
1592/// This returns the private variable which has been initialized. This
1593/// variable should be mapped before constructing the body of the Op.
1595initPrivateVar(llvm::IRBuilderBase &builder,
1596 LLVM::ModuleTranslation &moduleTranslation,
1597 omp::PrivateClauseOp &privDecl, llvm::Value *nonPrivateVar,
1598 BlockArgument &blockArg, llvm::Value *llvmPrivateVar,
1599 llvm::BasicBlock *privInitBlock,
1600 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1601 Region &initRegion = privDecl.getInitRegion();
1602 if (initRegion.empty())
1603 return llvmPrivateVar;
1604
1605 assert(nonPrivateVar);
1606 moduleTranslation.mapValue(privDecl.getInitMoldArg(), nonPrivateVar);
1607 moduleTranslation.mapValue(privDecl.getInitPrivateArg(), llvmPrivateVar);
1608
1609 // in-place convert the private initialization region
1611 if (failed(inlineConvertOmpRegions(initRegion, "omp.private.init", builder,
1612 moduleTranslation, &phis)))
1613 return llvm::createStringError(
1614 "failed to inline `init` region of `omp.private`");
1615
1616 assert(phis.size() == 1 && "expected one allocation to be yielded");
1617
1618 // clear init region block argument mapping in case it needs to be
1619 // re-created with a different source for another use of the same
1620 // reduction decl
1621 moduleTranslation.forgetMapping(initRegion);
1622
1623 // Prefer the value yielded from the init region to the allocated private
1624 // variable in case the region is operating on arguments by-value (e.g.
1625 // Fortran character boxes).
1626 return phis[0];
1627}
1628
1629/// Version of initPrivateVar which looks up the nonPrivateVar from mlirPrivVar.
1631 llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation,
1632 omp::PrivateClauseOp &privDecl, Value mlirPrivVar, BlockArgument &blockArg,
1633 llvm::Value *llvmPrivateVar, llvm::BasicBlock *privInitBlock,
1634 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1635 return initPrivateVar(
1636 builder, moduleTranslation, privDecl,
1637 findAssociatedValue(mlirPrivVar, builder, moduleTranslation,
1638 mappedPrivateVars),
1639 blockArg, llvmPrivateVar, privInitBlock, mappedPrivateVars);
1640}
1641
1642static llvm::Error
1643initPrivateVars(llvm::IRBuilderBase &builder,
1644 LLVM::ModuleTranslation &moduleTranslation,
1645 PrivateVarsInfo &privateVarsInfo,
1646 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1647 if (privateVarsInfo.blockArgs.empty())
1648 return llvm::Error::success();
1649
1650 llvm::BasicBlock *privInitBlock = splitBB(builder, true, "omp.private.init");
1651 setInsertPointForPossiblyEmptyBlock(builder, privInitBlock);
1652
1653 for (auto [idx, zip] : llvm::enumerate(llvm::zip_equal(
1654 privateVarsInfo.privatizers, privateVarsInfo.mlirVars,
1655 privateVarsInfo.blockArgs, privateVarsInfo.llvmVars))) {
1656 auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVar] = zip;
1658 builder, moduleTranslation, privDecl, mlirPrivVar, blockArg,
1659 llvmPrivateVar, privInitBlock, mappedPrivateVars);
1660
1661 if (!privVarOrErr)
1662 return privVarOrErr.takeError();
1663
1664 llvmPrivateVar = privVarOrErr.get();
1665 moduleTranslation.mapValue(blockArg, llvmPrivateVar);
1666
1668 }
1669
1670 return llvm::Error::success();
1671}
1672
1673/// Allocate and initialize delayed private variables. Returns the basic block
1674/// which comes after all of these allocations. llvm::Value * for each of these
1675/// private variables are populated in llvmPrivateVars.
1677allocatePrivateVars(llvm::IRBuilderBase &builder,
1678 LLVM::ModuleTranslation &moduleTranslation,
1679 PrivateVarsInfo &privateVarsInfo,
1680 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1681 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1682 // Allocate private vars
1683 llvm::Instruction *allocaTerminator = allocaIP.getBlock()->getTerminator();
1684 splitBB(llvm::OpenMPIRBuilder::InsertPointTy(allocaIP.getBlock(),
1685 allocaTerminator->getIterator()),
1686 true, allocaTerminator->getStableDebugLoc(),
1687 "omp.region.after_alloca");
1688
1689 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1690 // Update the allocaTerminator since the alloca block was split above.
1691 allocaTerminator = allocaIP.getBlock()->getTerminator();
1692 builder.SetInsertPoint(allocaTerminator);
1693 // The new terminator is an uncondition branch created by the splitBB above.
1694 assert(allocaTerminator->getNumSuccessors() == 1 &&
1695 "This is an unconditional branch created by splitBB");
1696
1697 llvm::DataLayout dataLayout = builder.GetInsertBlock()->getDataLayout();
1698 llvm::BasicBlock *afterAllocas = allocaTerminator->getSuccessor(0);
1699
1700 unsigned int allocaAS =
1701 moduleTranslation.getLLVMModule()->getDataLayout().getAllocaAddrSpace();
1702 unsigned int defaultAS = moduleTranslation.getLLVMModule()
1703 ->getDataLayout()
1704 .getProgramAddressSpace();
1705
1706 for (auto [privDecl, mlirPrivVar, blockArg] :
1707 llvm::zip_equal(privateVarsInfo.privatizers, privateVarsInfo.mlirVars,
1708 privateVarsInfo.blockArgs)) {
1709 llvm::Type *llvmAllocType =
1710 moduleTranslation.convertType(privDecl.getType());
1711 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1712 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
1713 llvmAllocType, /*ArraySize=*/nullptr, "omp.private.alloc");
1714 if (allocaAS != defaultAS)
1715 llvmPrivateVar = builder.CreateAddrSpaceCast(llvmPrivateVar,
1716 builder.getPtrTy(defaultAS));
1717
1718 privateVarsInfo.llvmVars.push_back(llvmPrivateVar);
1719 }
1720
1721 return afterAllocas;
1722}
1723
1724/// This can't always be determined statically, but when we can, it is good to
1725/// avoid generating compiler-added barriers which will deadlock the program.
1727 for (mlir::Operation *parent = op->getParentOp(); parent != nullptr;
1728 parent = parent->getParentOp()) {
1729 if (mlir::isa<omp::SingleOp, omp::CriticalOp>(parent))
1730 return true;
1731
1732 // e.g.
1733 // omp.single {
1734 // omp.parallel {
1735 // op
1736 // }
1737 // }
1738 if (mlir::isa<omp::ParallelOp>(parent))
1739 return false;
1740 }
1741 return false;
1742}
1743
1744static LogicalResult copyFirstPrivateVars(
1745 mlir::Operation *op, llvm::IRBuilderBase &builder,
1746 LLVM::ModuleTranslation &moduleTranslation,
1748 ArrayRef<llvm::Value *> llvmPrivateVars,
1749 SmallVectorImpl<omp::PrivateClauseOp> &privateDecls, bool insertBarrier,
1750 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1751 // Apply copy region for firstprivate.
1752 bool needsFirstprivate =
1753 llvm::any_of(privateDecls, [](omp::PrivateClauseOp &privOp) {
1754 return privOp.getDataSharingType() ==
1755 omp::DataSharingClauseType::FirstPrivate;
1756 });
1757
1758 if (!needsFirstprivate)
1759 return success();
1760
1761 llvm::BasicBlock *copyBlock =
1762 splitBB(builder, /*CreateBranch=*/true, "omp.private.copy");
1763 setInsertPointForPossiblyEmptyBlock(builder, copyBlock);
1764
1765 for (auto [decl, moldVar, llvmVar] :
1766 llvm::zip_equal(privateDecls, moldVars, llvmPrivateVars)) {
1767 if (decl.getDataSharingType() != omp::DataSharingClauseType::FirstPrivate)
1768 continue;
1769
1770 // copyRegion implements `lhs = rhs`
1771 Region &copyRegion = decl.getCopyRegion();
1772
1773 moduleTranslation.mapValue(decl.getCopyMoldArg(), moldVar);
1774
1775 // map copyRegion lhs arg
1776 moduleTranslation.mapValue(decl.getCopyPrivateArg(), llvmVar);
1777
1778 // in-place convert copy region
1779 if (failed(inlineConvertOmpRegions(copyRegion, "omp.private.copy", builder,
1780 moduleTranslation)))
1781 return decl.emitError("failed to inline `copy` region of `omp.private`");
1782
1784
1785 // ignore unused value yielded from copy region
1786
1787 // clear copy region block argument mapping in case it needs to be
1788 // re-created with different sources for reuse of the same reduction
1789 // decl
1790 moduleTranslation.forgetMapping(copyRegion);
1791 }
1792
1793 if (insertBarrier && !opIsInSingleThread(op)) {
1794 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
1795 llvm::OpenMPIRBuilder::InsertPointOrErrorTy res =
1796 ompBuilder->createBarrier(builder.saveIP(), llvm::omp::OMPD_barrier);
1797 if (failed(handleError(res, *op)))
1798 return failure();
1799 }
1800
1801 return success();
1802}
1803
1804static LogicalResult copyFirstPrivateVars(
1805 mlir::Operation *op, llvm::IRBuilderBase &builder,
1806 LLVM::ModuleTranslation &moduleTranslation,
1807 SmallVectorImpl<mlir::Value> &mlirPrivateVars,
1808 ArrayRef<llvm::Value *> llvmPrivateVars,
1809 SmallVectorImpl<omp::PrivateClauseOp> &privateDecls, bool insertBarrier,
1810 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1811 llvm::SmallVector<llvm::Value *> moldVars(mlirPrivateVars.size());
1812 llvm::transform(mlirPrivateVars, moldVars.begin(), [&](mlir::Value mlirVar) {
1813 // map copyRegion rhs arg
1814 llvm::Value *moldVar = findAssociatedValue(
1815 mlirVar, builder, moduleTranslation, mappedPrivateVars);
1816 assert(moldVar);
1817 return moldVar;
1818 });
1819 return copyFirstPrivateVars(op, builder, moduleTranslation, moldVars,
1820 llvmPrivateVars, privateDecls, insertBarrier,
1821 mappedPrivateVars);
1822}
1823
1824static LogicalResult
1825cleanupPrivateVars(llvm::IRBuilderBase &builder,
1826 LLVM::ModuleTranslation &moduleTranslation, Location loc,
1827 SmallVectorImpl<llvm::Value *> &llvmPrivateVars,
1829 // private variable deallocation
1830 SmallVector<Region *> privateCleanupRegions;
1831 llvm::transform(privateDecls, std::back_inserter(privateCleanupRegions),
1832 [](omp::PrivateClauseOp privatizer) {
1833 return &privatizer.getDeallocRegion();
1834 });
1835
1836 if (failed(inlineOmpRegionCleanup(
1837 privateCleanupRegions, llvmPrivateVars, moduleTranslation, builder,
1838 "omp.private.dealloc", /*shouldLoadCleanupRegionArg=*/false)))
1839 return mlir::emitError(loc, "failed to inline `dealloc` region of an "
1840 "`omp.private` op in");
1841
1842 return success();
1843}
1844
1845/// Returns true if the construct contains omp.cancel or omp.cancellation_point
1847 // omp.cancel and omp.cancellation_point must be "closely nested" so they will
1848 // be visible and not inside of function calls. This is enforced by the
1849 // verifier.
1850 return op
1851 ->walk([](Operation *child) {
1852 if (mlir::isa<omp::CancelOp, omp::CancellationPointOp>(child))
1853 return WalkResult::interrupt();
1854 return WalkResult::advance();
1855 })
1856 .wasInterrupted();
1857}
1858
1859static LogicalResult
1860convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder,
1861 LLVM::ModuleTranslation &moduleTranslation) {
1862 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1863 using StorableBodyGenCallbackTy =
1864 llvm::OpenMPIRBuilder::StorableBodyGenCallbackTy;
1865
1866 auto sectionsOp = cast<omp::SectionsOp>(opInst);
1867
1868 if (failed(checkImplementationStatus(opInst)))
1869 return failure();
1870
1871 llvm::ArrayRef<bool> isByRef = getIsByRef(sectionsOp.getReductionByref());
1872 assert(isByRef.size() == sectionsOp.getNumReductionVars());
1873
1875 collectReductionDecls(sectionsOp, reductionDecls);
1876 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1877 findAllocaInsertPoint(builder, moduleTranslation);
1878
1879 SmallVector<llvm::Value *> privateReductionVariables(
1880 sectionsOp.getNumReductionVars());
1881 DenseMap<Value, llvm::Value *> reductionVariableMap;
1882
1883 MutableArrayRef<BlockArgument> reductionArgs =
1884 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
1885
1887 sectionsOp, reductionArgs, builder, moduleTranslation, allocaIP,
1888 reductionDecls, privateReductionVariables, reductionVariableMap,
1889 isByRef)))
1890 return failure();
1891
1893
1894 for (Operation &op : *sectionsOp.getRegion().begin()) {
1895 auto sectionOp = dyn_cast<omp::SectionOp>(op);
1896 if (!sectionOp) // omp.terminator
1897 continue;
1898
1899 Region &region = sectionOp.getRegion();
1900 auto sectionCB = [&sectionsOp, &region, &builder, &moduleTranslation](
1901 InsertPointTy allocaIP, InsertPointTy codeGenIP) {
1902 builder.restoreIP(codeGenIP);
1903
1904 // map the omp.section reduction block argument to the omp.sections block
1905 // arguments
1906 // TODO: this assumes that the only block arguments are reduction
1907 // variables
1908 assert(region.getNumArguments() ==
1909 sectionsOp.getRegion().getNumArguments());
1910 for (auto [sectionsArg, sectionArg] : llvm::zip_equal(
1911 sectionsOp.getRegion().getArguments(), region.getArguments())) {
1912 llvm::Value *llvmVal = moduleTranslation.lookupValue(sectionsArg);
1913 assert(llvmVal);
1914 moduleTranslation.mapValue(sectionArg, llvmVal);
1915 }
1916
1917 return convertOmpOpRegions(region, "omp.section.region", builder,
1918 moduleTranslation)
1919 .takeError();
1920 };
1921 sectionCBs.push_back(sectionCB);
1922 }
1923
1924 // No sections within omp.sections operation - skip generation. This situation
1925 // is only possible if there is only a terminator operation inside the
1926 // sections operation
1927 if (sectionCBs.empty())
1928 return success();
1929
1930 assert(isa<omp::SectionOp>(*sectionsOp.getRegion().op_begin()));
1931
1932 // TODO: Perform appropriate actions according to the data-sharing
1933 // attribute (shared, private, firstprivate, ...) of variables.
1934 // Currently defaults to shared.
1935 auto privCB = [&](InsertPointTy, InsertPointTy codeGenIP, llvm::Value &,
1936 llvm::Value &vPtr, llvm::Value *&replacementValue)
1937 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1938 replacementValue = &vPtr;
1939 return codeGenIP;
1940 };
1941
1942 // TODO: Perform finalization actions for variables. This has to be
1943 // called for variables which have destructors/finalizers.
1944 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
1945
1946 allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
1947 bool isCancellable = constructIsCancellable(sectionsOp);
1948 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1949 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1950 moduleTranslation.getOpenMPBuilder()->createSections(
1951 ompLoc, allocaIP, sectionCBs, privCB, finiCB, isCancellable,
1952 sectionsOp.getNowait());
1953
1954 if (failed(handleError(afterIP, opInst)))
1955 return failure();
1956
1957 builder.restoreIP(*afterIP);
1958
1959 // Process the reductions if required.
1961 sectionsOp, builder, moduleTranslation, allocaIP, reductionDecls,
1962 privateReductionVariables, isByRef, sectionsOp.getNowait());
1963}
1964
1965/// Converts an OpenMP single construct into LLVM IR using OpenMPIRBuilder.
1966static LogicalResult
1967convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder,
1968 LLVM::ModuleTranslation &moduleTranslation) {
1969 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1970 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1971
1972 if (failed(checkImplementationStatus(*singleOp)))
1973 return failure();
1974
1975 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
1976 builder.restoreIP(codegenIP);
1977 return convertOmpOpRegions(singleOp.getRegion(), "omp.single.region",
1978 builder, moduleTranslation)
1979 .takeError();
1980 };
1981 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
1982
1983 // Handle copyprivate
1984 Operation::operand_range cpVars = singleOp.getCopyprivateVars();
1985 std::optional<ArrayAttr> cpFuncs = singleOp.getCopyprivateSyms();
1988 for (size_t i = 0, e = cpVars.size(); i < e; ++i) {
1989 llvmCPVars.push_back(moduleTranslation.lookupValue(cpVars[i]));
1991 singleOp, cast<SymbolRefAttr>((*cpFuncs)[i]));
1992 llvmCPFuncs.push_back(
1993 moduleTranslation.lookupFunction(llvmFuncOp.getName()));
1994 }
1995
1996 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1997 moduleTranslation.getOpenMPBuilder()->createSingle(
1998 ompLoc, bodyCB, finiCB, singleOp.getNowait(), llvmCPVars,
1999 llvmCPFuncs);
2000
2001 if (failed(handleError(afterIP, *singleOp)))
2002 return failure();
2003
2004 builder.restoreIP(*afterIP);
2005 return success();
2006}
2007
2008static bool teamsReductionContainedInDistribute(omp::TeamsOp teamsOp) {
2009 auto iface =
2010 llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(teamsOp.getOperation());
2011 // Check that all uses of the reduction block arg has the same distribute op
2012 // parent.
2014 Operation *distOp = nullptr;
2015 for (auto ra : iface.getReductionBlockArgs())
2016 for (auto &use : ra.getUses()) {
2017 auto *useOp = use.getOwner();
2018 // Ignore debug uses.
2019 if (mlir::isa<LLVM::DbgDeclareOp, LLVM::DbgValueOp>(useOp)) {
2020 debugUses.push_back(useOp);
2021 continue;
2022 }
2023
2024 auto currentDistOp = useOp->getParentOfType<omp::DistributeOp>();
2025 // Use is not inside a distribute op - return false
2026 if (!currentDistOp)
2027 return false;
2028 // Multiple distribute operations - return false
2029 Operation *currentOp = currentDistOp.getOperation();
2030 if (distOp && (distOp != currentOp))
2031 return false;
2032
2033 distOp = currentOp;
2034 }
2035
2036 // If we are going to use distribute reduction then remove any debug uses of
2037 // the reduction parameters in teamsOp. Otherwise they will be left without
2038 // any mapped value in moduleTranslation and will eventually error out.
2039 for (auto *use : debugUses)
2040 use->erase();
2041 return true;
2042}
2043
2044// Convert an OpenMP Teams construct to LLVM IR using OpenMPIRBuilder
2045static LogicalResult
2046convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder,
2047 LLVM::ModuleTranslation &moduleTranslation) {
2048 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2049 if (failed(checkImplementationStatus(*op)))
2050 return failure();
2051
2052 DenseMap<Value, llvm::Value *> reductionVariableMap;
2053 unsigned numReductionVars = op.getNumReductionVars();
2055 SmallVector<llvm::Value *> privateReductionVariables(numReductionVars);
2056 llvm::ArrayRef<bool> isByRef;
2057 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2058 findAllocaInsertPoint(builder, moduleTranslation);
2059
2060 // Only do teams reduction if there is no distribute op that captures the
2061 // reduction instead.
2062 bool doTeamsReduction = !teamsReductionContainedInDistribute(op);
2063 if (doTeamsReduction) {
2064 isByRef = getIsByRef(op.getReductionByref());
2065
2066 assert(isByRef.size() == op.getNumReductionVars());
2067
2068 MutableArrayRef<BlockArgument> reductionArgs =
2069 llvm::cast<omp::BlockArgOpenMPOpInterface>(*op).getReductionBlockArgs();
2070
2071 collectReductionDecls(op, reductionDecls);
2072
2074 op, reductionArgs, builder, moduleTranslation, allocaIP,
2075 reductionDecls, privateReductionVariables, reductionVariableMap,
2076 isByRef)))
2077 return failure();
2078 }
2079
2080 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
2082 moduleTranslation, allocaIP);
2083 builder.restoreIP(codegenIP);
2084 return convertOmpOpRegions(op.getRegion(), "omp.teams.region", builder,
2085 moduleTranslation)
2086 .takeError();
2087 };
2088
2089 llvm::Value *numTeamsLower = nullptr;
2090 if (Value numTeamsLowerVar = op.getNumTeamsLower())
2091 numTeamsLower = moduleTranslation.lookupValue(numTeamsLowerVar);
2092
2093 llvm::Value *numTeamsUpper = nullptr;
2094 if (!op.getNumTeamsUpperVars().empty())
2095 numTeamsUpper = moduleTranslation.lookupValue(op.getNumTeams(0));
2096
2097 llvm::Value *threadLimit = nullptr;
2098 if (!op.getThreadLimitVars().empty())
2099 threadLimit = moduleTranslation.lookupValue(op.getThreadLimit(0));
2100
2101 llvm::Value *ifExpr = nullptr;
2102 if (Value ifVar = op.getIfExpr())
2103 ifExpr = moduleTranslation.lookupValue(ifVar);
2104
2105 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2106 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2107 moduleTranslation.getOpenMPBuilder()->createTeams(
2108 ompLoc, bodyCB, numTeamsLower, numTeamsUpper, threadLimit, ifExpr);
2109
2110 if (failed(handleError(afterIP, *op)))
2111 return failure();
2112
2113 builder.restoreIP(*afterIP);
2114 if (doTeamsReduction) {
2115 // Process the reductions if required.
2117 op, builder, moduleTranslation, allocaIP, reductionDecls,
2118 privateReductionVariables, isByRef,
2119 /*isNoWait*/ false, /*isTeamsReduction*/ true);
2120 }
2121 return success();
2122}
2123
2124static void
2125buildDependData(std::optional<ArrayAttr> dependKinds, OperandRange dependVars,
2126 LLVM::ModuleTranslation &moduleTranslation,
2128 if (dependVars.empty())
2129 return;
2130 for (auto dep : llvm::zip(dependVars, dependKinds->getValue())) {
2131 llvm::omp::RTLDependenceKindTy type;
2132 switch (
2133 cast<mlir::omp::ClauseTaskDependAttr>(std::get<1>(dep)).getValue()) {
2134 case mlir::omp::ClauseTaskDepend::taskdependin:
2135 type = llvm::omp::RTLDependenceKindTy::DepIn;
2136 break;
2137 // The OpenMP runtime requires that the codegen for 'depend' clause for
2138 // 'out' dependency kind must be the same as codegen for 'depend' clause
2139 // with 'inout' dependency.
2140 case mlir::omp::ClauseTaskDepend::taskdependout:
2141 case mlir::omp::ClauseTaskDepend::taskdependinout:
2142 type = llvm::omp::RTLDependenceKindTy::DepInOut;
2143 break;
2144 case mlir::omp::ClauseTaskDepend::taskdependmutexinoutset:
2145 type = llvm::omp::RTLDependenceKindTy::DepMutexInOutSet;
2146 break;
2147 case mlir::omp::ClauseTaskDepend::taskdependinoutset:
2148 type = llvm::omp::RTLDependenceKindTy::DepInOutSet;
2149 break;
2150 };
2151 llvm::Value *depVal = moduleTranslation.lookupValue(std::get<0>(dep));
2152 llvm::OpenMPIRBuilder::DependData dd(type, depVal->getType(), depVal);
2153 dds.emplace_back(dd);
2154 }
2155}
2156
2157/// Shared implementation of a callback which adds a termiator for the new block
2158/// created for the branch taken when an openmp construct is cancelled. The
2159/// terminator is saved in \p cancelTerminators. This callback is invoked only
2160/// if there is cancellation inside of the taskgroup body.
2161/// The terminator will need to be fixed to branch to the correct block to
2162/// cleanup the construct.
2164 SmallVectorImpl<llvm::UncondBrInst *> &cancelTerminators,
2165 llvm::IRBuilderBase &llvmBuilder, llvm::OpenMPIRBuilder &ompBuilder,
2166 mlir::Operation *op, llvm::omp::Directive cancelDirective) {
2167 auto finiCB = [&](llvm::OpenMPIRBuilder::InsertPointTy ip) -> llvm::Error {
2168 llvm::IRBuilderBase::InsertPointGuard guard(llvmBuilder);
2169
2170 // ip is currently in the block branched to if cancellation occurred.
2171 // We need to create a branch to terminate that block.
2172 llvmBuilder.restoreIP(ip);
2173
2174 // We must still clean up the construct after cancelling it, so we need to
2175 // branch to the block that finalizes the taskgroup.
2176 // That block has not been created yet so use this block as a dummy for now
2177 // and fix this after creating the operation.
2178 cancelTerminators.push_back(llvmBuilder.CreateBr(ip.getBlock()));
2179 return llvm::Error::success();
2180 };
2181 // We have to add the cleanup to the OpenMPIRBuilder before the body gets
2182 // created in case the body contains omp.cancel (which will then expect to be
2183 // able to find this cleanup callback).
2184 ompBuilder.pushFinalizationCB(
2185 {finiCB, cancelDirective, constructIsCancellable(op)});
2186}
2187
2188/// If we cancelled the construct, we should branch to the finalization block of
2189/// that construct. OMPIRBuilder structures the CFG such that the cleanup block
2190/// is immediately before the continuation block. Now this finalization has
2191/// been created we can fix the branch.
2192static void
2194 llvm::OpenMPIRBuilder &ompBuilder,
2195 const llvm::OpenMPIRBuilder::InsertPointTy &afterIP) {
2196 ompBuilder.popFinalizationCB();
2197 llvm::BasicBlock *constructFini = afterIP.getBlock()->getSinglePredecessor();
2198 for (llvm::UncondBrInst *cancelBranch : cancelTerminators)
2199 cancelBranch->setSuccessor(constructFini);
2200}
2201
2202namespace {
2203/// TaskContextStructManager takes care of creating and freeing a structure
2204/// containing information needed by the task body to execute.
2205class TaskContextStructManager {
2206public:
2207 TaskContextStructManager(llvm::IRBuilderBase &builder,
2208 LLVM::ModuleTranslation &moduleTranslation,
2209 MutableArrayRef<omp::PrivateClauseOp> privateDecls)
2210 : builder{builder}, moduleTranslation{moduleTranslation},
2211 privateDecls{privateDecls} {}
2212
2213 /// Creates a heap allocated struct containing space for each private
2214 /// variable. Invariant: privateVarTypes, privateDecls, and the elements of
2215 /// the structure should all have the same order (although privateDecls which
2216 /// do not read from the mold argument are skipped).
2217 void generateTaskContextStruct();
2218
2219 /// Create GEPs to access each member of the structure representing a private
2220 /// variable, adding them to llvmPrivateVars. Null values are added where
2221 /// private decls were skipped so that the ordering continues to match the
2222 /// private decls.
2223 void createGEPsToPrivateVars();
2224
2225 /// Given the address of the structure, return a GEP for each private variable
2226 /// in the structure. Null values are added where private decls were skipped
2227 /// so that the ordering continues to match the private decls.
2228 /// Must be called after generateTaskContextStruct().
2229 SmallVector<llvm::Value *>
2230 createGEPsToPrivateVars(llvm::Value *altStructPtr) const;
2231
2232 /// De-allocate the task context structure.
2233 void freeStructPtr();
2234
2235 MutableArrayRef<llvm::Value *> getLLVMPrivateVarGEPs() {
2236 return llvmPrivateVarGEPs;
2237 }
2238
2239 llvm::Value *getStructPtr() { return structPtr; }
2240
2241private:
2242 llvm::IRBuilderBase &builder;
2243 LLVM::ModuleTranslation &moduleTranslation;
2244 MutableArrayRef<omp::PrivateClauseOp> privateDecls;
2245
2246 /// The type of each member of the structure, in order.
2247 SmallVector<llvm::Type *> privateVarTypes;
2248
2249 /// LLVM values for each private variable, or null if that private variable is
2250 /// not included in the task context structure
2251 SmallVector<llvm::Value *> llvmPrivateVarGEPs;
2252
2253 /// A pointer to the structure containing context for this task.
2254 llvm::Value *structPtr = nullptr;
2255 /// The type of the structure
2256 llvm::Type *structTy = nullptr;
2257};
2258
2259/// IteratorInfo extracts and prepares loop bounds information from an
2260/// mlir::omp::IteratorOp for lowering to LLVM IR.
2261///
2262/// It computes the per-dimension trip counts and the total linearized trip
2263/// count, casted to i64. These are used to build a canonical loop and to
2264/// reconstruct the physical induction variables inside the loop body.
2265class IteratorInfo {
2266private:
2267 llvm::SmallVector<llvm::Value *> lowerBounds;
2268 llvm::SmallVector<llvm::Value *> upperBounds;
2269 llvm::SmallVector<llvm::Value *> steps;
2270 llvm::SmallVector<llvm::Value *> trips;
2271 unsigned dims;
2272 llvm::Value *totalTrips;
2273
2274 llvm::Value *lookUpAsI64(mlir::Value val, const LLVM::ModuleTranslation &mt,
2275 llvm::IRBuilderBase &builder) {
2276 llvm::Value *v = mt.lookupValue(val);
2277 if (!v)
2278 return nullptr;
2279 if (v->getType()->isIntegerTy(64))
2280 return v;
2281 if (v->getType()->isIntegerTy())
2282 return builder.CreateSExtOrTrunc(v, builder.getInt64Ty());
2283 return nullptr;
2284 }
2285
2286public:
2287 IteratorInfo(mlir::omp::IteratorOp itersOp,
2288 mlir::LLVM::ModuleTranslation &moduleTranslation,
2289 llvm::IRBuilderBase &builder) {
2290 dims = itersOp.getLoopLowerBounds().size();
2291 lowerBounds.resize(dims);
2292 upperBounds.resize(dims);
2293 steps.resize(dims);
2294 trips.resize(dims);
2295
2296 for (unsigned d = 0; d < dims; ++d) {
2297 llvm::Value *lb = lookUpAsI64(itersOp.getLoopLowerBounds()[d],
2298 moduleTranslation, builder);
2299 llvm::Value *ub = lookUpAsI64(itersOp.getLoopUpperBounds()[d],
2300 moduleTranslation, builder);
2301 llvm::Value *st =
2302 lookUpAsI64(itersOp.getLoopSteps()[d], moduleTranslation, builder);
2303 assert(lb && ub && st &&
2304 "Expect lowerBounds, upperBounds, and steps in IteratorOp");
2305 assert((!llvm::isa<llvm::ConstantInt>(st) ||
2306 !llvm::cast<llvm::ConstantInt>(st)->isZero()) &&
2307 "Expect non-zero step in IteratorOp");
2308
2309 lowerBounds[d] = lb;
2310 upperBounds[d] = ub;
2311 steps[d] = st;
2312
2313 // trips = ((ub - lb) / step) + 1 (inclusive ub, assume positive step)
2314 llvm::Value *diff = builder.CreateSub(ub, lb);
2315 llvm::Value *div = builder.CreateSDiv(diff, st);
2316 trips[d] = builder.CreateAdd(
2317 div, llvm::ConstantInt::get(builder.getInt64Ty(), 1));
2318 }
2319
2320 totalTrips = llvm::ConstantInt::get(builder.getInt64Ty(), 1);
2321 for (unsigned d = 0; d < dims; ++d)
2322 totalTrips = builder.CreateMul(totalTrips, trips[d]);
2323 }
2324
2325 unsigned getDims() const { return dims; }
2326 llvm::ArrayRef<llvm::Value *> getLowerBounds() const { return lowerBounds; }
2327 llvm::ArrayRef<llvm::Value *> getUpperBounds() const { return upperBounds; }
2328 llvm::ArrayRef<llvm::Value *> getSteps() const { return steps; }
2329 llvm::ArrayRef<llvm::Value *> getTrips() const { return trips; }
2330 llvm::Value *getTotalTrips() const { return totalTrips; }
2331};
2332
2333} // namespace
2334
2335void TaskContextStructManager::generateTaskContextStruct() {
2336 if (privateDecls.empty())
2337 return;
2338 privateVarTypes.reserve(privateDecls.size());
2339
2340 for (omp::PrivateClauseOp &privOp : privateDecls) {
2341 // Skip private variables which can safely be allocated and initialised
2342 // inside of the task
2343 if (!privOp.readsFromMold())
2344 continue;
2345 Type mlirType = privOp.getType();
2346 privateVarTypes.push_back(moduleTranslation.convertType(mlirType));
2347 }
2348
2349 if (privateVarTypes.empty())
2350 return;
2351
2352 structTy = llvm::StructType::get(moduleTranslation.getLLVMContext(),
2353 privateVarTypes);
2354
2355 llvm::DataLayout dataLayout =
2356 builder.GetInsertBlock()->getModule()->getDataLayout();
2357 llvm::Type *intPtrTy = builder.getIntPtrTy(dataLayout);
2358 llvm::Constant *allocSize = llvm::ConstantExpr::getSizeOf(structTy);
2359
2360 // Heap allocate the structure
2361 structPtr = builder.CreateMalloc(intPtrTy, structTy, allocSize,
2362 /*ArraySize=*/nullptr, /*MallocF=*/nullptr,
2363 "omp.task.context_ptr");
2364}
2365
2366SmallVector<llvm::Value *> TaskContextStructManager::createGEPsToPrivateVars(
2367 llvm::Value *altStructPtr) const {
2368 SmallVector<llvm::Value *> ret;
2369
2370 // Create GEPs for each struct member
2371 ret.reserve(privateDecls.size());
2372 llvm::Value *zero = builder.getInt32(0);
2373 unsigned i = 0;
2374 for (auto privDecl : privateDecls) {
2375 if (!privDecl.readsFromMold()) {
2376 // Handle this inside of the task so we don't pass unnessecary vars in
2377 ret.push_back(nullptr);
2378 continue;
2379 }
2380 llvm::Value *iVal = builder.getInt32(i);
2381 llvm::Value *gep = builder.CreateGEP(structTy, altStructPtr, {zero, iVal});
2382 ret.push_back(gep);
2383 i += 1;
2384 }
2385 return ret;
2386}
2387
2388void TaskContextStructManager::createGEPsToPrivateVars() {
2389 if (!structPtr)
2390 assert(privateVarTypes.empty());
2391 // Still need to run createGEPsToPrivateVars to populate llvmPrivateVarGEPs
2392 // with null values for skipped private decls
2393
2394 llvmPrivateVarGEPs = createGEPsToPrivateVars(structPtr);
2395}
2396
2397void TaskContextStructManager::freeStructPtr() {
2398 if (!structPtr)
2399 return;
2400
2401 llvm::IRBuilderBase::InsertPointGuard guard{builder};
2402 // Ensure we don't put the call to free() after the terminator
2403 builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
2404 builder.CreateFree(structPtr);
2405}
2406
2407static void storeAffinityEntry(llvm::IRBuilderBase &builder,
2408 llvm::OpenMPIRBuilder &ompBuilder,
2409 llvm::Value *affinityList, llvm::Value *index,
2410 llvm::Value *addr, llvm::Value *len) {
2411 llvm::StructType *kmpTaskAffinityInfoTy =
2412 ompBuilder.getKmpTaskAffinityInfoTy();
2413 llvm::Value *entry = builder.CreateInBoundsGEP(
2414 kmpTaskAffinityInfoTy, affinityList, index, "omp.affinity.entry");
2415
2416 addr = builder.CreatePtrToInt(addr, kmpTaskAffinityInfoTy->getElementType(0));
2417 len = builder.CreateIntCast(len, kmpTaskAffinityInfoTy->getElementType(1),
2418 /*isSigned=*/false);
2419 llvm::Value *flags = builder.getInt32(0);
2420
2421 builder.CreateStore(addr,
2422 builder.CreateStructGEP(kmpTaskAffinityInfoTy, entry, 0));
2423 builder.CreateStore(len,
2424 builder.CreateStructGEP(kmpTaskAffinityInfoTy, entry, 1));
2425 builder.CreateStore(flags,
2426 builder.CreateStructGEP(kmpTaskAffinityInfoTy, entry, 2));
2427}
2428
2430 llvm::IRBuilderBase &builder,
2431 LLVM::ModuleTranslation &moduleTranslation,
2432 llvm::Value *affinityList) {
2433 for (auto [i, affinityVar] : llvm::enumerate(affinityVars)) {
2434 auto entryOp = affinityVar.getDefiningOp<mlir::omp::AffinityEntryOp>();
2435 assert(entryOp && "affinity item must be omp.affinity_entry");
2436
2437 llvm::Value *addr = moduleTranslation.lookupValue(entryOp.getAddr());
2438 llvm::Value *len = moduleTranslation.lookupValue(entryOp.getLen());
2439 assert(addr && len && "expect affinity addr and len to be non-null");
2440 storeAffinityEntry(builder, *moduleTranslation.getOpenMPBuilder(),
2441 affinityList, builder.getInt64(i), addr, len);
2442 }
2443}
2444
2445static mlir::LogicalResult
2446convertIteratorRegion(llvm::Value *linearIV, IteratorInfo &iterInfo,
2447 mlir::Block &iteratorRegionBlock,
2448 llvm::IRBuilderBase &builder,
2449 LLVM::ModuleTranslation &moduleTranslation) {
2450 llvm::Value *tmp = linearIV;
2451 for (int d = (int)iterInfo.getDims() - 1; d >= 0; --d) {
2452 llvm::Value *trip = iterInfo.getTrips()[d];
2453 // idx_d = tmp % trip_d
2454 llvm::Value *idx = builder.CreateURem(tmp, trip);
2455 // tmp = tmp / trip_d
2456 tmp = builder.CreateUDiv(tmp, trip);
2457
2458 // physIV_d = lb_d + idx_d * step_d
2459 llvm::Value *physIV = builder.CreateAdd(
2460 iterInfo.getLowerBounds()[d],
2461 builder.CreateMul(idx, iterInfo.getSteps()[d]), "omp.it.phys_iv");
2462
2463 moduleTranslation.mapValue(iteratorRegionBlock.getArgument(d), physIV);
2464 }
2465
2466 // Translate the iterator region into the loop body.
2467 moduleTranslation.mapBlock(&iteratorRegionBlock, builder.GetInsertBlock());
2468 if (mlir::failed(moduleTranslation.convertBlock(iteratorRegionBlock,
2469 /*ignoreArguments=*/true,
2470 builder))) {
2471 return mlir::failure();
2472 }
2473 return mlir::success();
2474}
2475
2477 llvm::function_ref<void(llvm::Value *linearIV, mlir::omp::YieldOp yield)>;
2478
2479static mlir::LogicalResult
2480fillIteratorLoop(mlir::omp::IteratorOp itersOp, llvm::IRBuilderBase &builder,
2481 mlir::LLVM::ModuleTranslation &moduleTranslation,
2482 IteratorInfo &iterInfo, llvm::StringRef loopName,
2483 IteratorStoreEntryTy genStoreEntry) {
2484 mlir::Region &itersRegion = itersOp.getRegion();
2485 mlir::Block &iteratorRegionBlock = itersRegion.front();
2486
2487 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
2488
2489 auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy bodyIP,
2490 llvm::Value *linearIV) -> llvm::Error {
2491 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2492 builder.restoreIP(bodyIP);
2493
2494 if (failed(convertIteratorRegion(linearIV, iterInfo, iteratorRegionBlock,
2495 builder, moduleTranslation))) {
2496 return llvm::make_error<llvm::StringError>(
2497 "failed to convert iterator region", llvm::inconvertibleErrorCode());
2498 }
2499
2500 auto yield =
2501 mlir::dyn_cast<mlir::omp::YieldOp>(iteratorRegionBlock.getTerminator());
2502 assert(yield && yield.getResults().size() == 1 &&
2503 "expect omp.yield in iterator region to have one result");
2504
2505 genStoreEntry(linearIV, yield);
2506
2507 // Iterator-region block/value mappings are temporary for this conversion,
2508 // clear them to avoid stale entries in ModuleTranslation.
2509 moduleTranslation.forgetMapping(itersRegion);
2510
2511 return llvm::Error::success();
2512 };
2513
2514 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2515 moduleTranslation.getOpenMPBuilder()->createIteratorLoop(
2516 loc, iterInfo.getTotalTrips(), bodyGen, loopName);
2517 if (failed(handleError(afterIP, *itersOp)))
2518 return failure();
2519
2520 builder.restoreIP(*afterIP);
2521
2522 return mlir::success();
2523}
2524
2525static mlir::LogicalResult
2526buildAffinityData(mlir::omp::TaskOp &taskOp, llvm::IRBuilderBase &builder,
2527 mlir::LLVM::ModuleTranslation &moduleTranslation,
2528 llvm::OpenMPIRBuilder::AffinityData &ad) {
2529
2530 if (taskOp.getAffinityVars().empty() && taskOp.getIterated().empty()) {
2531 ad.Count = nullptr;
2532 ad.Info = nullptr;
2533 return mlir::success();
2534 }
2535
2537 llvm::StructType *kmpTaskAffinityInfoTy =
2538 moduleTranslation.getOpenMPBuilder()->getKmpTaskAffinityInfoTy();
2539
2540 auto allocateAffinityList = [&](llvm::Value *count) -> llvm::Value * {
2541 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2542 if (llvm::isa<llvm::Constant>(count) || llvm::isa<llvm::Argument>(count))
2543 builder.restoreIP(findAllocaInsertPoint(builder, moduleTranslation));
2544 return builder.CreateAlloca(kmpTaskAffinityInfoTy, count,
2545 "omp.affinity_list");
2546 };
2547
2548 auto createAffinity =
2549 [&](llvm::Value *count,
2550 llvm::Value *info) -> llvm::OpenMPIRBuilder::AffinityData {
2551 llvm::OpenMPIRBuilder::AffinityData ad{};
2552 ad.Count = builder.CreateTrunc(count, builder.getInt32Ty());
2553 ad.Info =
2554 builder.CreatePointerBitCastOrAddrSpaceCast(info, builder.getPtrTy(0));
2555 return ad;
2556 };
2557
2558 if (!taskOp.getAffinityVars().empty()) {
2559 llvm::Value *count = llvm::ConstantInt::get(
2560 builder.getInt64Ty(), taskOp.getAffinityVars().size());
2561 llvm::Value *list = allocateAffinityList(count);
2562 fillAffinityLocators(taskOp.getAffinityVars(), builder, moduleTranslation,
2563 list);
2564 ads.emplace_back(createAffinity(count, list));
2565 }
2566
2567 if (!taskOp.getIterated().empty()) {
2568 for (auto [i, iter] : llvm::enumerate(taskOp.getIterated())) {
2569 auto itersOp = iter.getDefiningOp<omp::IteratorOp>();
2570 assert(itersOp && "iterated value must be defined by omp.iterator");
2571 IteratorInfo iterInfo(itersOp, moduleTranslation, builder);
2572 llvm::Value *affList = allocateAffinityList(iterInfo.getTotalTrips());
2573 if (failed(fillIteratorLoop(
2574 itersOp, builder, moduleTranslation, iterInfo, "iterator",
2575 [&](llvm::Value *linearIV, mlir::omp::YieldOp yield) {
2576 auto entryOp = yield.getResults()[0]
2577 .getDefiningOp<mlir::omp::AffinityEntryOp>();
2578 assert(entryOp && "expect yield produce an affinity entry");
2579 llvm::Value *addr =
2580 moduleTranslation.lookupValue(entryOp.getAddr());
2581 llvm::Value *len =
2582 moduleTranslation.lookupValue(entryOp.getLen());
2583 storeAffinityEntry(builder,
2584 *moduleTranslation.getOpenMPBuilder(),
2585 affList, linearIV, addr, len);
2586 })))
2587 return llvm::failure();
2588 ads.emplace_back(createAffinity(iterInfo.getTotalTrips(), affList));
2589 }
2590 }
2591
2592 llvm::Value *totalAffinityCount = builder.getInt32(0);
2593 for (const auto &affinity : ads)
2594 totalAffinityCount = builder.CreateAdd(
2595 totalAffinityCount,
2596 builder.CreateIntCast(affinity.Count, builder.getInt32Ty(),
2597 /*isSigned=*/false));
2598
2599 llvm::Value *affinityInfo = ads.front().Info;
2600 if (ads.size() > 1) {
2601 llvm::StructType *kmpTaskAffinityInfoTy =
2602 moduleTranslation.getOpenMPBuilder()->getKmpTaskAffinityInfoTy();
2603 llvm::Value *affinityInfoElemSize = builder.getInt64(
2604 moduleTranslation.getLLVMModule()->getDataLayout().getTypeAllocSize(
2605 kmpTaskAffinityInfoTy));
2606
2607 llvm::Value *packedAffinityInfo = allocateAffinityList(totalAffinityCount);
2608 llvm::Value *packedAffinityInfoOffset = builder.getInt32(0);
2609 for (const auto &affinity : ads) {
2610 llvm::Value *affinityCount = builder.CreateIntCast(
2611 affinity.Count, builder.getInt32Ty(), /*isSigned=*/false);
2612 llvm::Value *affinityCountInt64 = builder.CreateIntCast(
2613 affinityCount, builder.getInt64Ty(), /*isSigned=*/false);
2614 llvm::Value *affinityInfoSize =
2615 builder.CreateMul(affinityCountInt64, affinityInfoElemSize);
2616
2617 llvm::Value *packedAffinityInfoIndex = builder.CreateIntCast(
2618 packedAffinityInfoOffset, kmpTaskAffinityInfoTy->getElementType(0),
2619 /*isSigned=*/false);
2620 packedAffinityInfoIndex = builder.CreateInBoundsGEP(
2621 kmpTaskAffinityInfoTy, packedAffinityInfo, packedAffinityInfoIndex);
2622
2623 builder.CreateMemCpy(
2624 packedAffinityInfoIndex, llvm::Align(1),
2625 builder.CreatePointerBitCastOrAddrSpaceCast(
2626 affinity.Info, builder.getPtrTy(packedAffinityInfoIndex->getType()
2627 ->getPointerAddressSpace())),
2628 llvm::Align(1), affinityInfoSize);
2629
2630 packedAffinityInfoOffset =
2631 builder.CreateAdd(packedAffinityInfoOffset, affinityCount);
2632 }
2633
2634 affinityInfo = packedAffinityInfo;
2635 }
2636
2637 ad.Count = totalAffinityCount;
2638 ad.Info = affinityInfo;
2639
2640 return mlir::success();
2641}
2642
2643/// Converts an OpenMP task construct into LLVM IR using OpenMPIRBuilder.
2644static LogicalResult
2645convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
2646 LLVM::ModuleTranslation &moduleTranslation) {
2647 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2648 if (failed(checkImplementationStatus(*taskOp)))
2649 return failure();
2650
2651 PrivateVarsInfo privateVarsInfo(taskOp);
2652 TaskContextStructManager taskStructMgr{builder, moduleTranslation,
2653 privateVarsInfo.privatizers};
2654
2655 // Allocate and copy private variables before creating the task. This avoids
2656 // accessing invalid memory if (after this scope ends) the private variables
2657 // are initialized from host variables or if the variables are copied into
2658 // from host variables (firstprivate). The insertion point is just before
2659 // where the code for creating and scheduling the task will go. That puts this
2660 // code outside of the outlined task region, which is what we want because
2661 // this way the initialization and copy regions are executed immediately while
2662 // the host variable data are still live.
2663
2664 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2665 findAllocaInsertPoint(builder, moduleTranslation);
2666
2667 // Not using splitBB() because that requires the current block to have a
2668 // terminator.
2669 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end());
2670 llvm::BasicBlock *taskStartBlock = llvm::BasicBlock::Create(
2671 builder.getContext(), "omp.task.start",
2672 /*Parent=*/builder.GetInsertBlock()->getParent());
2673 llvm::Instruction *branchToTaskStartBlock = builder.CreateBr(taskStartBlock);
2674 builder.SetInsertPoint(branchToTaskStartBlock);
2675
2676 // Now do this again to make the initialization and copy blocks
2677 llvm::BasicBlock *copyBlock =
2678 splitBB(builder, /*CreateBranch=*/true, "omp.private.copy");
2679 llvm::BasicBlock *initBlock =
2680 splitBB(builder, /*CreateBranch=*/true, "omp.private.init");
2681
2682 // Now the control flow graph should look like
2683 // starter_block:
2684 // <---- where we started when convertOmpTaskOp was called
2685 // br %omp.private.init
2686 // omp.private.init:
2687 // br %omp.private.copy
2688 // omp.private.copy:
2689 // br %omp.task.start
2690 // omp.task.start:
2691 // <---- where we want the insertion point to be when we call createTask()
2692
2693 // Save the alloca insertion point on ModuleTranslation stack for use in
2694 // nested regions.
2696 moduleTranslation, allocaIP);
2697
2698 // Allocate and initialize private variables
2699 builder.SetInsertPoint(initBlock->getTerminator());
2700
2701 // Create task variable structure
2702 taskStructMgr.generateTaskContextStruct();
2703 // GEPs so that we can initialize the variables. Don't use these GEPs inside
2704 // of the body otherwise it will be the GEP not the struct which is fowarded
2705 // to the outlined function. GEPs forwarded in this way are passed in a
2706 // stack-allocated (by OpenMPIRBuilder) structure which is not safe for tasks
2707 // which may not be executed until after the current stack frame goes out of
2708 // scope.
2709 taskStructMgr.createGEPsToPrivateVars();
2710
2711 for (auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVarAlloc] :
2712 llvm::zip_equal(privateVarsInfo.privatizers, privateVarsInfo.mlirVars,
2713 privateVarsInfo.blockArgs,
2714 taskStructMgr.getLLVMPrivateVarGEPs())) {
2715 // To be handled inside the task.
2716 if (!privDecl.readsFromMold())
2717 continue;
2718 assert(llvmPrivateVarAlloc &&
2719 "reads from mold so shouldn't have been skipped");
2720
2721 llvm::Expected<llvm::Value *> privateVarOrErr =
2722 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2723 blockArg, llvmPrivateVarAlloc, initBlock);
2724 if (!privateVarOrErr)
2725 return handleError(privateVarOrErr, *taskOp.getOperation());
2726
2728
2729 // TODO: this is a bit of a hack for Fortran character boxes.
2730 // Character boxes are passed by value into the init region and then the
2731 // initialized character box is yielded by value. Here we need to store the
2732 // yielded value into the private allocation, and load the private
2733 // allocation to match the type expected by region block arguments.
2734 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
2735 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2736 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
2737 // Load it so we have the value pointed to by the GEP
2738 llvmPrivateVarAlloc = builder.CreateLoad(privateVarOrErr.get()->getType(),
2739 llvmPrivateVarAlloc);
2740 }
2741 assert(llvmPrivateVarAlloc->getType() ==
2742 moduleTranslation.convertType(blockArg.getType()));
2743
2744 // Mapping blockArg -> llvmPrivateVarAlloc is done inside the body callback
2745 // so that OpenMPIRBuilder doesn't try to pass each GEP address through a
2746 // stack allocated structure.
2747 }
2748
2749 // firstprivate copy region
2750 setInsertPointForPossiblyEmptyBlock(builder, copyBlock);
2751 if (failed(copyFirstPrivateVars(
2752 taskOp, builder, moduleTranslation, privateVarsInfo.mlirVars,
2753 taskStructMgr.getLLVMPrivateVarGEPs(), privateVarsInfo.privatizers,
2754 taskOp.getPrivateNeedsBarrier())))
2755 return llvm::failure();
2756
2757 llvm::OpenMPIRBuilder::AffinityData ad;
2758 if (failed(buildAffinityData(taskOp, builder, moduleTranslation, ad)))
2759 return llvm::failure();
2760
2761 // Set up for call to createTask()
2762 builder.SetInsertPoint(taskStartBlock);
2763
2764 auto bodyCB = [&](InsertPointTy allocaIP,
2765 InsertPointTy codegenIP) -> llvm::Error {
2766 // Save the alloca insertion point on ModuleTranslation stack for use in
2767 // nested regions.
2769 moduleTranslation, allocaIP);
2770
2771 // translate the body of the task:
2772 builder.restoreIP(codegenIP);
2773
2774 llvm::BasicBlock *privInitBlock = nullptr;
2775 privateVarsInfo.llvmVars.resize(privateVarsInfo.blockArgs.size());
2776 for (auto [i, zip] : llvm::enumerate(llvm::zip_equal(
2777 privateVarsInfo.blockArgs, privateVarsInfo.privatizers,
2778 privateVarsInfo.mlirVars))) {
2779 auto [blockArg, privDecl, mlirPrivVar] = zip;
2780 // This is handled before the task executes
2781 if (privDecl.readsFromMold())
2782 continue;
2783
2784 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2785 llvm::Type *llvmAllocType =
2786 moduleTranslation.convertType(privDecl.getType());
2787 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
2788 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
2789 llvmAllocType, /*ArraySize=*/nullptr, "omp.private.alloc");
2790
2791 llvm::Expected<llvm::Value *> privateVarOrError =
2792 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2793 blockArg, llvmPrivateVar, privInitBlock);
2794 if (!privateVarOrError)
2795 return privateVarOrError.takeError();
2796 moduleTranslation.mapValue(blockArg, privateVarOrError.get());
2797 privateVarsInfo.llvmVars[i] = privateVarOrError.get();
2798 }
2799
2800 taskStructMgr.createGEPsToPrivateVars();
2801 for (auto [i, llvmPrivVar] :
2802 llvm::enumerate(taskStructMgr.getLLVMPrivateVarGEPs())) {
2803 if (!llvmPrivVar) {
2804 assert(privateVarsInfo.llvmVars[i] &&
2805 "This is added in the loop above");
2806 continue;
2807 }
2808 privateVarsInfo.llvmVars[i] = llvmPrivVar;
2809 }
2810
2811 // Find and map the addresses of each variable within the task context
2812 // structure
2813 for (auto [blockArg, llvmPrivateVar, privateDecl] :
2814 llvm::zip_equal(privateVarsInfo.blockArgs, privateVarsInfo.llvmVars,
2815 privateVarsInfo.privatizers)) {
2816 // This was handled above.
2817 if (!privateDecl.readsFromMold())
2818 continue;
2819 // Fix broken pass-by-value case for Fortran character boxes
2820 if (!mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2821 llvmPrivateVar = builder.CreateLoad(
2822 moduleTranslation.convertType(blockArg.getType()), llvmPrivateVar);
2823 }
2824 assert(llvmPrivateVar->getType() ==
2825 moduleTranslation.convertType(blockArg.getType()));
2826 moduleTranslation.mapValue(blockArg, llvmPrivateVar);
2827 }
2828
2829 auto continuationBlockOrError = convertOmpOpRegions(
2830 taskOp.getRegion(), "omp.task.region", builder, moduleTranslation);
2831 if (failed(handleError(continuationBlockOrError, *taskOp)))
2832 return llvm::make_error<PreviouslyReportedError>();
2833
2834 builder.SetInsertPoint(continuationBlockOrError.get()->getTerminator());
2835
2836 if (failed(cleanupPrivateVars(builder, moduleTranslation, taskOp.getLoc(),
2837 privateVarsInfo.llvmVars,
2838 privateVarsInfo.privatizers)))
2839 return llvm::make_error<PreviouslyReportedError>();
2840
2841 // Free heap allocated task context structure at the end of the task.
2842 taskStructMgr.freeStructPtr();
2843
2844 return llvm::Error::success();
2845 };
2846
2847 llvm::OpenMPIRBuilder &ompBuilder = *moduleTranslation.getOpenMPBuilder();
2848 SmallVector<llvm::UncondBrInst *> cancelTerminators;
2849 // The directive to match here is OMPD_taskgroup because it is the taskgroup
2850 // which is canceled. This is handled here because it is the task's cleanup
2851 // block which should be branched to.
2852 pushCancelFinalizationCB(cancelTerminators, builder, ompBuilder, taskOp,
2853 llvm::omp::Directive::OMPD_taskgroup);
2854
2856 buildDependData(taskOp.getDependKinds(), taskOp.getDependVars(),
2857 moduleTranslation, dds);
2858
2859 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2860 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2861 moduleTranslation.getOpenMPBuilder()->createTask(
2862 ompLoc, allocaIP, bodyCB, !taskOp.getUntied(),
2863 moduleTranslation.lookupValue(taskOp.getFinal()),
2864 moduleTranslation.lookupValue(taskOp.getIfExpr()), dds, ad,
2865 taskOp.getMergeable(),
2866 moduleTranslation.lookupValue(taskOp.getEventHandle()),
2867 moduleTranslation.lookupValue(taskOp.getPriority()));
2868
2869 if (failed(handleError(afterIP, *taskOp)))
2870 return failure();
2871
2872 // Set the correct branch target for task cancellation
2873 popCancelFinalizationCB(cancelTerminators, ompBuilder, afterIP.get());
2874
2875 builder.restoreIP(*afterIP);
2876 return success();
2877}
2878
2879// Converts an OpenMP taskloop construct into LLVM IR using OpenMPIRBuilder.
2880static LogicalResult
2881convertOmpTaskloopOp(Operation &opInst, llvm::IRBuilderBase &builder,
2882 LLVM::ModuleTranslation &moduleTranslation) {
2883 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2884 auto taskloopOp = cast<omp::TaskloopOp>(opInst);
2885 if (failed(checkImplementationStatus(opInst)))
2886 return failure();
2887
2888 // It stores the pointer of allocated firstprivate copies,
2889 // which can be used later for freeing the allocated space.
2890 SmallVector<llvm::Value *> llvmFirstPrivateVars;
2891 PrivateVarsInfo privateVarsInfo(taskloopOp);
2892 TaskContextStructManager taskStructMgr{builder, moduleTranslation,
2893 privateVarsInfo.privatizers};
2894
2895 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2896 findAllocaInsertPoint(builder, moduleTranslation);
2897
2898 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end());
2899 llvm::BasicBlock *taskloopStartBlock = llvm::BasicBlock::Create(
2900 builder.getContext(), "omp.taskloop.start",
2901 /*Parent=*/builder.GetInsertBlock()->getParent());
2902 llvm::Instruction *branchToTaskloopStartBlock =
2903 builder.CreateBr(taskloopStartBlock);
2904 builder.SetInsertPoint(branchToTaskloopStartBlock);
2905
2906 llvm::BasicBlock *copyBlock =
2907 splitBB(builder, /*CreateBranch=*/true, "omp.private.copy");
2908 llvm::BasicBlock *initBlock =
2909 splitBB(builder, /*CreateBranch=*/true, "omp.private.init");
2910
2912 moduleTranslation, allocaIP);
2913
2914 // Allocate and initialize private variables
2915 builder.SetInsertPoint(initBlock->getTerminator());
2916
2917 // TODO: don't allocate if the loop has zero iterations.
2918 taskStructMgr.generateTaskContextStruct();
2919 taskStructMgr.createGEPsToPrivateVars();
2920
2921 llvmFirstPrivateVars.resize(privateVarsInfo.blockArgs.size());
2922
2923 for (auto [i, zip] : llvm::enumerate(llvm::zip_equal(
2924 privateVarsInfo.privatizers, privateVarsInfo.mlirVars,
2925 privateVarsInfo.blockArgs, taskStructMgr.getLLVMPrivateVarGEPs()))) {
2926 auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVarAlloc] = zip;
2927 // To be handled inside the taskloop.
2928 if (!privDecl.readsFromMold())
2929 continue;
2930 assert(llvmPrivateVarAlloc &&
2931 "reads from mold so shouldn't have been skipped");
2932
2933 llvm::Expected<llvm::Value *> privateVarOrErr =
2934 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2935 blockArg, llvmPrivateVarAlloc, initBlock);
2936 if (!privateVarOrErr)
2937 return handleError(privateVarOrErr, *taskloopOp.getOperation());
2938
2939 llvmFirstPrivateVars[i] = privateVarOrErr.get();
2940
2941 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2942 builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
2943
2944 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
2945 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2946 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
2947 // Load it so we have the value pointed to by the GEP
2948 llvmPrivateVarAlloc = builder.CreateLoad(privateVarOrErr.get()->getType(),
2949 llvmPrivateVarAlloc);
2950 }
2951 assert(llvmPrivateVarAlloc->getType() ==
2952 moduleTranslation.convertType(blockArg.getType()));
2953 }
2954
2955 // firstprivate copy region
2956 setInsertPointForPossiblyEmptyBlock(builder, copyBlock);
2957 if (failed(copyFirstPrivateVars(
2958 taskloopOp, builder, moduleTranslation, privateVarsInfo.mlirVars,
2959 taskStructMgr.getLLVMPrivateVarGEPs(), privateVarsInfo.privatizers,
2960 taskloopOp.getPrivateNeedsBarrier())))
2961 return llvm::failure();
2962
2963 // Set up inserttion point for call to createTaskloop()
2964 builder.SetInsertPoint(taskloopStartBlock);
2965
2966 auto bodyCB = [&](InsertPointTy allocaIP,
2967 InsertPointTy codegenIP) -> llvm::Error {
2968 // Save the alloca insertion point on ModuleTranslation stack for use in
2969 // nested regions.
2971 moduleTranslation, allocaIP);
2972
2973 // translate the body of the taskloop:
2974 builder.restoreIP(codegenIP);
2975
2976 llvm::BasicBlock *privInitBlock = nullptr;
2977 privateVarsInfo.llvmVars.resize(privateVarsInfo.blockArgs.size());
2978 for (auto [i, zip] : llvm::enumerate(llvm::zip_equal(
2979 privateVarsInfo.blockArgs, privateVarsInfo.privatizers,
2980 privateVarsInfo.mlirVars))) {
2981 auto [blockArg, privDecl, mlirPrivVar] = zip;
2982 // This is handled before the task executes
2983 if (privDecl.readsFromMold())
2984 continue;
2985
2986 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2987 llvm::Type *llvmAllocType =
2988 moduleTranslation.convertType(privDecl.getType());
2989 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
2990 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
2991 llvmAllocType, /*ArraySize=*/nullptr, "omp.private.alloc");
2992
2993 llvm::Expected<llvm::Value *> privateVarOrError =
2994 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2995 blockArg, llvmPrivateVar, privInitBlock);
2996 if (!privateVarOrError)
2997 return privateVarOrError.takeError();
2998 moduleTranslation.mapValue(blockArg, privateVarOrError.get());
2999 privateVarsInfo.llvmVars[i] = privateVarOrError.get();
3000 }
3001
3002 taskStructMgr.createGEPsToPrivateVars();
3003 for (auto [i, llvmPrivVar] :
3004 llvm::enumerate(taskStructMgr.getLLVMPrivateVarGEPs())) {
3005 if (!llvmPrivVar) {
3006 assert(privateVarsInfo.llvmVars[i] &&
3007 "This is added in the loop above");
3008 continue;
3009 }
3010 privateVarsInfo.llvmVars[i] = llvmPrivVar;
3011 }
3012
3013 // Find and map the addresses of each variable within the taskloop context
3014 // structure
3015 for (auto [blockArg, llvmPrivateVar, privateDecl] :
3016 llvm::zip_equal(privateVarsInfo.blockArgs, privateVarsInfo.llvmVars,
3017 privateVarsInfo.privatizers)) {
3018 // This was handled above.
3019 if (!privateDecl.readsFromMold())
3020 continue;
3021 // Fix broken pass-by-value case for Fortran character boxes
3022 if (!mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
3023 llvmPrivateVar = builder.CreateLoad(
3024 moduleTranslation.convertType(blockArg.getType()), llvmPrivateVar);
3025 }
3026 assert(llvmPrivateVar->getType() ==
3027 moduleTranslation.convertType(blockArg.getType()));
3028 moduleTranslation.mapValue(blockArg, llvmPrivateVar);
3029 }
3030
3031 auto continuationBlockOrError =
3032 convertOmpOpRegions(taskloopOp.getRegion(), "omp.taskloop.region",
3033 builder, moduleTranslation);
3034
3035 if (failed(handleError(continuationBlockOrError, opInst)))
3036 return llvm::make_error<PreviouslyReportedError>();
3037
3038 builder.SetInsertPoint(continuationBlockOrError.get()->getTerminator());
3039
3040 // This is freeing the private variables as mapped inside of the task: these
3041 // will be per-task private copies possibly after task duplication. This is
3042 // handled transparently by how these are passed to the structure passed
3043 // into the outlined function. When the task is duplicated, that structure
3044 // is duplicated too.
3045 if (failed(cleanupPrivateVars(builder, moduleTranslation,
3046 taskloopOp.getLoc(), privateVarsInfo.llvmVars,
3047 privateVarsInfo.privatizers)))
3048 return llvm::make_error<PreviouslyReportedError>();
3049 // Similarly, the task context structure freed inside the task is the
3050 // per-task copy after task duplication.
3051 taskStructMgr.freeStructPtr();
3052
3053 return llvm::Error::success();
3054 };
3055
3056 // Taskloop divides into an appropriate number of tasks by repeatedly
3057 // duplicating the original task. Each time this is done, the task context
3058 // structure must be duplicated too.
3059 auto taskDupCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP,
3060 llvm::Value *destPtr, llvm::Value *srcPtr)
3062 llvm::IRBuilderBase::InsertPointGuard guard(builder);
3063 builder.restoreIP(codegenIP);
3064
3065 llvm::Type *ptrTy =
3066 builder.getPtrTy(srcPtr->getType()->getPointerAddressSpace());
3067 llvm::Value *src =
3068 builder.CreateLoad(ptrTy, srcPtr, "omp.taskloop.context.src");
3069
3070 TaskContextStructManager &srcStructMgr = taskStructMgr;
3071 TaskContextStructManager destStructMgr(builder, moduleTranslation,
3072 privateVarsInfo.privatizers);
3073 destStructMgr.generateTaskContextStruct();
3074 llvm::Value *dest = destStructMgr.getStructPtr();
3075 dest->setName("omp.taskloop.context.dest");
3076 builder.CreateStore(dest, destPtr);
3077
3079 srcStructMgr.createGEPsToPrivateVars(src);
3081 destStructMgr.createGEPsToPrivateVars(dest);
3082
3083 // Inline init regions.
3084 for (auto [privDecl, mold, blockArg, llvmPrivateVarAlloc] :
3085 llvm::zip_equal(privateVarsInfo.privatizers, srcGEPs,
3086 privateVarsInfo.blockArgs, destGEPs)) {
3087 // To be handled inside task body.
3088 if (!privDecl.readsFromMold())
3089 continue;
3090 assert(llvmPrivateVarAlloc &&
3091 "reads from mold so shouldn't have been skipped");
3092
3093 llvm::Expected<llvm::Value *> privateVarOrErr =
3094 initPrivateVar(builder, moduleTranslation, privDecl, mold, blockArg,
3095 llvmPrivateVarAlloc, builder.GetInsertBlock());
3096 if (!privateVarOrErr)
3097 return privateVarOrErr.takeError();
3098
3100
3101 // TODO: this is a bit of a hack for Fortran character boxes.
3102 // Character boxes are passed by value into the init region and then the
3103 // initialized character box is yielded by value. Here we need to store
3104 // the yielded value into the private allocation, and load the private
3105 // allocation to match the type expected by region block arguments.
3106 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
3107 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
3108 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
3109 // Load it so we have the value pointed to by the GEP
3110 llvmPrivateVarAlloc = builder.CreateLoad(
3111 privateVarOrErr.get()->getType(), llvmPrivateVarAlloc);
3112 }
3113 assert(llvmPrivateVarAlloc->getType() ==
3114 moduleTranslation.convertType(blockArg.getType()));
3115
3116 // Mapping blockArg -> llvmPrivateVarAlloc is done inside the body
3117 // callback so that OpenMPIRBuilder doesn't try to pass each GEP address
3118 // through a stack allocated structure.
3119 }
3120
3121 if (failed(copyFirstPrivateVars(
3122 &opInst, builder, moduleTranslation, srcGEPs, destGEPs,
3123 privateVarsInfo.privatizers, taskloopOp.getPrivateNeedsBarrier())))
3124 return llvm::make_error<PreviouslyReportedError>();
3125
3126 return builder.saveIP();
3127 };
3128
3129 auto loopOp = cast<omp::LoopNestOp>(taskloopOp.getWrappedLoop());
3130
3131 auto loopInfo = [&]() -> llvm::Expected<llvm::CanonicalLoopInfo *> {
3132 llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
3133 return loopInfo;
3134 };
3135
3136 Operation::operand_range lowerBounds = loopOp.getLoopLowerBounds();
3137 Operation::operand_range upperBounds = loopOp.getLoopUpperBounds();
3138 Operation::operand_range steps = loopOp.getLoopSteps();
3139 llvm::Type *boundType =
3140 moduleTranslation.lookupValue(lowerBounds[0])->getType();
3141 llvm::Value *lbVal = nullptr;
3142 llvm::Value *ubVal = builder.getIntN(boundType->getIntegerBitWidth(), 1);
3143 llvm::Value *stepVal = nullptr;
3144 if (loopOp.getCollapseNumLoops() > 1) {
3145 // In cases where Collapse is used with Taskloop, the upper bound of the
3146 // iteration space needs to be recalculated to cater for the collapsed loop.
3147 // The Collapsed Loop UpperBound is the product of all collapsed
3148 // loop's tripcount.
3149 // The LowerBound for collapsed loops is always 1. When the loops are
3150 // collapsed, it will reset the bounds and introduce processing to ensure
3151 // the index's are presented as expected. As this happens after creating
3152 // Taskloop, these bounds need predicting. Example:
3153 // !$omp taskloop collapse(2)
3154 // do i = 1, 10
3155 // do j = 1, 5
3156 // ..
3157 // end do
3158 // end do
3159 // This loop above has a total of 50 iterations, so the lb will be 1, and
3160 // the ub will be 50. collapseLoops in OMPIRBuilder then handles ensuring
3161 // that i and j are properly presented when used in the loop.
3162 for (uint64_t i = 0; i < loopOp.getCollapseNumLoops(); i++) {
3163 llvm::Value *loopLb = moduleTranslation.lookupValue(lowerBounds[i]);
3164 llvm::Value *loopUb = moduleTranslation.lookupValue(upperBounds[i]);
3165 llvm::Value *loopStep = moduleTranslation.lookupValue(steps[i]);
3166 // In some cases, such as where the ub is less than the lb so the loop
3167 // steps down, the calculation for the loopTripCount is swapped. To ensure
3168 // the correct value is found, calculate both UB - LB and LB - UB then
3169 // select which value to use depending on how the loop has been
3170 // configured.
3171 llvm::Value *loopLbMinusOne = builder.CreateSub(
3172 loopLb, builder.getIntN(boundType->getIntegerBitWidth(), 1));
3173 llvm::Value *loopUbMinusOne = builder.CreateSub(
3174 loopUb, builder.getIntN(boundType->getIntegerBitWidth(), 1));
3175 llvm::Value *boundsCmp = builder.CreateICmpSLT(loopLb, loopUb);
3176 llvm::Value *ubMinusLb = builder.CreateSub(loopUb, loopLbMinusOne);
3177 llvm::Value *lbMinusUb = builder.CreateSub(loopLb, loopUbMinusOne);
3178 llvm::Value *loopTripCount =
3179 builder.CreateSelect(boundsCmp, ubMinusLb, lbMinusUb);
3180 loopTripCount = builder.CreateBinaryIntrinsic(
3181 llvm::Intrinsic::abs, loopTripCount, builder.getFalse());
3182 // For loops that have a step value not equal to 1, we need to adjust the
3183 // trip count to ensure the correct number of iterations for the loop is
3184 // captured.
3185 llvm::Value *loopTripCountDivStep =
3186 builder.CreateSDiv(loopTripCount, loopStep);
3187 loopTripCountDivStep = builder.CreateBinaryIntrinsic(
3188 llvm::Intrinsic::abs, loopTripCountDivStep, builder.getFalse());
3189 llvm::Value *loopTripCountRem =
3190 builder.CreateSRem(loopTripCount, loopStep);
3191 loopTripCountRem = builder.CreateBinaryIntrinsic(
3192 llvm::Intrinsic::abs, loopTripCountRem, builder.getFalse());
3193 llvm::Value *needsRoundUp = builder.CreateICmpNE(
3194 loopTripCountRem,
3195 builder.getIntN(loopTripCountRem->getType()->getIntegerBitWidth(),
3196 0));
3197 loopTripCount =
3198 builder.CreateAdd(loopTripCountDivStep,
3199 builder.CreateZExtOrTrunc(
3200 needsRoundUp, loopTripCountDivStep->getType()));
3201 ubVal = builder.CreateMul(ubVal, loopTripCount);
3202 }
3203 lbVal = builder.getIntN(boundType->getIntegerBitWidth(), 1);
3204 stepVal = builder.getIntN(boundType->getIntegerBitWidth(), 1);
3205 } else {
3206 lbVal = moduleTranslation.lookupValue(lowerBounds[0]);
3207 ubVal = moduleTranslation.lookupValue(upperBounds[0]);
3208 stepVal = moduleTranslation.lookupValue(steps[0]);
3209 }
3210 assert(lbVal != nullptr && "Expected value for lbVal");
3211 assert(ubVal != nullptr && "Expected value for ubVal");
3212 assert(stepVal != nullptr && "Expected value for stepVal");
3213
3214 llvm::Value *ifCond = nullptr;
3215 llvm::Value *grainsize = nullptr;
3216 int sched = 0; // default
3217 mlir::Value grainsizeVal = taskloopOp.getGrainsize();
3218 mlir::Value numTasksVal = taskloopOp.getNumTasks();
3219 if (Value ifVar = taskloopOp.getIfExpr())
3220 ifCond = moduleTranslation.lookupValue(ifVar);
3221 if (grainsizeVal) {
3222 grainsize = moduleTranslation.lookupValue(grainsizeVal);
3223 sched = 1; // grainsize
3224 } else if (numTasksVal) {
3225 grainsize = moduleTranslation.lookupValue(numTasksVal);
3226 sched = 2; // num_tasks
3227 }
3228
3229 llvm::OpenMPIRBuilder::TaskDupCallbackTy taskDupOrNull = nullptr;
3230 if (taskStructMgr.getStructPtr())
3231 taskDupOrNull = taskDupCB;
3232
3233 llvm::OpenMPIRBuilder &ompBuilder = *moduleTranslation.getOpenMPBuilder();
3234 SmallVector<llvm::UncondBrInst *> cancelTerminators;
3235 // The directive to match here is OMPD_taskgroup because it is the
3236 // taskgroup which is canceled. This is handled here because it is the
3237 // task's cleanup block which should be branched to. It doesn't depend upon
3238 // nogroup because even in that case the taskloop might still be inside an
3239 // explicit taskgroup.
3240 pushCancelFinalizationCB(cancelTerminators, builder, ompBuilder, taskloopOp,
3241 llvm::omp::Directive::OMPD_taskgroup);
3242
3243 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3244 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3245 moduleTranslation.getOpenMPBuilder()->createTaskloop(
3246 ompLoc, allocaIP, bodyCB, loopInfo, lbVal, ubVal, stepVal,
3247 taskloopOp.getUntied(), ifCond, grainsize, taskloopOp.getNogroup(),
3248 sched, moduleTranslation.lookupValue(taskloopOp.getFinal()),
3249 taskloopOp.getMergeable(),
3250 moduleTranslation.lookupValue(taskloopOp.getPriority()),
3251 loopOp.getCollapseNumLoops(), taskDupOrNull,
3252 taskStructMgr.getStructPtr());
3253
3254 if (failed(handleError(afterIP, opInst)))
3255 return failure();
3256
3257 popCancelFinalizationCB(cancelTerminators, ompBuilder, afterIP.get());
3258
3259 builder.restoreIP(*afterIP);
3260 return success();
3261}
3262
3263/// Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder.
3264static LogicalResult
3265convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder,
3266 LLVM::ModuleTranslation &moduleTranslation) {
3267 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
3268 if (failed(checkImplementationStatus(*tgOp)))
3269 return failure();
3270
3271 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
3272 builder.restoreIP(codegenIP);
3273 return convertOmpOpRegions(tgOp.getRegion(), "omp.taskgroup.region",
3274 builder, moduleTranslation)
3275 .takeError();
3276 };
3277
3278 InsertPointTy allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
3279 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3280 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3281 moduleTranslation.getOpenMPBuilder()->createTaskgroup(ompLoc, allocaIP,
3282 bodyCB);
3283
3284 if (failed(handleError(afterIP, *tgOp)))
3285 return failure();
3286
3287 builder.restoreIP(*afterIP);
3288 return success();
3289}
3290
3291static LogicalResult
3292convertOmpTaskwaitOp(omp::TaskwaitOp twOp, llvm::IRBuilderBase &builder,
3293 LLVM::ModuleTranslation &moduleTranslation) {
3294 if (failed(checkImplementationStatus(*twOp)))
3295 return failure();
3296
3297 moduleTranslation.getOpenMPBuilder()->createTaskwait(builder.saveIP());
3298 return success();
3299}
3300
3301/// Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
3302static LogicalResult
3303convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
3304 LLVM::ModuleTranslation &moduleTranslation) {
3305 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3306 auto wsloopOp = cast<omp::WsloopOp>(opInst);
3307 if (failed(checkImplementationStatus(opInst)))
3308 return failure();
3309
3310 auto loopOp = cast<omp::LoopNestOp>(wsloopOp.getWrappedLoop());
3311 llvm::ArrayRef<bool> isByRef = getIsByRef(wsloopOp.getReductionByref());
3312 assert(isByRef.size() == wsloopOp.getNumReductionVars());
3313
3314 // Static is the default.
3315 auto schedule =
3316 wsloopOp.getScheduleKind().value_or(omp::ClauseScheduleKind::Static);
3317
3318 // Find the loop configuration.
3319 llvm::Value *step = moduleTranslation.lookupValue(loopOp.getLoopSteps()[0]);
3320 llvm::Type *ivType = step->getType();
3321 llvm::Value *chunk = nullptr;
3322 if (wsloopOp.getScheduleChunk()) {
3323 llvm::Value *chunkVar =
3324 moduleTranslation.lookupValue(wsloopOp.getScheduleChunk());
3325 chunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
3326 }
3327
3328 omp::DistributeOp distributeOp = nullptr;
3329 llvm::Value *distScheduleChunk = nullptr;
3330 bool hasDistSchedule = false;
3331 if (llvm::isa_and_present<omp::DistributeOp>(opInst.getParentOp())) {
3332 distributeOp = cast<omp::DistributeOp>(opInst.getParentOp());
3333 hasDistSchedule = distributeOp.getDistScheduleStatic();
3334 if (distributeOp.getDistScheduleChunkSize()) {
3335 llvm::Value *chunkVar = moduleTranslation.lookupValue(
3336 distributeOp.getDistScheduleChunkSize());
3337 distScheduleChunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
3338 }
3339 }
3340
3341 PrivateVarsInfo privateVarsInfo(wsloopOp);
3342
3344 collectReductionDecls(wsloopOp, reductionDecls);
3345 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3346 findAllocaInsertPoint(builder, moduleTranslation);
3347
3348 SmallVector<llvm::Value *> privateReductionVariables(
3349 wsloopOp.getNumReductionVars());
3350
3352 builder, moduleTranslation, privateVarsInfo, allocaIP);
3353 if (handleError(afterAllocas, opInst).failed())
3354 return failure();
3355
3356 DenseMap<Value, llvm::Value *> reductionVariableMap;
3357
3358 MutableArrayRef<BlockArgument> reductionArgs =
3359 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
3360
3361 SmallVector<DeferredStore> deferredStores;
3362
3363 if (failed(allocReductionVars(wsloopOp, reductionArgs, builder,
3364 moduleTranslation, allocaIP, reductionDecls,
3365 privateReductionVariables, reductionVariableMap,
3366 deferredStores, isByRef)))
3367 return failure();
3368
3369 if (handleError(initPrivateVars(builder, moduleTranslation, privateVarsInfo),
3370 opInst)
3371 .failed())
3372 return failure();
3373
3374 if (failed(copyFirstPrivateVars(
3375 wsloopOp, builder, moduleTranslation, privateVarsInfo.mlirVars,
3376 privateVarsInfo.llvmVars, privateVarsInfo.privatizers,
3377 wsloopOp.getPrivateNeedsBarrier())))
3378 return failure();
3379
3380 assert(afterAllocas.get()->getSinglePredecessor());
3381 if (failed(initReductionVars(wsloopOp, reductionArgs, builder,
3382 moduleTranslation,
3383 afterAllocas.get()->getSinglePredecessor(),
3384 reductionDecls, privateReductionVariables,
3385 reductionVariableMap, isByRef, deferredStores)))
3386 return failure();
3387
3388 // TODO: Handle doacross loops when the ordered clause has a parameter.
3389 bool isOrdered = wsloopOp.getOrdered().has_value();
3390 std::optional<omp::ScheduleModifier> scheduleMod = wsloopOp.getScheduleMod();
3391 bool isSimd = wsloopOp.getScheduleSimd();
3392 bool loopNeedsBarrier = !wsloopOp.getNowait();
3393
3394 // The only legal way for the direct parent to be omp.distribute is that this
3395 // represents 'distribute parallel do'. Otherwise, this is a regular
3396 // worksharing loop.
3397 llvm::omp::WorksharingLoopType workshareLoopType =
3398 llvm::isa_and_present<omp::DistributeOp>(opInst.getParentOp())
3399 ? llvm::omp::WorksharingLoopType::DistributeForStaticLoop
3400 : llvm::omp::WorksharingLoopType::ForStaticLoop;
3401
3402 SmallVector<llvm::UncondBrInst *> cancelTerminators;
3403 pushCancelFinalizationCB(cancelTerminators, builder, *ompBuilder, wsloopOp,
3404 llvm::omp::Directive::OMPD_for);
3405
3406 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3407
3408 // Initialize linear variables and linear step
3409 LinearClauseProcessor linearClauseProcessor;
3410
3411 if (!wsloopOp.getLinearVars().empty()) {
3412 auto linearVarTypes = wsloopOp.getLinearVarTypes().value();
3413 for (mlir::Attribute linearVarType : linearVarTypes)
3414 linearClauseProcessor.registerType(moduleTranslation, linearVarType);
3415
3416 for (auto [idx, linearVar] : llvm::enumerate(wsloopOp.getLinearVars()))
3417 linearClauseProcessor.createLinearVar(
3418 builder, moduleTranslation, moduleTranslation.lookupValue(linearVar),
3419 idx);
3420 for (mlir::Value linearStep : wsloopOp.getLinearStepVars())
3421 linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
3422 }
3423
3424 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
3426 wsloopOp.getRegion(), "omp.wsloop.region", builder, moduleTranslation);
3427
3428 if (failed(handleError(regionBlock, opInst)))
3429 return failure();
3430
3431 llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
3432
3433 // Emit Initialization and Update IR for linear variables
3434 if (!wsloopOp.getLinearVars().empty()) {
3435 linearClauseProcessor.initLinearVar(builder, moduleTranslation,
3436 loopInfo->getPreheader());
3437 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
3438 moduleTranslation.getOpenMPBuilder()->createBarrier(
3439 builder.saveIP(), llvm::omp::OMPD_barrier);
3440 if (failed(handleError(afterBarrierIP, *loopOp)))
3441 return failure();
3442 builder.restoreIP(*afterBarrierIP);
3443 linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
3444 loopInfo->getIndVar());
3445 linearClauseProcessor.splitLinearFiniBB(builder, loopInfo->getExit());
3446 }
3447
3448 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
3449
3450 // Check if we can generate no-loop kernel
3451 bool noLoopMode = false;
3452 omp::TargetOp targetOp = wsloopOp->getParentOfType<mlir::omp::TargetOp>();
3453 if (targetOp) {
3454 Operation *targetCapturedOp = targetOp.getInnermostCapturedOmpOp();
3455 // We need this check because, without it, noLoopMode would be set to true
3456 // for every omp.wsloop nested inside a no-loop SPMD target region, even if
3457 // that loop is not the top-level SPMD one.
3458 if (loopOp == targetCapturedOp) {
3459 omp::TargetRegionFlags kernelFlags =
3460 targetOp.getKernelExecFlags(targetCapturedOp);
3461 if (omp::bitEnumContainsAll(kernelFlags,
3462 omp::TargetRegionFlags::spmd |
3463 omp::TargetRegionFlags::no_loop) &&
3464 !omp::bitEnumContainsAny(kernelFlags,
3465 omp::TargetRegionFlags::generic))
3466 noLoopMode = true;
3467 }
3468 }
3469
3470 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
3471 ompBuilder->applyWorkshareLoop(
3472 ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
3473 convertToScheduleKind(schedule), chunk, isSimd,
3474 scheduleMod == omp::ScheduleModifier::monotonic,
3475 scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
3476 workshareLoopType, noLoopMode, hasDistSchedule, distScheduleChunk);
3477
3478 if (failed(handleError(wsloopIP, opInst)))
3479 return failure();
3480
3481 // Emit finalization and in-place rewrites for linear vars.
3482 if (!wsloopOp.getLinearVars().empty()) {
3483 llvm::OpenMPIRBuilder::InsertPointTy oldIP = builder.saveIP();
3484 assert(loopInfo->getLastIter() &&
3485 "`lastiter` in CanonicalLoopInfo is nullptr");
3486 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
3487 linearClauseProcessor.finalizeLinearVar(builder, moduleTranslation,
3488 loopInfo->getLastIter());
3489 if (failed(handleError(afterBarrierIP, *loopOp)))
3490 return failure();
3491 for (size_t index = 0; index < wsloopOp.getLinearVars().size(); index++)
3492 linearClauseProcessor.rewriteInPlace(
3493 builder, sourceBlock->getSingleSuccessor(), *regionBlock,
3494 "omp.loop_nest.region", index);
3495
3496 builder.restoreIP(oldIP);
3497 }
3498
3499 // Set the correct branch target for task cancellation
3500 popCancelFinalizationCB(cancelTerminators, *ompBuilder, wsloopIP.get());
3501
3502 // Process the reductions if required.
3503 if (failed(createReductionsAndCleanup(
3504 wsloopOp, builder, moduleTranslation, allocaIP, reductionDecls,
3505 privateReductionVariables, isByRef, wsloopOp.getNowait(),
3506 /*isTeamsReduction=*/false)))
3507 return failure();
3508
3509 return cleanupPrivateVars(builder, moduleTranslation, wsloopOp.getLoc(),
3510 privateVarsInfo.llvmVars,
3511 privateVarsInfo.privatizers);
3512}
3513
3514/// Converts the OpenMP parallel operation to LLVM IR.
3515static LogicalResult
3516convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
3517 LLVM::ModuleTranslation &moduleTranslation) {
3518 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
3519 ArrayRef<bool> isByRef = getIsByRef(opInst.getReductionByref());
3520 assert(isByRef.size() == opInst.getNumReductionVars());
3521 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3522 bool isCancellable = constructIsCancellable(opInst);
3523
3524 if (failed(checkImplementationStatus(*opInst)))
3525 return failure();
3526
3527 PrivateVarsInfo privateVarsInfo(opInst);
3528
3529 // Collect reduction declarations
3531 collectReductionDecls(opInst, reductionDecls);
3532 SmallVector<llvm::Value *> privateReductionVariables(
3533 opInst.getNumReductionVars());
3534 SmallVector<DeferredStore> deferredStores;
3535
3536 auto bodyGenCB = [&](InsertPointTy allocaIP,
3537 InsertPointTy codeGenIP) -> llvm::Error {
3539 builder, moduleTranslation, privateVarsInfo, allocaIP);
3540 if (handleError(afterAllocas, *opInst).failed())
3541 return llvm::make_error<PreviouslyReportedError>();
3542
3543 // Allocate reduction vars
3544 DenseMap<Value, llvm::Value *> reductionVariableMap;
3545
3546 MutableArrayRef<BlockArgument> reductionArgs =
3547 cast<omp::BlockArgOpenMPOpInterface>(*opInst).getReductionBlockArgs();
3548
3549 allocaIP =
3550 InsertPointTy(allocaIP.getBlock(),
3551 allocaIP.getBlock()->getTerminator()->getIterator());
3552
3553 if (failed(allocReductionVars(
3554 opInst, reductionArgs, builder, moduleTranslation, allocaIP,
3555 reductionDecls, privateReductionVariables, reductionVariableMap,
3556 deferredStores, isByRef)))
3557 return llvm::make_error<PreviouslyReportedError>();
3558
3559 assert(afterAllocas.get()->getSinglePredecessor());
3560 builder.restoreIP(codeGenIP);
3561
3562 if (handleError(
3563 initPrivateVars(builder, moduleTranslation, privateVarsInfo),
3564 *opInst)
3565 .failed())
3566 return llvm::make_error<PreviouslyReportedError>();
3567
3568 if (failed(copyFirstPrivateVars(
3569 opInst, builder, moduleTranslation, privateVarsInfo.mlirVars,
3570 privateVarsInfo.llvmVars, privateVarsInfo.privatizers,
3571 opInst.getPrivateNeedsBarrier())))
3572 return llvm::make_error<PreviouslyReportedError>();
3573
3574 if (failed(
3575 initReductionVars(opInst, reductionArgs, builder, moduleTranslation,
3576 afterAllocas.get()->getSinglePredecessor(),
3577 reductionDecls, privateReductionVariables,
3578 reductionVariableMap, isByRef, deferredStores)))
3579 return llvm::make_error<PreviouslyReportedError>();
3580
3581 // Save the alloca insertion point on ModuleTranslation stack for use in
3582 // nested regions.
3584 moduleTranslation, allocaIP);
3585
3586 // ParallelOp has only one region associated with it.
3588 opInst.getRegion(), "omp.par.region", builder, moduleTranslation);
3589 if (!regionBlock)
3590 return regionBlock.takeError();
3591
3592 // Process the reductions if required.
3593 if (opInst.getNumReductionVars() > 0) {
3594 // Collect reduction info
3596 SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
3598 owningReductionGenRefDataPtrGens;
3600 collectReductionInfo(opInst, builder, moduleTranslation, reductionDecls,
3601 owningReductionGens, owningAtomicReductionGens,
3602 owningReductionGenRefDataPtrGens,
3603 privateReductionVariables, reductionInfos, isByRef);
3604
3605 // Move to region cont block
3606 builder.SetInsertPoint((*regionBlock)->getTerminator());
3607
3608 // Generate reductions from info
3609 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
3610 builder.SetInsertPoint(tempTerminator);
3611
3612 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
3613 ompBuilder->createReductions(
3614 builder.saveIP(), allocaIP, reductionInfos, isByRef,
3615 /*IsNoWait=*/false, /*IsTeamsReduction=*/false);
3616 if (!contInsertPoint)
3617 return contInsertPoint.takeError();
3618
3619 if (!contInsertPoint->getBlock())
3620 return llvm::make_error<PreviouslyReportedError>();
3621
3622 tempTerminator->eraseFromParent();
3623 builder.restoreIP(*contInsertPoint);
3624 }
3625
3626 return llvm::Error::success();
3627 };
3628
3629 auto privCB = [](InsertPointTy allocaIP, InsertPointTy codeGenIP,
3630 llvm::Value &, llvm::Value &val, llvm::Value *&replVal) {
3631 // tell OpenMPIRBuilder not to do anything. We handled Privatisation in
3632 // bodyGenCB.
3633 replVal = &val;
3634 return codeGenIP;
3635 };
3636
3637 // TODO: Perform finalization actions for variables. This has to be
3638 // called for variables which have destructors/finalizers.
3639 auto finiCB = [&](InsertPointTy codeGenIP) -> llvm::Error {
3640 InsertPointTy oldIP = builder.saveIP();
3641 builder.restoreIP(codeGenIP);
3642
3643 // if the reduction has a cleanup region, inline it here to finalize the
3644 // reduction variables
3645 SmallVector<Region *> reductionCleanupRegions;
3646 llvm::transform(reductionDecls, std::back_inserter(reductionCleanupRegions),
3647 [](omp::DeclareReductionOp reductionDecl) {
3648 return &reductionDecl.getCleanupRegion();
3649 });
3650 if (failed(inlineOmpRegionCleanup(
3651 reductionCleanupRegions, privateReductionVariables,
3652 moduleTranslation, builder, "omp.reduction.cleanup")))
3653 return llvm::createStringError(
3654 "failed to inline `cleanup` region of `omp.declare_reduction`");
3655
3656 if (failed(cleanupPrivateVars(builder, moduleTranslation, opInst.getLoc(),
3657 privateVarsInfo.llvmVars,
3658 privateVarsInfo.privatizers)))
3659 return llvm::make_error<PreviouslyReportedError>();
3660
3661 // If we could be performing cancellation, add the cancellation barrier on
3662 // the way out of the outlined region.
3663 if (isCancellable) {
3664 auto IPOrErr = ompBuilder->createBarrier(
3665 llvm::OpenMPIRBuilder::LocationDescription(builder),
3666 llvm::omp::Directive::OMPD_unknown,
3667 /* ForceSimpleCall */ false,
3668 /* CheckCancelFlag */ false);
3669 if (!IPOrErr)
3670 return IPOrErr.takeError();
3671 }
3672
3673 builder.restoreIP(oldIP);
3674 return llvm::Error::success();
3675 };
3676
3677 llvm::Value *ifCond = nullptr;
3678 if (auto ifVar = opInst.getIfExpr())
3679 ifCond = moduleTranslation.lookupValue(ifVar);
3680 llvm::Value *numThreads = nullptr;
3681 if (!opInst.getNumThreadsVars().empty())
3682 numThreads = moduleTranslation.lookupValue(opInst.getNumThreads(0));
3683 auto pbKind = llvm::omp::OMP_PROC_BIND_default;
3684 if (auto bind = opInst.getProcBindKind())
3685 pbKind = getProcBindKind(*bind);
3686
3687 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3688 findAllocaInsertPoint(builder, moduleTranslation);
3689 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3690
3691 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3692 ompBuilder->createParallel(ompLoc, allocaIP, bodyGenCB, privCB, finiCB,
3693 ifCond, numThreads, pbKind, isCancellable);
3694
3695 if (failed(handleError(afterIP, *opInst)))
3696 return failure();
3697
3698 builder.restoreIP(*afterIP);
3699 return success();
3700}
3701
3702/// Convert Order attribute to llvm::omp::OrderKind.
3703static llvm::omp::OrderKind
3704convertOrderKind(std::optional<omp::ClauseOrderKind> o) {
3705 if (!o)
3706 return llvm::omp::OrderKind::OMP_ORDER_unknown;
3707 switch (*o) {
3708 case omp::ClauseOrderKind::Concurrent:
3709 return llvm::omp::OrderKind::OMP_ORDER_concurrent;
3710 }
3711 llvm_unreachable("Unknown ClauseOrderKind kind");
3712}
3713
3714/// Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder.
3715static LogicalResult
3716convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
3717 LLVM::ModuleTranslation &moduleTranslation) {
3718 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3719 auto simdOp = cast<omp::SimdOp>(opInst);
3720
3721 if (failed(checkImplementationStatus(opInst)))
3722 return failure();
3723
3724 PrivateVarsInfo privateVarsInfo(simdOp);
3725
3726 MutableArrayRef<BlockArgument> reductionArgs =
3727 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
3728 DenseMap<Value, llvm::Value *> reductionVariableMap;
3729 SmallVector<llvm::Value *> privateReductionVariables(
3730 simdOp.getNumReductionVars());
3731 SmallVector<DeferredStore> deferredStores;
3733 collectReductionDecls(simdOp, reductionDecls);
3734 llvm::ArrayRef<bool> isByRef = getIsByRef(simdOp.getReductionByref());
3735 assert(isByRef.size() == simdOp.getNumReductionVars());
3736
3737 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3738 findAllocaInsertPoint(builder, moduleTranslation);
3739
3741 builder, moduleTranslation, privateVarsInfo, allocaIP);
3742 if (handleError(afterAllocas, opInst).failed())
3743 return failure();
3744
3745 // Initialize linear variables and linear step
3746 LinearClauseProcessor linearClauseProcessor;
3747
3748 if (!simdOp.getLinearVars().empty()) {
3749 auto linearVarTypes = simdOp.getLinearVarTypes().value();
3750 for (mlir::Attribute linearVarType : linearVarTypes)
3751 linearClauseProcessor.registerType(moduleTranslation, linearVarType);
3752 for (auto [idx, linearVar] : llvm::enumerate(simdOp.getLinearVars())) {
3753 bool isImplicit = false;
3754 for (auto [mlirPrivVar, llvmPrivateVar] : llvm::zip_equal(
3755 privateVarsInfo.mlirVars, privateVarsInfo.llvmVars)) {
3756 // If the linear variable is implicit, reuse the already
3757 // existing llvm::Value
3758 if (linearVar == mlirPrivVar) {
3759 isImplicit = true;
3760 linearClauseProcessor.createLinearVar(builder, moduleTranslation,
3761 llvmPrivateVar, idx);
3762 break;
3763 }
3764 }
3765
3766 if (!isImplicit)
3767 linearClauseProcessor.createLinearVar(
3768 builder, moduleTranslation,
3769 moduleTranslation.lookupValue(linearVar), idx);
3770 }
3771 for (mlir::Value linearStep : simdOp.getLinearStepVars())
3772 linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
3773 }
3774
3775 if (failed(allocReductionVars(simdOp, reductionArgs, builder,
3776 moduleTranslation, allocaIP, reductionDecls,
3777 privateReductionVariables, reductionVariableMap,
3778 deferredStores, isByRef)))
3779 return failure();
3780
3781 if (handleError(initPrivateVars(builder, moduleTranslation, privateVarsInfo),
3782 opInst)
3783 .failed())
3784 return failure();
3785
3786 // No call to copyFirstPrivateVars because FIRSTPRIVATE is not allowed for
3787 // SIMD.
3788
3789 assert(afterAllocas.get()->getSinglePredecessor());
3790 if (failed(initReductionVars(simdOp, reductionArgs, builder,
3791 moduleTranslation,
3792 afterAllocas.get()->getSinglePredecessor(),
3793 reductionDecls, privateReductionVariables,
3794 reductionVariableMap, isByRef, deferredStores)))
3795 return failure();
3796
3797 llvm::ConstantInt *simdlen = nullptr;
3798 if (std::optional<uint64_t> simdlenVar = simdOp.getSimdlen())
3799 simdlen = builder.getInt64(simdlenVar.value());
3800
3801 llvm::ConstantInt *safelen = nullptr;
3802 if (std::optional<uint64_t> safelenVar = simdOp.getSafelen())
3803 safelen = builder.getInt64(safelenVar.value());
3804
3805 llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars;
3806 llvm::omp::OrderKind order = convertOrderKind(simdOp.getOrder());
3807
3808 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
3809 std::optional<ArrayAttr> alignmentValues = simdOp.getAlignments();
3810 mlir::OperandRange operands = simdOp.getAlignedVars();
3811 for (size_t i = 0; i < operands.size(); ++i) {
3812 llvm::Value *alignment = nullptr;
3813 llvm::Value *llvmVal = moduleTranslation.lookupValue(operands[i]);
3814 llvm::Type *ty = llvmVal->getType();
3815
3816 auto intAttr = cast<IntegerAttr>((*alignmentValues)[i]);
3817 alignment = builder.getInt64(intAttr.getInt());
3818 assert(ty->isPointerTy() && "Invalid type for aligned variable");
3819 assert(alignment && "Invalid alignment value");
3820
3821 // Check if the alignment value is not a power of 2. If so, skip emitting
3822 // alignment.
3823 if (!intAttr.getValue().isPowerOf2())
3824 continue;
3825
3826 auto curInsert = builder.saveIP();
3827 builder.SetInsertPoint(sourceBlock);
3828 llvmVal = builder.CreateLoad(ty, llvmVal);
3829 builder.restoreIP(curInsert);
3830 alignedVars[llvmVal] = alignment;
3831 }
3832
3834 simdOp.getRegion(), "omp.simd.region", builder, moduleTranslation);
3835
3836 if (failed(handleError(regionBlock, opInst)))
3837 return failure();
3838
3839 llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
3840 // Emit Initialization for linear variables
3841 if (simdOp.getLinearVars().size()) {
3842 linearClauseProcessor.initLinearVar(builder, moduleTranslation,
3843 loopInfo->getPreheader());
3844
3845 linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
3846 loopInfo->getIndVar());
3847 }
3848 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
3849
3850 ompBuilder->applySimd(loopInfo, alignedVars,
3851 simdOp.getIfExpr()
3852 ? moduleTranslation.lookupValue(simdOp.getIfExpr())
3853 : nullptr,
3854 order, simdlen, safelen);
3855
3856 linearClauseProcessor.emitStoresForLinearVar(builder);
3857
3858 // Check if this SIMD loop contains ordered regions
3859 bool hasOrderedRegions = false;
3860 simdOp.getRegion().walk([&](omp::OrderedRegionOp orderedOp) {
3861 hasOrderedRegions = true;
3862 return WalkResult::interrupt();
3863 });
3864
3865 for (size_t index = 0; index < simdOp.getLinearVars().size(); index++) {
3866 llvm::BasicBlock *startBB = sourceBlock->getSingleSuccessor();
3867 llvm::BasicBlock *endBB = *regionBlock;
3868 linearClauseProcessor.rewriteInPlace(builder, startBB, endBB,
3869 "omp.loop_nest.region", index);
3870
3871 if (hasOrderedRegions) {
3872 // Also rewrite uses in ordered regions so they read the current value
3873 linearClauseProcessor.rewriteInPlace(builder, startBB, endBB,
3874 "omp.ordered.region", index);
3875 // Also rewrite uses in finalize blocks (code after ordered regions)
3876 linearClauseProcessor.rewriteInPlace(builder, startBB, endBB,
3877 "omp_region.finalize", index);
3878 }
3879 }
3880
3881 // We now need to reduce the per-simd-lane reduction variable into the
3882 // original variable. This works a bit differently to other reductions (e.g.
3883 // wsloop) because we don't need to call into the OpenMP runtime to handle
3884 // threads: everything happened in this one thread.
3885 for (auto [i, tuple] : llvm::enumerate(
3886 llvm::zip(reductionDecls, isByRef, simdOp.getReductionVars(),
3887 privateReductionVariables))) {
3888 auto [decl, byRef, reductionVar, privateReductionVar] = tuple;
3889
3890 OwningReductionGen gen = makeReductionGen(decl, builder, moduleTranslation);
3891 llvm::Value *originalVariable = moduleTranslation.lookupValue(reductionVar);
3892 llvm::Type *reductionType = moduleTranslation.convertType(decl.getType());
3893
3894 // We have one less load for by-ref case because that load is now inside of
3895 // the reduction region.
3896 llvm::Value *redValue = originalVariable;
3897 if (!byRef)
3898 redValue =
3899 builder.CreateLoad(reductionType, redValue, "red.value." + Twine(i));
3900 llvm::Value *privateRedValue = builder.CreateLoad(
3901 reductionType, privateReductionVar, "red.private.value." + Twine(i));
3902 llvm::Value *reduced;
3903
3904 auto res = gen(builder.saveIP(), redValue, privateRedValue, reduced);
3905 if (failed(handleError(res, opInst)))
3906 return failure();
3907 builder.restoreIP(res.get());
3908
3909 // For by-ref case, the store is inside of the reduction region.
3910 if (!byRef)
3911 builder.CreateStore(reduced, originalVariable);
3912 }
3913
3914 // After the construct, deallocate private reduction variables.
3915 SmallVector<Region *> reductionRegions;
3916 llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
3917 [](omp::DeclareReductionOp reductionDecl) {
3918 return &reductionDecl.getCleanupRegion();
3919 });
3920 if (failed(inlineOmpRegionCleanup(reductionRegions, privateReductionVariables,
3921 moduleTranslation, builder,
3922 "omp.reduction.cleanup")))
3923 return failure();
3924
3925 return cleanupPrivateVars(builder, moduleTranslation, simdOp.getLoc(),
3926 privateVarsInfo.llvmVars,
3927 privateVarsInfo.privatizers);
3928}
3929
3930/// Converts an OpenMP loop nest into LLVM IR using OpenMPIRBuilder.
3931static LogicalResult
3932convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder,
3933 LLVM::ModuleTranslation &moduleTranslation) {
3934 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3935 auto loopOp = cast<omp::LoopNestOp>(opInst);
3936
3937 if (failed(checkImplementationStatus(opInst)))
3938 return failure();
3939
3940 // Set up the source location value for OpenMP runtime.
3941 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3942
3943 // Generator of the canonical loop body.
3946 auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip,
3947 llvm::Value *iv) -> llvm::Error {
3948 // Make sure further conversions know about the induction variable.
3949 moduleTranslation.mapValue(
3950 loopOp.getRegion().front().getArgument(loopInfos.size()), iv);
3951
3952 // Capture the body insertion point for use in nested loops. BodyIP of the
3953 // CanonicalLoopInfo always points to the beginning of the entry block of
3954 // the body.
3955 bodyInsertPoints.push_back(ip);
3956
3957 if (loopInfos.size() != loopOp.getNumLoops() - 1)
3958 return llvm::Error::success();
3959
3960 // Convert the body of the loop.
3961 builder.restoreIP(ip);
3963 loopOp.getRegion(), "omp.loop_nest.region", builder, moduleTranslation);
3964 if (!regionBlock)
3965 return regionBlock.takeError();
3966
3967 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
3968 return llvm::Error::success();
3969 };
3970
3971 // Delegate actual loop construction to the OpenMP IRBuilder.
3972 // TODO: this currently assumes omp.loop_nest is semantically similar to SCF
3973 // loop, i.e. it has a positive step, uses signed integer semantics.
3974 // Reconsider this code when the nested loop operation clearly supports more
3975 // cases.
3976 for (unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) {
3977 llvm::Value *lowerBound =
3978 moduleTranslation.lookupValue(loopOp.getLoopLowerBounds()[i]);
3979 llvm::Value *upperBound =
3980 moduleTranslation.lookupValue(loopOp.getLoopUpperBounds()[i]);
3981 llvm::Value *step = moduleTranslation.lookupValue(loopOp.getLoopSteps()[i]);
3982
3983 // Make sure loop trip count are emitted in the preheader of the outermost
3984 // loop at the latest so that they are all available for the new collapsed
3985 // loop will be created below.
3986 llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc;
3987 llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP;
3988 if (i != 0) {
3989 loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back(),
3990 ompLoc.DL);
3991 computeIP = loopInfos.front()->getPreheaderIP();
3992 }
3993
3995 ompBuilder->createCanonicalLoop(
3996 loc, bodyGen, lowerBound, upperBound, step,
3997 /*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP);
3998
3999 if (failed(handleError(loopResult, *loopOp)))
4000 return failure();
4001
4002 loopInfos.push_back(*loopResult);
4003 }
4004
4005 llvm::OpenMPIRBuilder::InsertPointTy afterIP =
4006 loopInfos.front()->getAfterIP();
4007
4008 // Do tiling.
4009 if (const auto &tiles = loopOp.getTileSizes()) {
4010 llvm::Type *ivType = loopInfos.front()->getIndVarType();
4012
4013 for (auto tile : tiles.value()) {
4014 llvm::Value *tileVal = llvm::ConstantInt::get(ivType, tile);
4015 tileSizes.push_back(tileVal);
4016 }
4017
4018 std::vector<llvm::CanonicalLoopInfo *> newLoops =
4019 ompBuilder->tileLoops(ompLoc.DL, loopInfos, tileSizes);
4020
4021 // Update afterIP to get the correct insertion point after
4022 // tiling.
4023 llvm::BasicBlock *afterBB = newLoops.front()->getAfter();
4024 llvm::BasicBlock *afterAfterBB = afterBB->getSingleSuccessor();
4025 afterIP = {afterAfterBB, afterAfterBB->begin()};
4026
4027 // Update the loop infos.
4028 loopInfos.clear();
4029 for (const auto &newLoop : newLoops)
4030 loopInfos.push_back(newLoop);
4031 } // Tiling done.
4032
4033 // Do collapse.
4034 const auto &numCollapse = loopOp.getCollapseNumLoops();
4036 loopInfos.begin(), loopInfos.begin() + (numCollapse));
4037
4038 auto newTopLoopInfo =
4039 ompBuilder->collapseLoops(ompLoc.DL, collapseLoopInfos, {});
4040
4041 assert(newTopLoopInfo && "New top loop information is missing");
4042 moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
4043 [&](OpenMPLoopInfoStackFrame &frame) {
4044 frame.loopInfo = newTopLoopInfo;
4045 return WalkResult::interrupt();
4046 });
4047
4048 // Continue building IR after the loop. Note that the LoopInfo returned by
4049 // `collapseLoops` points inside the outermost loop and is intended for
4050 // potential further loop transformations. Use the insertion point stored
4051 // before collapsing loops instead.
4052 builder.restoreIP(afterIP);
4053 return success();
4054}
4055
4056/// Convert an omp.canonical_loop to LLVM-IR
4057static LogicalResult
4058convertOmpCanonicalLoopOp(omp::CanonicalLoopOp op, llvm::IRBuilderBase &builder,
4059 LLVM::ModuleTranslation &moduleTranslation) {
4060 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4061
4062 llvm::OpenMPIRBuilder::LocationDescription loopLoc(builder);
4063 Value loopIV = op.getInductionVar();
4064 Value loopTC = op.getTripCount();
4065
4066 llvm::Value *llvmTC = moduleTranslation.lookupValue(loopTC);
4067
4069 ompBuilder->createCanonicalLoop(
4070 loopLoc,
4071 [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *llvmIV) {
4072 // Register the mapping of MLIR induction variable to LLVM-IR
4073 // induction variable
4074 moduleTranslation.mapValue(loopIV, llvmIV);
4075
4076 builder.restoreIP(ip);
4078 convertOmpOpRegions(op.getRegion(), "omp.loop.region", builder,
4079 moduleTranslation);
4080
4081 return bodyGenStatus.takeError();
4082 },
4083 llvmTC, "omp.loop");
4084 if (!llvmOrError)
4085 return op.emitError(llvm::toString(llvmOrError.takeError()));
4086
4087 llvm::CanonicalLoopInfo *llvmCLI = *llvmOrError;
4088 llvm::IRBuilderBase::InsertPoint afterIP = llvmCLI->getAfterIP();
4089 builder.restoreIP(afterIP);
4090
4091 // Register the mapping of MLIR loop to LLVM-IR OpenMPIRBuilder loop
4092 if (Value cli = op.getCli())
4093 moduleTranslation.mapOmpLoop(cli, llvmCLI);
4094
4095 return success();
4096}
4097
4098/// Apply a `#pragma omp unroll` / "!$omp unroll" transformation using the
4099/// OpenMPIRBuilder.
4100static LogicalResult
4101applyUnrollHeuristic(omp::UnrollHeuristicOp op, llvm::IRBuilderBase &builder,
4102 LLVM::ModuleTranslation &moduleTranslation) {
4103 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4104
4105 Value applyee = op.getApplyee();
4106 assert(applyee && "Loop to apply unrolling on required");
4107
4108 llvm::CanonicalLoopInfo *consBuilderCLI =
4109 moduleTranslation.lookupOMPLoop(applyee);
4110 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
4111 ompBuilder->unrollLoopHeuristic(loc.DL, consBuilderCLI);
4112
4113 moduleTranslation.invalidateOmpLoop(applyee);
4114 return success();
4115}
4116
4117/// Apply a `#pragma omp tile` / `!$omp tile` transformation using the
4118/// OpenMPIRBuilder.
4119static LogicalResult applyTile(omp::TileOp op, llvm::IRBuilderBase &builder,
4120 LLVM::ModuleTranslation &moduleTranslation) {
4121 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4122 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
4123
4125 SmallVector<llvm::Value *> translatedSizes;
4126
4127 for (Value size : op.getSizes()) {
4128 llvm::Value *translatedSize = moduleTranslation.lookupValue(size);
4129 assert(translatedSize &&
4130 "sizes clause arguments must already be translated");
4131 translatedSizes.push_back(translatedSize);
4132 }
4133
4134 for (Value applyee : op.getApplyees()) {
4135 llvm::CanonicalLoopInfo *consBuilderCLI =
4136 moduleTranslation.lookupOMPLoop(applyee);
4137 assert(applyee && "Canonical loop must already been translated");
4138 translatedLoops.push_back(consBuilderCLI);
4139 }
4140
4141 auto generatedLoops =
4142 ompBuilder->tileLoops(loc.DL, translatedLoops, translatedSizes);
4143 if (!op.getGeneratees().empty()) {
4144 for (auto [mlirLoop, genLoop] :
4145 zip_equal(op.getGeneratees(), generatedLoops))
4146 moduleTranslation.mapOmpLoop(mlirLoop, genLoop);
4147 }
4148
4149 // CLIs can only be consumed once
4150 for (Value applyee : op.getApplyees())
4151 moduleTranslation.invalidateOmpLoop(applyee);
4152
4153 return success();
4154}
4155
4156/// Apply a `#pragma omp fuse` / `!$omp fuse` transformation using the
4157/// OpenMPIRBuilder.
4158static LogicalResult applyFuse(omp::FuseOp op, llvm::IRBuilderBase &builder,
4159 LLVM::ModuleTranslation &moduleTranslation) {
4160 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4161 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
4162
4163 // Select what CLIs are going to be fused
4164 SmallVector<llvm::CanonicalLoopInfo *> beforeFuse, toFuse, afterFuse;
4165 for (size_t i = 0; i < op.getApplyees().size(); i++) {
4166 Value applyee = op.getApplyees()[i];
4167 llvm::CanonicalLoopInfo *consBuilderCLI =
4168 moduleTranslation.lookupOMPLoop(applyee);
4169 assert(applyee && "Canonical loop must already been translated");
4170 if (op.getFirst().has_value() && i < op.getFirst().value() - 1)
4171 beforeFuse.push_back(consBuilderCLI);
4172 else if (op.getCount().has_value() &&
4173 i >= op.getFirst().value() + op.getCount().value() - 1)
4174 afterFuse.push_back(consBuilderCLI);
4175 else
4176 toFuse.push_back(consBuilderCLI);
4177 }
4178 assert(
4179 (op.getGeneratees().empty() ||
4180 beforeFuse.size() + afterFuse.size() + 1 == op.getGeneratees().size()) &&
4181 "Wrong number of generatees");
4182
4183 // do the fuse
4184 auto generatedLoop = ompBuilder->fuseLoops(loc.DL, toFuse);
4185 if (!op.getGeneratees().empty()) {
4186 size_t i = 0;
4187 for (; i < beforeFuse.size(); i++)
4188 moduleTranslation.mapOmpLoop(op.getGeneratees()[i], beforeFuse[i]);
4189 moduleTranslation.mapOmpLoop(op.getGeneratees()[i++], generatedLoop);
4190 for (; i < afterFuse.size(); i++)
4191 moduleTranslation.mapOmpLoop(op.getGeneratees()[i], afterFuse[i]);
4192 }
4193
4194 // CLIs can only be consumed once
4195 for (Value applyee : op.getApplyees())
4196 moduleTranslation.invalidateOmpLoop(applyee);
4197
4198 return success();
4199}
4200
4201/// Convert an Atomic Ordering attribute to llvm::AtomicOrdering.
4202static llvm::AtomicOrdering
4203convertAtomicOrdering(std::optional<omp::ClauseMemoryOrderKind> ao) {
4204 if (!ao)
4205 return llvm::AtomicOrdering::Monotonic; // Default Memory Ordering
4206
4207 switch (*ao) {
4208 case omp::ClauseMemoryOrderKind::Seq_cst:
4209 return llvm::AtomicOrdering::SequentiallyConsistent;
4210 case omp::ClauseMemoryOrderKind::Acq_rel:
4211 return llvm::AtomicOrdering::AcquireRelease;
4212 case omp::ClauseMemoryOrderKind::Acquire:
4213 return llvm::AtomicOrdering::Acquire;
4214 case omp::ClauseMemoryOrderKind::Release:
4215 return llvm::AtomicOrdering::Release;
4216 case omp::ClauseMemoryOrderKind::Relaxed:
4217 return llvm::AtomicOrdering::Monotonic;
4218 }
4219 llvm_unreachable("Unknown ClauseMemoryOrderKind kind");
4220}
4221
4222/// Convert omp.atomic.read operation to LLVM IR.
4223static LogicalResult
4224convertOmpAtomicRead(Operation &opInst, llvm::IRBuilderBase &builder,
4225 LLVM::ModuleTranslation &moduleTranslation) {
4226 auto readOp = cast<omp::AtomicReadOp>(opInst);
4227 if (failed(checkImplementationStatus(opInst)))
4228 return failure();
4229
4230 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4231 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4232 findAllocaInsertPoint(builder, moduleTranslation);
4233
4234 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4235
4236 llvm::AtomicOrdering AO = convertAtomicOrdering(readOp.getMemoryOrder());
4237 llvm::Value *x = moduleTranslation.lookupValue(readOp.getX());
4238 llvm::Value *v = moduleTranslation.lookupValue(readOp.getV());
4239
4240 llvm::Type *elementType =
4241 moduleTranslation.convertType(readOp.getElementType());
4242
4243 llvm::OpenMPIRBuilder::AtomicOpValue V = {v, elementType, false, false};
4244 llvm::OpenMPIRBuilder::AtomicOpValue X = {x, elementType, false, false};
4245 builder.restoreIP(ompBuilder->createAtomicRead(ompLoc, X, V, AO, allocaIP));
4246 return success();
4247}
4248
4249/// Converts an omp.atomic.write operation to LLVM IR.
4250static LogicalResult
4251convertOmpAtomicWrite(Operation &opInst, llvm::IRBuilderBase &builder,
4252 LLVM::ModuleTranslation &moduleTranslation) {
4253 auto writeOp = cast<omp::AtomicWriteOp>(opInst);
4254 if (failed(checkImplementationStatus(opInst)))
4255 return failure();
4256
4257 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4258 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4259 findAllocaInsertPoint(builder, moduleTranslation);
4260
4261 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4262 llvm::AtomicOrdering ao = convertAtomicOrdering(writeOp.getMemoryOrder());
4263 llvm::Value *expr = moduleTranslation.lookupValue(writeOp.getExpr());
4264 llvm::Value *dest = moduleTranslation.lookupValue(writeOp.getX());
4265 llvm::Type *ty = moduleTranslation.convertType(writeOp.getExpr().getType());
4266 llvm::OpenMPIRBuilder::AtomicOpValue x = {dest, ty, /*isSigned=*/false,
4267 /*isVolatile=*/false};
4268 builder.restoreIP(
4269 ompBuilder->createAtomicWrite(ompLoc, x, expr, ao, allocaIP));
4270 return success();
4271}
4272
4273/// Converts an LLVM dialect binary operation to the corresponding enum value
4274/// for `atomicrmw` supported binary operation.
4275static llvm::AtomicRMWInst::BinOp convertBinOpToAtomic(Operation &op) {
4277 .Case([&](LLVM::AddOp) { return llvm::AtomicRMWInst::BinOp::Add; })
4278 .Case([&](LLVM::SubOp) { return llvm::AtomicRMWInst::BinOp::Sub; })
4279 .Case([&](LLVM::AndOp) { return llvm::AtomicRMWInst::BinOp::And; })
4280 .Case([&](LLVM::OrOp) { return llvm::AtomicRMWInst::BinOp::Or; })
4281 .Case([&](LLVM::XOrOp) { return llvm::AtomicRMWInst::BinOp::Xor; })
4282 .Case([&](LLVM::UMaxOp) { return llvm::AtomicRMWInst::BinOp::UMax; })
4283 .Case([&](LLVM::UMinOp) { return llvm::AtomicRMWInst::BinOp::UMin; })
4284 .Case([&](LLVM::FAddOp) { return llvm::AtomicRMWInst::BinOp::FAdd; })
4285 .Case([&](LLVM::FSubOp) { return llvm::AtomicRMWInst::BinOp::FSub; })
4286 .Default(llvm::AtomicRMWInst::BinOp::BAD_BINOP);
4287}
4288
4289static void extractAtomicControlFlags(omp::AtomicUpdateOp atomicUpdateOp,
4290 bool &isIgnoreDenormalMode,
4291 bool &isFineGrainedMemory,
4292 bool &isRemoteMemory) {
4293 isIgnoreDenormalMode = false;
4294 isFineGrainedMemory = false;
4295 isRemoteMemory = false;
4296 if (atomicUpdateOp &&
4297 atomicUpdateOp->hasAttr(atomicUpdateOp.getAtomicControlAttrName())) {
4298 mlir::omp::AtomicControlAttr atomicControlAttr =
4299 atomicUpdateOp.getAtomicControlAttr();
4300 isIgnoreDenormalMode = atomicControlAttr.getIgnoreDenormalMode();
4301 isFineGrainedMemory = atomicControlAttr.getFineGrainedMemory();
4302 isRemoteMemory = atomicControlAttr.getRemoteMemory();
4303 }
4304}
4305
4306/// Converts an OpenMP atomic update operation using OpenMPIRBuilder.
4307static LogicalResult
4308convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst,
4309 llvm::IRBuilderBase &builder,
4310 LLVM::ModuleTranslation &moduleTranslation) {
4311 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4312 if (failed(checkImplementationStatus(*opInst)))
4313 return failure();
4314
4315 // Convert values and types.
4316 auto &innerOpList = opInst.getRegion().front().getOperations();
4317 bool isXBinopExpr{false};
4318 llvm::AtomicRMWInst::BinOp binop;
4319 mlir::Value mlirExpr;
4320 llvm::Value *llvmExpr = nullptr;
4321 llvm::Value *llvmX = nullptr;
4322 llvm::Type *llvmXElementType = nullptr;
4323 if (innerOpList.size() == 2) {
4324 // The two operations here are the update and the terminator.
4325 // Since we can identify the update operation, there is a possibility
4326 // that we can generate the atomicrmw instruction.
4327 mlir::Operation &innerOp = *opInst.getRegion().front().begin();
4328 if (!llvm::is_contained(innerOp.getOperands(),
4329 opInst.getRegion().getArgument(0))) {
4330 return opInst.emitError("no atomic update operation with region argument"
4331 " as operand found inside atomic.update region");
4332 }
4333 binop = convertBinOpToAtomic(innerOp);
4334 isXBinopExpr = innerOp.getOperand(0) == opInst.getRegion().getArgument(0);
4335 mlirExpr = (isXBinopExpr ? innerOp.getOperand(1) : innerOp.getOperand(0));
4336 llvmExpr = moduleTranslation.lookupValue(mlirExpr);
4337 } else {
4338 // Since the update region includes more than one operation
4339 // we will resort to generating a cmpxchg loop.
4340 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
4341 }
4342 llvmX = moduleTranslation.lookupValue(opInst.getX());
4343 llvmXElementType = moduleTranslation.convertType(
4344 opInst.getRegion().getArgument(0).getType());
4345 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
4346 /*isSigned=*/false,
4347 /*isVolatile=*/false};
4348
4349 llvm::AtomicOrdering atomicOrdering =
4350 convertAtomicOrdering(opInst.getMemoryOrder());
4351
4352 // Generate update code.
4353 auto updateFn =
4354 [&opInst, &moduleTranslation](
4355 llvm::Value *atomicx,
4356 llvm::IRBuilder<> &builder) -> llvm::Expected<llvm::Value *> {
4357 Block &bb = *opInst.getRegion().begin();
4358 moduleTranslation.mapValue(*opInst.getRegion().args_begin(), atomicx);
4359 moduleTranslation.mapBlock(&bb, builder.GetInsertBlock());
4360 if (failed(moduleTranslation.convertBlock(bb, true, builder)))
4361 return llvm::make_error<PreviouslyReportedError>();
4362
4363 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.getTerminator());
4364 assert(yieldop && yieldop.getResults().size() == 1 &&
4365 "terminator must be omp.yield op and it must have exactly one "
4366 "argument");
4367 return moduleTranslation.lookupValue(yieldop.getResults()[0]);
4368 };
4369
4370 bool isIgnoreDenormalMode;
4371 bool isFineGrainedMemory;
4372 bool isRemoteMemory;
4373 extractAtomicControlFlags(opInst, isIgnoreDenormalMode, isFineGrainedMemory,
4374 isRemoteMemory);
4375 // Handle ambiguous alloca, if any.
4376 auto allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
4377 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4378 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4379 ompBuilder->createAtomicUpdate(ompLoc, allocaIP, llvmAtomicX, llvmExpr,
4380 atomicOrdering, binop, updateFn,
4381 isXBinopExpr, isIgnoreDenormalMode,
4382 isFineGrainedMemory, isRemoteMemory);
4383
4384 if (failed(handleError(afterIP, *opInst)))
4385 return failure();
4386
4387 builder.restoreIP(*afterIP);
4388 return success();
4389}
4390
4391static LogicalResult
4392convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp,
4393 llvm::IRBuilderBase &builder,
4394 LLVM::ModuleTranslation &moduleTranslation) {
4395 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4396 if (failed(checkImplementationStatus(*atomicCaptureOp)))
4397 return failure();
4398
4399 mlir::Value mlirExpr;
4400 bool isXBinopExpr = false, isPostfixUpdate = false;
4401 llvm::AtomicRMWInst::BinOp binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
4402
4403 omp::AtomicUpdateOp atomicUpdateOp = atomicCaptureOp.getAtomicUpdateOp();
4404 omp::AtomicWriteOp atomicWriteOp = atomicCaptureOp.getAtomicWriteOp();
4405
4406 assert((atomicUpdateOp || atomicWriteOp) &&
4407 "internal op must be an atomic.update or atomic.write op");
4408
4409 if (atomicWriteOp) {
4410 isPostfixUpdate = true;
4411 mlirExpr = atomicWriteOp.getExpr();
4412 } else {
4413 isPostfixUpdate = atomicCaptureOp.getSecondOp() ==
4414 atomicCaptureOp.getAtomicUpdateOp().getOperation();
4415 auto &innerOpList = atomicUpdateOp.getRegion().front().getOperations();
4416 // Find the binary update operation that uses the region argument
4417 // and get the expression to update
4418 if (innerOpList.size() == 2) {
4419 mlir::Operation &innerOp = *atomicUpdateOp.getRegion().front().begin();
4420 if (!llvm::is_contained(innerOp.getOperands(),
4421 atomicUpdateOp.getRegion().getArgument(0))) {
4422 return atomicUpdateOp.emitError(
4423 "no atomic update operation with region argument"
4424 " as operand found inside atomic.update region");
4425 }
4426 binop = convertBinOpToAtomic(innerOp);
4427 isXBinopExpr =
4428 innerOp.getOperand(0) == atomicUpdateOp.getRegion().getArgument(0);
4429 mlirExpr = (isXBinopExpr ? innerOp.getOperand(1) : innerOp.getOperand(0));
4430 } else {
4431 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
4432 }
4433 }
4434
4435 llvm::Value *llvmExpr = moduleTranslation.lookupValue(mlirExpr);
4436 llvm::Value *llvmX =
4437 moduleTranslation.lookupValue(atomicCaptureOp.getAtomicReadOp().getX());
4438 llvm::Value *llvmV =
4439 moduleTranslation.lookupValue(atomicCaptureOp.getAtomicReadOp().getV());
4440 llvm::Type *llvmXElementType = moduleTranslation.convertType(
4441 atomicCaptureOp.getAtomicReadOp().getElementType());
4442 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
4443 /*isSigned=*/false,
4444 /*isVolatile=*/false};
4445 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicV = {llvmV, llvmXElementType,
4446 /*isSigned=*/false,
4447 /*isVolatile=*/false};
4448
4449 llvm::AtomicOrdering atomicOrdering =
4450 convertAtomicOrdering(atomicCaptureOp.getMemoryOrder());
4451
4452 auto updateFn =
4453 [&](llvm::Value *atomicx,
4454 llvm::IRBuilder<> &builder) -> llvm::Expected<llvm::Value *> {
4455 if (atomicWriteOp)
4456 return moduleTranslation.lookupValue(atomicWriteOp.getExpr());
4457 Block &bb = *atomicUpdateOp.getRegion().begin();
4458 moduleTranslation.mapValue(*atomicUpdateOp.getRegion().args_begin(),
4459 atomicx);
4460 moduleTranslation.mapBlock(&bb, builder.GetInsertBlock());
4461 if (failed(moduleTranslation.convertBlock(bb, true, builder)))
4462 return llvm::make_error<PreviouslyReportedError>();
4463
4464 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.getTerminator());
4465 assert(yieldop && yieldop.getResults().size() == 1 &&
4466 "terminator must be omp.yield op and it must have exactly one "
4467 "argument");
4468 return moduleTranslation.lookupValue(yieldop.getResults()[0]);
4469 };
4470
4471 bool isIgnoreDenormalMode;
4472 bool isFineGrainedMemory;
4473 bool isRemoteMemory;
4474 extractAtomicControlFlags(atomicUpdateOp, isIgnoreDenormalMode,
4475 isFineGrainedMemory, isRemoteMemory);
4476 // Handle ambiguous alloca, if any.
4477 auto allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
4478 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4479 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4480 ompBuilder->createAtomicCapture(
4481 ompLoc, allocaIP, llvmAtomicX, llvmAtomicV, llvmExpr, atomicOrdering,
4482 binop, updateFn, atomicUpdateOp, isPostfixUpdate, isXBinopExpr,
4483 isIgnoreDenormalMode, isFineGrainedMemory, isRemoteMemory);
4484
4485 if (failed(handleError(afterIP, *atomicCaptureOp)))
4486 return failure();
4487
4488 builder.restoreIP(*afterIP);
4489 return success();
4490}
4491
4492static llvm::omp::Directive convertCancellationConstructType(
4493 omp::ClauseCancellationConstructType directive) {
4494 switch (directive) {
4495 case omp::ClauseCancellationConstructType::Loop:
4496 return llvm::omp::Directive::OMPD_for;
4497 case omp::ClauseCancellationConstructType::Parallel:
4498 return llvm::omp::Directive::OMPD_parallel;
4499 case omp::ClauseCancellationConstructType::Sections:
4500 return llvm::omp::Directive::OMPD_sections;
4501 case omp::ClauseCancellationConstructType::Taskgroup:
4502 return llvm::omp::Directive::OMPD_taskgroup;
4503 }
4504 llvm_unreachable("Unhandled cancellation construct type");
4505}
4506
4507static LogicalResult
4508convertOmpCancel(omp::CancelOp op, llvm::IRBuilderBase &builder,
4509 LLVM::ModuleTranslation &moduleTranslation) {
4510 if (failed(checkImplementationStatus(*op.getOperation())))
4511 return failure();
4512
4513 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4514 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4515
4516 llvm::Value *ifCond = nullptr;
4517 if (Value ifVar = op.getIfExpr())
4518 ifCond = moduleTranslation.lookupValue(ifVar);
4519
4520 llvm::omp::Directive cancelledDirective =
4521 convertCancellationConstructType(op.getCancelDirective());
4522
4523 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4524 ompBuilder->createCancel(ompLoc, ifCond, cancelledDirective);
4525
4526 if (failed(handleError(afterIP, *op.getOperation())))
4527 return failure();
4528
4529 builder.restoreIP(afterIP.get());
4530
4531 return success();
4532}
4533
4534static LogicalResult
4535convertOmpCancellationPoint(omp::CancellationPointOp op,
4536 llvm::IRBuilderBase &builder,
4537 LLVM::ModuleTranslation &moduleTranslation) {
4538 if (failed(checkImplementationStatus(*op.getOperation())))
4539 return failure();
4540
4541 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4542 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4543
4544 llvm::omp::Directive cancelledDirective =
4545 convertCancellationConstructType(op.getCancelDirective());
4546
4547 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4548 ompBuilder->createCancellationPoint(ompLoc, cancelledDirective);
4549
4550 if (failed(handleError(afterIP, *op.getOperation())))
4551 return failure();
4552
4553 builder.restoreIP(afterIP.get());
4554
4555 return success();
4556}
4557
4558/// Converts an OpenMP Threadprivate operation into LLVM IR using
4559/// OpenMPIRBuilder.
4560static LogicalResult
4561convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder,
4562 LLVM::ModuleTranslation &moduleTranslation) {
4563 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4564 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4565 auto threadprivateOp = cast<omp::ThreadprivateOp>(opInst);
4566
4567 if (failed(checkImplementationStatus(opInst)))
4568 return failure();
4569
4570 Value symAddr = threadprivateOp.getSymAddr();
4571 auto *symOp = symAddr.getDefiningOp();
4572
4573 if (auto asCast = dyn_cast<LLVM::AddrSpaceCastOp>(symOp))
4574 symOp = asCast.getOperand().getDefiningOp();
4575
4576 if (!isa<LLVM::AddressOfOp>(symOp))
4577 return opInst.emitError("Addressing symbol not found");
4578 LLVM::AddressOfOp addressOfOp = dyn_cast<LLVM::AddressOfOp>(symOp);
4579
4580 LLVM::GlobalOp global =
4581 addressOfOp.getGlobal(moduleTranslation.symbolTable());
4582 llvm::GlobalValue *globalValue = moduleTranslation.lookupGlobal(global);
4583 llvm::Type *type = globalValue->getValueType();
4584 llvm::TypeSize typeSize =
4585 builder.GetInsertBlock()->getModule()->getDataLayout().getTypeStoreSize(
4586 type);
4587 llvm::ConstantInt *size = builder.getInt64(typeSize.getFixedValue());
4588 llvm::Value *callInst = ompBuilder->createCachedThreadPrivate(
4589 ompLoc, globalValue, size, global.getSymName() + ".cache");
4590 moduleTranslation.mapValue(opInst.getResult(0), callInst);
4591
4592 return success();
4593}
4594
4595static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind
4596convertToDeviceClauseKind(mlir::omp::DeclareTargetDeviceType deviceClause) {
4597 switch (deviceClause) {
4598 case mlir::omp::DeclareTargetDeviceType::host:
4599 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseHost;
4600 break;
4601 case mlir::omp::DeclareTargetDeviceType::nohost:
4602 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseNoHost;
4603 break;
4604 case mlir::omp::DeclareTargetDeviceType::any:
4605 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseAny;
4606 break;
4607 }
4608 llvm_unreachable("unhandled device clause");
4609}
4610
4611static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind
4613 mlir::omp::DeclareTargetCaptureClause captureClause) {
4614 switch (captureClause) {
4615 case mlir::omp::DeclareTargetCaptureClause::to:
4616 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo;
4617 case mlir::omp::DeclareTargetCaptureClause::link:
4618 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink;
4619 case mlir::omp::DeclareTargetCaptureClause::enter:
4620 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter;
4621 case mlir::omp::DeclareTargetCaptureClause::none:
4622 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryNone;
4623 }
4624 llvm_unreachable("unhandled capture clause");
4625}
4626
4628 Operation *op = value.getDefiningOp();
4629 if (auto addrCast = dyn_cast_if_present<LLVM::AddrSpaceCastOp>(op))
4630 op = addrCast->getOperand(0).getDefiningOp();
4631 if (auto addressOfOp = dyn_cast_if_present<LLVM::AddressOfOp>(op)) {
4632 auto modOp = addressOfOp->getParentOfType<mlir::ModuleOp>();
4633 return modOp.lookupSymbol(addressOfOp.getGlobalName());
4634 }
4635 return nullptr;
4636}
4637
4638static llvm::SmallString<64>
4639getDeclareTargetRefPtrSuffix(LLVM::GlobalOp globalOp,
4640 llvm::OpenMPIRBuilder &ompBuilder) {
4641 llvm::SmallString<64> suffix;
4642 llvm::raw_svector_ostream os(suffix);
4643 if (globalOp.getVisibility() == mlir::SymbolTable::Visibility::Private) {
4644 auto loc = globalOp->getLoc()->findInstanceOf<FileLineColLoc>();
4645 auto fileInfoCallBack = [&loc]() {
4646 return std::pair<std::string, uint64_t>(
4647 llvm::StringRef(loc.getFilename()), loc.getLine());
4648 };
4649
4650 auto vfs = llvm::vfs::getRealFileSystem();
4651 os << llvm::format(
4652 "_%x",
4653 ompBuilder.getTargetEntryUniqueInfo(fileInfoCallBack, *vfs).FileID);
4654 }
4655 os << "_decl_tgt_ref_ptr";
4656
4657 return suffix;
4658}
4659
4660static bool isDeclareTargetLink(Value value) {
4661 if (auto declareTargetGlobal =
4662 dyn_cast_if_present<omp::DeclareTargetInterface>(
4663 getGlobalOpFromValue(value)))
4664 if (declareTargetGlobal.getDeclareTargetCaptureClause() ==
4665 omp::DeclareTargetCaptureClause::link)
4666 return true;
4667 return false;
4668}
4669
4670static bool isDeclareTargetTo(Value value) {
4671 if (auto declareTargetGlobal =
4672 dyn_cast_if_present<omp::DeclareTargetInterface>(
4673 getGlobalOpFromValue(value)))
4674 if (declareTargetGlobal.getDeclareTargetCaptureClause() ==
4675 omp::DeclareTargetCaptureClause::to ||
4676 declareTargetGlobal.getDeclareTargetCaptureClause() ==
4677 omp::DeclareTargetCaptureClause::enter)
4678 return true;
4679 return false;
4680}
4681
4682// Returns the reference pointer generated by the lowering of the declare
4683// target operation in cases where the link clause is used or the to clause is
4684// used in USM mode.
4685static llvm::Value *
4687 LLVM::ModuleTranslation &moduleTranslation) {
4688 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4689 if (auto gOp =
4690 dyn_cast_or_null<LLVM::GlobalOp>(getGlobalOpFromValue(value))) {
4691 if (auto declareTargetGlobal =
4692 dyn_cast<omp::DeclareTargetInterface>(gOp.getOperation())) {
4693 // In this case, we must utilise the reference pointer generated by
4694 // the declare target operation, similar to Clang
4695 if ((declareTargetGlobal.getDeclareTargetCaptureClause() ==
4696 omp::DeclareTargetCaptureClause::link) ||
4697 (declareTargetGlobal.getDeclareTargetCaptureClause() ==
4698 omp::DeclareTargetCaptureClause::to &&
4699 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
4700 llvm::SmallString<64> suffix =
4701 getDeclareTargetRefPtrSuffix(gOp, *ompBuilder);
4702
4703 if (gOp.getSymName().contains(suffix))
4704 return moduleTranslation.getLLVMModule()->getNamedValue(
4705 gOp.getSymName());
4706
4707 return moduleTranslation.getLLVMModule()->getNamedValue(
4708 (gOp.getSymName().str() + suffix.str()).str());
4709 }
4710 }
4711 }
4712 return nullptr;
4713}
4714
4715namespace {
4716// Append customMappers information to existing MapInfosTy
4717struct MapInfosTy : llvm::OpenMPIRBuilder::MapInfosTy {
4718 SmallVector<Operation *, 4> Mappers;
4719
4720 /// Append arrays in \a CurInfo.
4721 void append(MapInfosTy &curInfo) {
4722 Mappers.append(curInfo.Mappers.begin(), curInfo.Mappers.end());
4723 llvm::OpenMPIRBuilder::MapInfosTy::append(curInfo);
4724 }
4725};
4726// A small helper structure to contain data gathered
4727// for map lowering and coalese it into one area and
4728// avoiding extra computations such as searches in the
4729// llvm module for lowered mapped variables or checking
4730// if something is declare target (and retrieving the
4731// value) more than neccessary.
4732struct MapInfoData : MapInfosTy {
4733 llvm::SmallVector<bool, 4> IsDeclareTarget;
4734 llvm::SmallVector<bool, 4> IsAMember;
4735 // Identify if mapping was added by mapClause or use_device clauses.
4736 llvm::SmallVector<bool, 4> IsAMapping;
4737 llvm::SmallVector<mlir::Operation *, 4> MapClause;
4738 llvm::SmallVector<llvm::Value *, 4> OriginalValue;
4739 // Stripped off array/pointer to get the underlying
4740 // element type
4741 llvm::SmallVector<llvm::Type *, 4> BaseType;
4742
4743 /// Append arrays in \a CurInfo.
4744 void append(MapInfoData &CurInfo) {
4745 IsDeclareTarget.append(CurInfo.IsDeclareTarget.begin(),
4746 CurInfo.IsDeclareTarget.end());
4747 MapClause.append(CurInfo.MapClause.begin(), CurInfo.MapClause.end());
4748 OriginalValue.append(CurInfo.OriginalValue.begin(),
4749 CurInfo.OriginalValue.end());
4750 BaseType.append(CurInfo.BaseType.begin(), CurInfo.BaseType.end());
4751 MapInfosTy::append(CurInfo);
4752 }
4753};
4754
4755enum class TargetDirectiveEnumTy : uint32_t {
4756 None = 0,
4757 Target = 1,
4758 TargetData = 2,
4759 TargetEnterData = 3,
4760 TargetExitData = 4,
4761 TargetUpdate = 5
4762};
4763
4764static TargetDirectiveEnumTy getTargetDirectiveEnumTyFromOp(Operation *op) {
4765 return llvm::TypeSwitch<Operation *, TargetDirectiveEnumTy>(op)
4766 .Case([](omp::TargetDataOp) { return TargetDirectiveEnumTy::TargetData; })
4767 .Case([](omp::TargetEnterDataOp) {
4768 return TargetDirectiveEnumTy::TargetEnterData;
4769 })
4770 .Case([&](omp::TargetExitDataOp) {
4771 return TargetDirectiveEnumTy::TargetExitData;
4772 })
4773 .Case([&](omp::TargetUpdateOp) {
4774 return TargetDirectiveEnumTy::TargetUpdate;
4775 })
4776 .Case([&](omp::TargetOp) { return TargetDirectiveEnumTy::Target; })
4777 .Default([&](Operation *op) { return TargetDirectiveEnumTy::None; });
4778}
4779
4780} // namespace
4781
4782static uint64_t getArrayElementSizeInBits(LLVM::LLVMArrayType arrTy,
4783 DataLayout &dl) {
4784 if (auto nestedArrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(
4785 arrTy.getElementType()))
4786 return getArrayElementSizeInBits(nestedArrTy, dl);
4787 return dl.getTypeSizeInBits(arrTy.getElementType());
4788}
4789
4790// This function calculates the size to be offloaded for a specified type, given
4791// its associated map clause (which can contain bounds information which affects
4792// the total size), this size is calculated based on the underlying element type
4793// e.g. given a 1-D array of ints, we will calculate the size from the integer
4794// type * number of elements in the array. This size can be used in other
4795// calculations but is ultimately used as an argument to the OpenMP runtimes
4796// kernel argument structure which is generated through the combinedInfo data
4797// structures.
4798// This function is somewhat equivalent to Clang's getExprTypeSize inside of
4799// CGOpenMPRuntime.cpp.
4800static llvm::Value *getSizeInBytes(DataLayout &dl, const mlir::Type &type,
4801 Operation *clauseOp,
4802 llvm::Value *basePointer,
4803 llvm::Type *baseType,
4804 llvm::IRBuilderBase &builder,
4805 LLVM::ModuleTranslation &moduleTranslation) {
4806 if (auto memberClause =
4807 mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(clauseOp)) {
4808 // This calculates the size to transfer based on bounds and the underlying
4809 // element type, provided bounds have been specified (Fortran
4810 // pointers/allocatables/target and arrays that have sections specified fall
4811 // into this as well)
4812 if (!memberClause.getBounds().empty()) {
4813 llvm::Value *elementCount = builder.getInt64(1);
4814 for (auto bounds : memberClause.getBounds()) {
4815 if (auto boundOp = mlir::dyn_cast_if_present<mlir::omp::MapBoundsOp>(
4816 bounds.getDefiningOp())) {
4817 // The below calculation for the size to be mapped calculated from the
4818 // map.info's bounds is: (elemCount * [UB - LB] + 1), later we
4819 // multiply by the underlying element types byte size to get the full
4820 // size to be offloaded based on the bounds
4821 elementCount = builder.CreateMul(
4822 elementCount,
4823 builder.CreateAdd(
4824 builder.CreateSub(
4825 moduleTranslation.lookupValue(boundOp.getUpperBound()),
4826 moduleTranslation.lookupValue(boundOp.getLowerBound())),
4827 builder.getInt64(1)));
4828 }
4829 }
4830
4831 // utilising getTypeSizeInBits instead of getTypeSize as getTypeSize gives
4832 // the size in inconsistent byte or bit format.
4833 uint64_t underlyingTypeSzInBits = dl.getTypeSizeInBits(type);
4834 if (auto arrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(type))
4835 underlyingTypeSzInBits = getArrayElementSizeInBits(arrTy, dl);
4836
4837 // The size in bytes x number of elements, the sizeInBytes stored is
4838 // the underyling types size, e.g. if ptr<i32>, it'll be the i32's
4839 // size, so we do some on the fly runtime math to get the size in
4840 // bytes from the extent (ub - lb) * sizeInBytes. NOTE: This may need
4841 // some adjustment for members with more complex types.
4842 return builder.CreateMul(elementCount,
4843 builder.getInt64(underlyingTypeSzInBits / 8));
4844 }
4845 }
4846
4847 return builder.getInt64(dl.getTypeSizeInBits(type) / 8);
4848}
4849
4850// Convert the MLIR map flag set to the runtime map flag set for embedding
4851// in LLVM-IR. This is important as the two bit-flag lists do not correspond
4852// 1-to-1 as there's flags the runtime doesn't care about and vice versa.
4853// Certain flags are discarded here such as RefPtee and co.
4854static llvm::omp::OpenMPOffloadMappingFlags
4855convertClauseMapFlags(omp::ClauseMapFlags mlirFlags) {
4856 auto mapTypeToBool = [&mlirFlags](omp::ClauseMapFlags flag) {
4857 return (mlirFlags & flag) == flag;
4858 };
4859 const bool hasExplicitMap =
4860 (mlirFlags & ~omp::ClauseMapFlags::is_device_ptr) !=
4861 omp::ClauseMapFlags::none;
4862
4863 llvm::omp::OpenMPOffloadMappingFlags mapType =
4864 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
4865
4866 if (mapTypeToBool(omp::ClauseMapFlags::to))
4867 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
4868
4869 if (mapTypeToBool(omp::ClauseMapFlags::from))
4870 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
4871
4872 if (mapTypeToBool(omp::ClauseMapFlags::always))
4873 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
4874
4875 if (mapTypeToBool(omp::ClauseMapFlags::del))
4876 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
4877
4878 if (mapTypeToBool(omp::ClauseMapFlags::return_param))
4879 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
4880
4881 if (mapTypeToBool(omp::ClauseMapFlags::priv))
4882 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRIVATE;
4883
4884 if (mapTypeToBool(omp::ClauseMapFlags::literal))
4885 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
4886
4887 if (mapTypeToBool(omp::ClauseMapFlags::implicit))
4888 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
4889
4890 if (mapTypeToBool(omp::ClauseMapFlags::close))
4891 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
4892
4893 if (mapTypeToBool(omp::ClauseMapFlags::present))
4894 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT;
4895
4896 if (mapTypeToBool(omp::ClauseMapFlags::ompx_hold))
4897 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD;
4898
4899 if (mapTypeToBool(omp::ClauseMapFlags::attach))
4900 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH;
4901
4902 if (mapTypeToBool(omp::ClauseMapFlags::is_device_ptr)) {
4903 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
4904 if (!hasExplicitMap)
4905 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
4906 }
4907
4908 return mapType;
4909}
4910
4912 MapInfoData &mapData, SmallVectorImpl<Value> &mapVars,
4913 LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl,
4914 llvm::IRBuilderBase &builder, ArrayRef<Value> useDevPtrOperands = {},
4915 ArrayRef<Value> useDevAddrOperands = {},
4916 ArrayRef<Value> hasDevAddrOperands = {}) {
4917 auto checkIsAMember = [](const auto &mapVars, auto mapOp) {
4918 // Check if this is a member mapping and correctly assign that it is, if
4919 // it is a member of a larger object.
4920 // TODO: Need better handling of members, and distinguishing of members
4921 // that are implicitly allocated on device vs explicitly passed in as
4922 // arguments.
4923 // TODO: May require some further additions to support nested record
4924 // types, i.e. member maps that can have member maps.
4925 for (Value mapValue : mapVars) {
4926 auto map = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
4927 for (auto member : map.getMembers())
4928 if (member == mapOp)
4929 return true;
4930 }
4931 return false;
4932 };
4933
4934 // Process MapOperands
4935 for (Value mapValue : mapVars) {
4936 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
4937 Value offloadPtr =
4938 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
4939 mapData.OriginalValue.push_back(moduleTranslation.lookupValue(offloadPtr));
4940 mapData.Pointers.push_back(mapData.OriginalValue.back());
4941
4942 if (llvm::Value *refPtr =
4943 getRefPtrIfDeclareTarget(offloadPtr, moduleTranslation)) {
4944 mapData.IsDeclareTarget.push_back(true);
4945 mapData.BasePointers.push_back(refPtr);
4946 } else if (isDeclareTargetTo(offloadPtr)) {
4947 mapData.IsDeclareTarget.push_back(true);
4948 mapData.BasePointers.push_back(mapData.OriginalValue.back());
4949 } else { // regular mapped variable
4950 mapData.IsDeclareTarget.push_back(false);
4951 mapData.BasePointers.push_back(mapData.OriginalValue.back());
4952 }
4953
4954 mapData.BaseType.push_back(
4955 moduleTranslation.convertType(mapOp.getVarType()));
4956 mapData.Sizes.push_back(
4957 getSizeInBytes(dl, mapOp.getVarType(), mapOp, mapData.Pointers.back(),
4958 mapData.BaseType.back(), builder, moduleTranslation));
4959 mapData.MapClause.push_back(mapOp.getOperation());
4960 mapData.Types.push_back(convertClauseMapFlags(mapOp.getMapType()));
4961 mapData.Names.push_back(LLVM::createMappingInformation(
4962 mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
4963 mapData.DevicePointers.push_back(llvm::OpenMPIRBuilder::DeviceInfoTy::None);
4964 if (mapOp.getMapperId())
4965 mapData.Mappers.push_back(
4967 mapOp, mapOp.getMapperIdAttr()));
4968 else
4969 mapData.Mappers.push_back(nullptr);
4970 mapData.IsAMapping.push_back(true);
4971 mapData.IsAMember.push_back(checkIsAMember(mapVars, mapOp));
4972 }
4973
4974 auto findMapInfo = [&mapData](llvm::Value *val,
4975 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
4976 unsigned index = 0;
4977 bool found = false;
4978 for (llvm::Value *basePtr : mapData.OriginalValue) {
4979 if (basePtr == val && mapData.IsAMapping[index]) {
4980 found = true;
4981 mapData.Types[index] |=
4982 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
4983 mapData.DevicePointers[index] = devInfoTy;
4984 }
4985 index++;
4986 }
4987 return found;
4988 };
4989
4990 // Process useDevPtr(Addr)Operands
4991 auto addDevInfos = [&](const llvm::ArrayRef<Value> &useDevOperands,
4992 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
4993 for (Value mapValue : useDevOperands) {
4994 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
4995 Value offloadPtr =
4996 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
4997 llvm::Value *origValue = moduleTranslation.lookupValue(offloadPtr);
4998
4999 // Check if map info is already present for this entry.
5000 if (!findMapInfo(origValue, devInfoTy)) {
5001 mapData.OriginalValue.push_back(origValue);
5002 mapData.Pointers.push_back(mapData.OriginalValue.back());
5003 mapData.IsDeclareTarget.push_back(false);
5004 mapData.BasePointers.push_back(mapData.OriginalValue.back());
5005 mapData.BaseType.push_back(
5006 moduleTranslation.convertType(mapOp.getVarType()));
5007 mapData.Sizes.push_back(builder.getInt64(0));
5008 mapData.MapClause.push_back(mapOp.getOperation());
5009 mapData.Types.push_back(
5010 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM);
5011 mapData.Names.push_back(LLVM::createMappingInformation(
5012 mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
5013 mapData.DevicePointers.push_back(devInfoTy);
5014 mapData.Mappers.push_back(nullptr);
5015 mapData.IsAMapping.push_back(false);
5016 mapData.IsAMember.push_back(checkIsAMember(useDevOperands, mapOp));
5017 }
5018 }
5019 };
5020
5021 addDevInfos(useDevAddrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
5022 addDevInfos(useDevPtrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer);
5023
5024 for (Value mapValue : hasDevAddrOperands) {
5025 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
5026 Value offloadPtr =
5027 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
5028 llvm::Value *origValue = moduleTranslation.lookupValue(offloadPtr);
5029 auto mapType = convertClauseMapFlags(mapOp.getMapType());
5030 auto mapTypeAlways = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
5031 bool isDevicePtr =
5032 (mapOp.getMapType() & omp::ClauseMapFlags::is_device_ptr) !=
5033 omp::ClauseMapFlags::none;
5034
5035 mapData.OriginalValue.push_back(origValue);
5036 mapData.BasePointers.push_back(origValue);
5037 mapData.Pointers.push_back(origValue);
5038 mapData.IsDeclareTarget.push_back(false);
5039 mapData.BaseType.push_back(
5040 moduleTranslation.convertType(mapOp.getVarType()));
5041 mapData.Sizes.push_back(
5042 builder.getInt64(dl.getTypeSize(mapOp.getVarType())));
5043 mapData.MapClause.push_back(mapOp.getOperation());
5044 if (llvm::to_underlying(mapType & mapTypeAlways)) {
5045 // Descriptors are mapped with the ALWAYS flag, since they can get
5046 // rematerialized, so the address of the decriptor for a given object
5047 // may change from one place to another.
5048 mapData.Types.push_back(mapType);
5049 // Technically it's possible for a non-descriptor mapping to have
5050 // both has-device-addr and ALWAYS, so lookup the mapper in case it
5051 // exists.
5052 if (mapOp.getMapperId()) {
5053 mapData.Mappers.push_back(
5055 mapOp, mapOp.getMapperIdAttr()));
5056 } else {
5057 mapData.Mappers.push_back(nullptr);
5058 }
5059 } else {
5060 // For is_device_ptr we need the map type to propagate so the runtime
5061 // can materialize the device-side copy of the pointer container.
5062 mapData.Types.push_back(
5063 isDevicePtr ? mapType
5064 : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL);
5065 mapData.Mappers.push_back(nullptr);
5066 }
5067 mapData.Names.push_back(LLVM::createMappingInformation(
5068 mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
5069 mapData.DevicePointers.push_back(
5070 isDevicePtr ? llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer
5071 : llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
5072 mapData.IsAMapping.push_back(false);
5073 mapData.IsAMember.push_back(checkIsAMember(hasDevAddrOperands, mapOp));
5074 }
5075}
5076
5077static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp) {
5078 auto *res = llvm::find(mapData.MapClause, memberOp);
5079 assert(res != mapData.MapClause.end() &&
5080 "MapInfoOp for member not found in MapData, cannot return index");
5081 return std::distance(mapData.MapClause.begin(), res);
5082}
5083
5085 omp::MapInfoOp mapInfo) {
5086 ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
5087 llvm::SmallVector<size_t> occludedChildren;
5088 llvm::sort(
5089 indices.begin(), indices.end(), [&](const size_t a, const size_t b) {
5090 // Bail early if we are asked to look at the same index. If we do not
5091 // bail early, we can end up mistakenly adding indices to
5092 // occludedChildren. This can occur with some types of libc++ hardening.
5093 if (a == b)
5094 return false;
5095
5096 auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]);
5097 auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]);
5098
5099 for (auto it : llvm::zip(memberIndicesA, memberIndicesB)) {
5100 int64_t aIndex = mlir::cast<IntegerAttr>(std::get<0>(it)).getInt();
5101 int64_t bIndex = mlir::cast<IntegerAttr>(std::get<1>(it)).getInt();
5102
5103 if (aIndex == bIndex)
5104 continue;
5105
5106 if (aIndex < bIndex)
5107 return true;
5108
5109 if (aIndex > bIndex)
5110 return false;
5111 }
5112
5113 // Iterated up until the end of the smallest member and
5114 // they were found to be equal up to that point, so select
5115 // the member with the lowest index count, so the "parent"
5116 bool memberAParent = memberIndicesA.size() < memberIndicesB.size();
5117 if (memberAParent)
5118 occludedChildren.push_back(b);
5119 else
5120 occludedChildren.push_back(a);
5121 return memberAParent;
5122 });
5123
5124 // We remove children from the index list that are overshadowed by
5125 // a parent, this prevents us retrieving these as the first or last
5126 // element when the parent is the correct element in these cases.
5127 for (auto v : occludedChildren)
5128 indices.erase(std::remove(indices.begin(), indices.end(), v),
5129 indices.end());
5130}
5131
5132static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo,
5133 bool first) {
5134 ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
5135 // Only 1 member has been mapped, we can return it.
5136 if (indexAttr.size() == 1)
5137 return cast<omp::MapInfoOp>(mapInfo.getMembers()[0].getDefiningOp());
5138 llvm::SmallVector<size_t> indices(indexAttr.size());
5139 std::iota(indices.begin(), indices.end(), 0);
5140 sortMapIndices(indices, mapInfo);
5141 return llvm::cast<omp::MapInfoOp>(
5142 mapInfo.getMembers()[first ? indices.front() : indices.back()]
5143 .getDefiningOp());
5144}
5145
5146/// This function calculates the array/pointer offset for map data provided
5147/// with bounds operations, e.g. when provided something like the following:
5148///
5149/// Fortran
5150/// map(tofrom: array(2:5, 3:2))
5151///
5152/// We must calculate the initial pointer offset to pass across, this function
5153/// performs this using bounds.
5154///
5155/// TODO/WARNING: This only supports Fortran's column major indexing currently
5156/// as is noted in the note below and comments in the function, we must extend
5157/// this function when we add a C++ frontend.
5158/// NOTE: which while specified in row-major order it currently needs to be
5159/// flipped for Fortran's column order array allocation and access (as
5160/// opposed to C++'s row-major, hence the backwards processing where order is
5161/// important). This is likely important to keep in mind for the future when
5162/// we incorporate a C++ frontend, both frontends will need to agree on the
5163/// ordering of generated bounds operations (one may have to flip them) to
5164/// make the below lowering frontend agnostic. The offload size
5165/// calcualtion may also have to be adjusted for C++.
5166static std::vector<llvm::Value *>
5168 llvm::IRBuilderBase &builder, bool isArrayTy,
5169 OperandRange bounds) {
5170 std::vector<llvm::Value *> idx;
5171 // There's no bounds to calculate an offset from, we can safely
5172 // ignore and return no indices.
5173 if (bounds.empty())
5174 return idx;
5175
5176 // If we have an array type, then we have its type so can treat it as a
5177 // normal GEP instruction where the bounds operations are simply indexes
5178 // into the array. We currently do reverse order of the bounds, which
5179 // I believe leans more towards Fortran's column-major in memory.
5180 if (isArrayTy) {
5181 idx.push_back(builder.getInt64(0));
5182 for (int i = bounds.size() - 1; i >= 0; --i) {
5183 if (auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
5184 bounds[i].getDefiningOp())) {
5185 idx.push_back(moduleTranslation.lookupValue(boundOp.getLowerBound()));
5186 }
5187 }
5188 } else {
5189 // If we do not have an array type, but we have bounds, then we're dealing
5190 // with a pointer that's being treated like an array and we have the
5191 // underlying type e.g. an i32, or f64 etc, e.g. a fortran descriptor base
5192 // address (pointer pointing to the actual data) so we must caclulate the
5193 // offset using a single index which the following loop attempts to
5194 // compute using the standard column-major algorithm e.g for a 3D array:
5195 //
5196 // ((((c_idx * b_len) + b_idx) * a_len) + a_idx)
5197 //
5198 // It is of note that it's doing column-major rather than row-major at the
5199 // moment, but having a way for the frontend to indicate which major format
5200 // to use or standardizing/canonicalizing the order of the bounds to compute
5201 // the offset may be useful in the future when there's other frontends with
5202 // different formats.
5203 std::vector<llvm::Value *> dimensionIndexSizeOffset;
5204 for (int i = bounds.size() - 1; i >= 0; --i) {
5205 if (auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
5206 bounds[i].getDefiningOp())) {
5207 if (i == ((int)bounds.size() - 1))
5208 idx.emplace_back(
5209 moduleTranslation.lookupValue(boundOp.getLowerBound()));
5210 else
5211 idx.back() = builder.CreateAdd(
5212 builder.CreateMul(idx.back(), moduleTranslation.lookupValue(
5213 boundOp.getExtent())),
5214 moduleTranslation.lookupValue(boundOp.getLowerBound()));
5215 }
5216 }
5217 }
5218
5219 return idx;
5220}
5221
5223 llvm::transform(values, std::back_inserter(ints), [](Attribute value) {
5224 return cast<IntegerAttr>(value).getInt();
5225 });
5226}
5227
5228// Gathers members that are overlapping in the parent, excluding members that
5229// themselves overlap, keeping the top-most (closest to parents level) map.
5230static void
5232 omp::MapInfoOp parentOp) {
5233 // No members mapped, no overlaps.
5234 if (parentOp.getMembers().empty())
5235 return;
5236
5237 // Single member, we can insert and return early.
5238 if (parentOp.getMembers().size() == 1) {
5239 overlapMapDataIdxs.push_back(0);
5240 return;
5241 }
5242
5243 // 1) collect list of top-level overlapping members from MemberOp
5245 ArrayAttr indexAttr = parentOp.getMembersIndexAttr();
5246 for (auto [memIndex, indicesAttr] : llvm::enumerate(indexAttr))
5247 memberByIndex.push_back(
5248 std::make_pair(memIndex, cast<ArrayAttr>(indicesAttr)));
5249
5250 // Sort the smallest first (higher up the parent -> member chain), so that
5251 // when we remove members, we remove as much as we can in the initial
5252 // iterations, shortening the number of passes required.
5253 llvm::sort(memberByIndex.begin(), memberByIndex.end(),
5254 [&](auto a, auto b) { return a.second.size() < b.second.size(); });
5255
5256 // Remove elements from the vector if there is a parent element that
5257 // supersedes it. i.e. if member [0] is mapped, we can remove members [0,1],
5258 // [0,2].. etc.
5260 for (auto v : memberByIndex) {
5261 llvm::SmallVector<int64_t> vArr(v.second.size());
5262 getAsIntegers(v.second, vArr);
5263 skipList.push_back(
5264 *std::find_if(memberByIndex.begin(), memberByIndex.end(), [&](auto x) {
5265 if (v == x)
5266 return false;
5267 llvm::SmallVector<int64_t> xArr(x.second.size());
5268 getAsIntegers(x.second, xArr);
5269 return std::equal(vArr.begin(), vArr.end(), xArr.begin()) &&
5270 xArr.size() >= vArr.size();
5271 }));
5272 }
5273
5274 // Collect the indices, as we need the base pointer etc. from the MapData
5275 // structure which is primarily accessible via index at the moment.
5276 for (auto v : memberByIndex)
5277 if (find(skipList.begin(), skipList.end(), v) == skipList.end())
5278 overlapMapDataIdxs.push_back(v.first);
5279}
5280
5281// The intent is to verify if the mapped data being passed is a
5282// pointer -> pointee that requires special handling in certain cases,
5283// e.g. applying the OMP_MAP_PTR_AND_OBJ map type.
5284//
5285// There may be a better way to verify this, but unfortunately with
5286// opaque pointers we lose the ability to easily check if something is
5287// a pointer whilst maintaining access to the underlying type.
5288static bool checkIfPointerMap(omp::MapInfoOp mapOp) {
5289 // If we have a varPtrPtr field assigned then the underlying type is a pointer
5290 if (mapOp.getVarPtrPtr())
5291 return true;
5292
5293 // If the map data is declare target with a link clause, then it's represented
5294 // as a pointer when we lower it to LLVM-IR even if at the MLIR level it has
5295 // no relation to pointers.
5296 if (isDeclareTargetLink(mapOp.getVarPtr()))
5297 return true;
5298
5299 return false;
5300}
5301
5302// This creates two insertions into the MapInfosTy data structure for the
5303// "parent" of a set of members, (usually a container e.g.
5304// class/structure/derived type) when subsequent members have also been
5305// explicitly mapped on the same map clause. Certain types, such as Fortran
5306// descriptors are mapped like this as well, however, the members are
5307// implicit as far as a user is concerned, but we must explicitly map them
5308// internally.
5309//
5310// This function also returns the memberOfFlag for this particular parent,
5311// which is utilised in subsequent member mappings (by modifying there map type
5312// with it) to indicate that a member is part of this parent and should be
5313// treated by the runtime as such. Important to achieve the correct mapping.
5314//
5315// This function borrows a lot from Clang's emitCombinedEntry function
5316// inside of CGOpenMPRuntime.cpp
5317static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
5318 LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder,
5319 llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo,
5320 MapInfoData &mapData, uint64_t mapDataIndex,
5321 TargetDirectiveEnumTy targetDirective) {
5322 assert(!ompBuilder.Config.isTargetDevice() &&
5323 "function only supported for host device codegen");
5324
5325 auto parentClause =
5326 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
5327
5328 auto *parentMapper = mapData.Mappers[mapDataIndex];
5329
5330 // Map the first segment of the parent. If a user-defined mapper is attached,
5331 // include the parent's to/from-style bits (and common modifiers) in this
5332 // base entry so the mapper receives correct copy semantics via its 'type'
5333 // parameter. Also keep TARGET_PARAM when required for kernel arguments.
5334 llvm::omp::OpenMPOffloadMappingFlags baseFlag =
5335 (targetDirective == TargetDirectiveEnumTy::Target &&
5336 !mapData.IsDeclareTarget[mapDataIndex])
5337 ? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM
5338 : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
5339
5340 if (parentMapper) {
5341 using mapFlags = llvm::omp::OpenMPOffloadMappingFlags;
5342 // Preserve relevant map-type bits from the parent clause. These include
5343 // the copy direction (TO/FROM), as well as commonly used modifiers that
5344 // should be visible to the mapper for correct behaviour.
5345 mapFlags parentFlags = mapData.Types[mapDataIndex];
5346 mapFlags preserve = mapFlags::OMP_MAP_TO | mapFlags::OMP_MAP_FROM |
5347 mapFlags::OMP_MAP_ALWAYS | mapFlags::OMP_MAP_CLOSE |
5348 mapFlags::OMP_MAP_PRESENT | mapFlags::OMP_MAP_OMPX_HOLD;
5349 baseFlag |= (parentFlags & preserve);
5350 }
5351
5352 combinedInfo.Types.emplace_back(baseFlag);
5353 combinedInfo.DevicePointers.emplace_back(
5354 mapData.DevicePointers[mapDataIndex]);
5355 // Only attach the mapper to the base entry when we are mapping the whole
5356 // parent. Combined/segment entries must not carry a mapper; otherwise the
5357 // mapper can be invoked with a partial size, which is undefined behaviour.
5358 combinedInfo.Mappers.emplace_back(
5359 parentMapper && !parentClause.getPartialMap() ? parentMapper : nullptr);
5360 combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
5361 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
5362 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
5363
5364 // Calculate size of the parent object being mapped based on the
5365 // addresses at runtime, highAddr - lowAddr = size. This of course
5366 // doesn't factor in allocated data like pointers, hence the further
5367 // processing of members specified by users, or in the case of
5368 // Fortran pointers and allocatables, the mapping of the pointed to
5369 // data by the descriptor (which itself, is a structure containing
5370 // runtime information on the dynamically allocated data).
5371 llvm::Value *lowAddr, *highAddr;
5372 if (!parentClause.getPartialMap()) {
5373 lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
5374 builder.getPtrTy());
5375 highAddr = builder.CreatePointerCast(
5376 builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
5377 mapData.Pointers[mapDataIndex], 1),
5378 builder.getPtrTy());
5379 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
5380 } else {
5381 auto mapOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
5382 int firstMemberIdx = getMapDataMemberIdx(
5383 mapData, getFirstOrLastMappedMemberPtr(mapOp, true));
5384 lowAddr = builder.CreatePointerCast(mapData.Pointers[firstMemberIdx],
5385 builder.getPtrTy());
5386 int lastMemberIdx = getMapDataMemberIdx(
5387 mapData, getFirstOrLastMappedMemberPtr(mapOp, false));
5388 highAddr = builder.CreatePointerCast(
5389 builder.CreateGEP(mapData.BaseType[lastMemberIdx],
5390 mapData.Pointers[lastMemberIdx], builder.getInt64(1)),
5391 builder.getPtrTy());
5392 combinedInfo.Pointers.emplace_back(mapData.Pointers[firstMemberIdx]);
5393 }
5394
5395 llvm::Value *size = builder.CreateIntCast(
5396 builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
5397 builder.getInt64Ty(),
5398 /*isSigned=*/false);
5399 combinedInfo.Sizes.push_back(size);
5400
5401 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag =
5402 ompBuilder.getMemberOfFlag(combinedInfo.BasePointers.size() - 1);
5403
5404 // This creates the initial MEMBER_OF mapping that consists of
5405 // the parent/top level container (same as above effectively, except
5406 // with a fixed initial compile time size and separate maptype which
5407 // indicates the true mape type (tofrom etc.). This parent mapping is
5408 // only relevant if the structure in its totality is being mapped,
5409 // otherwise the above suffices.
5410 if (!parentClause.getPartialMap()) {
5411 // TODO: This will need to be expanded to include the whole host of logic
5412 // for the map flags that Clang currently supports (e.g. it should do some
5413 // further case specific flag modifications). For the moment, it handles
5414 // what we support as expected.
5415 llvm::omp::OpenMPOffloadMappingFlags mapFlag = mapData.Types[mapDataIndex];
5416 bool hasMapClose = (llvm::omp::OpenMPOffloadMappingFlags(mapFlag) &
5417 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE) ==
5418 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
5419 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
5420
5421 if (targetDirective == TargetDirectiveEnumTy::TargetUpdate || hasMapClose) {
5422 combinedInfo.Types.emplace_back(mapFlag);
5423 combinedInfo.DevicePointers.emplace_back(
5424 mapData.DevicePointers[mapDataIndex]);
5425 combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
5426 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
5427 combinedInfo.BasePointers.emplace_back(
5428 mapData.BasePointers[mapDataIndex]);
5429 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
5430 combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIndex]);
5431 combinedInfo.Mappers.emplace_back(nullptr);
5432 } else {
5433 llvm::SmallVector<size_t> overlapIdxs;
5434 // Find all of the members that "overlap", i.e. occlude other members that
5435 // were mapped alongside the parent, e.g. member [0], occludes [0,1] and
5436 // [0,2], but not [1,0].
5437 getOverlappedMembers(overlapIdxs, parentClause);
5438 // We need to make sure the overlapped members are sorted in order of
5439 // lowest address to highest address.
5440 sortMapIndices(overlapIdxs, parentClause);
5441
5442 lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
5443 builder.getPtrTy());
5444 highAddr = builder.CreatePointerCast(
5445 builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
5446 mapData.Pointers[mapDataIndex], 1),
5447 builder.getPtrTy());
5448
5449 // TODO: We may want to skip arrays/array sections in this as Clang does.
5450 // It appears to be an optimisation rather than a necessity though,
5451 // but this requires further investigation. However, we would have to make
5452 // sure to not exclude maps with bounds that ARE pointers, as these are
5453 // processed as separate components, i.e. pointer + data.
5454 for (auto v : overlapIdxs) {
5455 auto mapDataOverlapIdx = getMapDataMemberIdx(
5456 mapData,
5457 cast<omp::MapInfoOp>(parentClause.getMembers()[v].getDefiningOp()));
5458 combinedInfo.Types.emplace_back(mapFlag);
5459 combinedInfo.DevicePointers.emplace_back(
5460 mapData.DevicePointers[mapDataOverlapIdx]);
5461 combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
5462 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
5463 combinedInfo.BasePointers.emplace_back(
5464 mapData.BasePointers[mapDataIndex]);
5465 combinedInfo.Mappers.emplace_back(nullptr);
5466 combinedInfo.Pointers.emplace_back(lowAddr);
5467 combinedInfo.Sizes.emplace_back(builder.CreateIntCast(
5468 builder.CreatePtrDiff(builder.getInt8Ty(),
5469 mapData.OriginalValue[mapDataOverlapIdx],
5470 lowAddr),
5471 builder.getInt64Ty(), /*isSigned=*/true));
5472 lowAddr = builder.CreateConstGEP1_32(
5473 checkIfPointerMap(llvm::cast<omp::MapInfoOp>(
5474 mapData.MapClause[mapDataOverlapIdx]))
5475 ? builder.getPtrTy()
5476 : mapData.BaseType[mapDataOverlapIdx],
5477 mapData.BasePointers[mapDataOverlapIdx], 1);
5478 }
5479
5480 combinedInfo.Types.emplace_back(mapFlag);
5481 combinedInfo.DevicePointers.emplace_back(
5482 mapData.DevicePointers[mapDataIndex]);
5483 combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
5484 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
5485 combinedInfo.BasePointers.emplace_back(
5486 mapData.BasePointers[mapDataIndex]);
5487 combinedInfo.Mappers.emplace_back(nullptr);
5488 combinedInfo.Pointers.emplace_back(lowAddr);
5489 combinedInfo.Sizes.emplace_back(builder.CreateIntCast(
5490 builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
5491 builder.getInt64Ty(), true));
5492 }
5493 }
5494 return memberOfFlag;
5495}
5496
5497// This function is intended to add explicit mappings of members
5499 LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder,
5500 llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo,
5501 MapInfoData &mapData, uint64_t mapDataIndex,
5502 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag,
5503 TargetDirectiveEnumTy targetDirective) {
5504 assert(!ompBuilder.Config.isTargetDevice() &&
5505 "function only supported for host device codegen");
5506
5507 auto parentClause =
5508 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
5509
5510 for (auto mappedMembers : parentClause.getMembers()) {
5511 auto memberClause =
5512 llvm::cast<omp::MapInfoOp>(mappedMembers.getDefiningOp());
5513 int memberDataIdx = getMapDataMemberIdx(mapData, memberClause);
5514
5515 assert(memberDataIdx >= 0 && "could not find mapped member of structure");
5516
5517 // If we're currently mapping a pointer to a block of data, we must
5518 // initially map the pointer, and then attatch/bind the data with a
5519 // subsequent map to the pointer. This segment of code generates the
5520 // pointer mapping, which can in certain cases be optimised out as Clang
5521 // currently does in its lowering. However, for the moment we do not do so,
5522 // in part as we currently have substantially less information on the data
5523 // being mapped at this stage.
5524 if (checkIfPointerMap(memberClause)) {
5525 auto mapFlag = convertClauseMapFlags(memberClause.getMapType());
5526 mapFlag &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
5527 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
5528 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
5529 combinedInfo.Types.emplace_back(mapFlag);
5530 combinedInfo.DevicePointers.emplace_back(
5531 llvm::OpenMPIRBuilder::DeviceInfoTy::None);
5532 combinedInfo.Mappers.emplace_back(nullptr);
5533 combinedInfo.Names.emplace_back(
5534 LLVM::createMappingInformation(memberClause.getLoc(), ompBuilder));
5535 combinedInfo.BasePointers.emplace_back(
5536 mapData.BasePointers[mapDataIndex]);
5537 combinedInfo.Pointers.emplace_back(mapData.BasePointers[memberDataIdx]);
5538 combinedInfo.Sizes.emplace_back(builder.getInt64(
5539 moduleTranslation.getLLVMModule()->getDataLayout().getPointerSize()));
5540 }
5541
5542 // Same MemberOfFlag to indicate its link with parent and other members
5543 // of.
5544 auto mapFlag = convertClauseMapFlags(memberClause.getMapType());
5545 mapFlag &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
5546 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
5547 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
5548 bool isDeclTargetTo = isDeclareTargetTo(parentClause.getVarPtr()
5549 ? parentClause.getVarPtr()
5550 : parentClause.getVarPtrPtr());
5551 if (checkIfPointerMap(memberClause) &&
5552 (!isDeclTargetTo ||
5553 (targetDirective != TargetDirectiveEnumTy::TargetUpdate &&
5554 targetDirective != TargetDirectiveEnumTy::TargetData))) {
5555 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
5556 }
5557
5558 combinedInfo.Types.emplace_back(mapFlag);
5559 combinedInfo.DevicePointers.emplace_back(
5560 mapData.DevicePointers[memberDataIdx]);
5561 combinedInfo.Mappers.emplace_back(mapData.Mappers[memberDataIdx]);
5562 combinedInfo.Names.emplace_back(
5563 LLVM::createMappingInformation(memberClause.getLoc(), ompBuilder));
5564 uint64_t basePointerIndex =
5565 checkIfPointerMap(memberClause) ? memberDataIdx : mapDataIndex;
5566 combinedInfo.BasePointers.emplace_back(
5567 mapData.BasePointers[basePointerIndex]);
5568 combinedInfo.Pointers.emplace_back(mapData.Pointers[memberDataIdx]);
5569
5570 llvm::Value *size = mapData.Sizes[memberDataIdx];
5571 if (checkIfPointerMap(memberClause)) {
5572 size = builder.CreateSelect(
5573 builder.CreateIsNull(mapData.Pointers[memberDataIdx]),
5574 builder.getInt64(0), size);
5575 }
5576
5577 combinedInfo.Sizes.emplace_back(size);
5578 }
5579}
5580
5581static void processIndividualMap(MapInfoData &mapData, size_t mapDataIdx,
5582 MapInfosTy &combinedInfo,
5583 TargetDirectiveEnumTy targetDirective,
5584 int mapDataParentIdx = -1) {
5585 // Declare Target Mappings are excluded from being marked as
5586 // OMP_MAP_TARGET_PARAM as they are not passed as parameters, they're
5587 // marked with OMP_MAP_PTR_AND_OBJ instead.
5588 auto mapFlag = mapData.Types[mapDataIdx];
5589 auto mapInfoOp = llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIdx]);
5590
5591 bool isPtrTy = checkIfPointerMap(mapInfoOp);
5592 if (isPtrTy)
5593 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
5594
5595 if (targetDirective == TargetDirectiveEnumTy::Target &&
5596 !mapData.IsDeclareTarget[mapDataIdx])
5597 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
5598
5599 if (mapInfoOp.getMapCaptureType() == omp::VariableCaptureKind::ByCopy &&
5600 !isPtrTy)
5601 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
5602
5603 // if we're provided a mapDataParentIdx, then the data being mapped is
5604 // part of a larger object (in a parent <-> member mapping) and in this
5605 // case our BasePointer should be the parent.
5606 if (mapDataParentIdx >= 0)
5607 combinedInfo.BasePointers.emplace_back(
5608 mapData.BasePointers[mapDataParentIdx]);
5609 else
5610 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIdx]);
5611
5612 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIdx]);
5613 combinedInfo.DevicePointers.emplace_back(mapData.DevicePointers[mapDataIdx]);
5614 combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIdx]);
5615 combinedInfo.Names.emplace_back(mapData.Names[mapDataIdx]);
5616 combinedInfo.Types.emplace_back(mapFlag);
5617 combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIdx]);
5618}
5619
5621 llvm::IRBuilderBase &builder,
5622 llvm::OpenMPIRBuilder &ompBuilder,
5623 DataLayout &dl, MapInfosTy &combinedInfo,
5624 MapInfoData &mapData, uint64_t mapDataIndex,
5625 TargetDirectiveEnumTy targetDirective) {
5626 assert(!ompBuilder.Config.isTargetDevice() &&
5627 "function only supported for host device codegen");
5628
5629 auto parentClause =
5630 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
5631
5632 // If we have a partial map (no parent referenced in the map clauses of the
5633 // directive, only members) and only a single member, we do not need to bind
5634 // the map of the member to the parent, we can pass the member separately.
5635 if (parentClause.getMembers().size() == 1 && parentClause.getPartialMap()) {
5636 auto memberClause = llvm::cast<omp::MapInfoOp>(
5637 parentClause.getMembers()[0].getDefiningOp());
5638 int memberDataIdx = getMapDataMemberIdx(mapData, memberClause);
5639 // Note: Clang treats arrays with explicit bounds that fall into this
5640 // category as a parent with map case, however, it seems this isn't a
5641 // requirement, and processing them as an individual map is fine. So,
5642 // we will handle them as individual maps for the moment, as it's
5643 // difficult for us to check this as we always require bounds to be
5644 // specified currently and it's also marginally more optimal (single
5645 // map rather than two). The difference may come from the fact that
5646 // Clang maps array without bounds as pointers (which we do not
5647 // currently do), whereas we treat them as arrays in all cases
5648 // currently.
5649 processIndividualMap(mapData, memberDataIdx, combinedInfo, targetDirective,
5650 mapDataIndex);
5651 return;
5652 }
5653
5654 llvm::omp::OpenMPOffloadMappingFlags memberOfParentFlag =
5655 mapParentWithMembers(moduleTranslation, builder, ompBuilder, dl,
5656 combinedInfo, mapData, mapDataIndex,
5657 targetDirective);
5658 processMapMembersWithParent(moduleTranslation, builder, ompBuilder, dl,
5659 combinedInfo, mapData, mapDataIndex,
5660 memberOfParentFlag, targetDirective);
5661}
5662
5663// This is a variation on Clang's GenerateOpenMPCapturedVars, which
5664// generates different operation (e.g. load/store) combinations for
5665// arguments to the kernel, based on map capture kinds which are then
5666// utilised in the combinedInfo in place of the original Map value.
5667static void
5668createAlteredByCaptureMap(MapInfoData &mapData,
5669 LLVM::ModuleTranslation &moduleTranslation,
5670 llvm::IRBuilderBase &builder) {
5671 assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
5672 "function only supported for host device codegen");
5673 for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
5674 // if it's declare target, skip it, it's handled separately.
5675 if (!mapData.IsDeclareTarget[i]) {
5676 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
5677 omp::VariableCaptureKind captureKind = mapOp.getMapCaptureType();
5678 bool isPtrTy = checkIfPointerMap(mapOp);
5679
5680 // Currently handles array sectioning lowerbound case, but more
5681 // logic may be required in the future. Clang invokes EmitLValue,
5682 // which has specialised logic for special Clang types such as user
5683 // defines, so it is possible we will have to extend this for
5684 // structures or other complex types. As the general idea is that this
5685 // function mimics some of the logic from Clang that we require for
5686 // kernel argument passing from host -> device.
5687 switch (captureKind) {
5688 case omp::VariableCaptureKind::ByRef: {
5689 llvm::Value *newV = mapData.Pointers[i];
5690 std::vector<llvm::Value *> offsetIdx = calculateBoundsOffset(
5691 moduleTranslation, builder, mapData.BaseType[i]->isArrayTy(),
5692 mapOp.getBounds());
5693 if (isPtrTy)
5694 newV = builder.CreateLoad(builder.getPtrTy(), newV);
5695
5696 if (!offsetIdx.empty())
5697 newV = builder.CreateInBoundsGEP(mapData.BaseType[i], newV, offsetIdx,
5698 "array_offset");
5699 mapData.Pointers[i] = newV;
5700 } break;
5701 case omp::VariableCaptureKind::ByCopy: {
5702 llvm::Type *type = mapData.BaseType[i];
5703 llvm::Value *newV;
5704 if (mapData.Pointers[i]->getType()->isPointerTy())
5705 newV = builder.CreateLoad(type, mapData.Pointers[i]);
5706 else
5707 newV = mapData.Pointers[i];
5708
5709 if (!isPtrTy) {
5710 auto curInsert = builder.saveIP();
5711 llvm::DebugLoc DbgLoc = builder.getCurrentDebugLocation();
5712 builder.restoreIP(findAllocaInsertPoint(builder, moduleTranslation));
5713 auto *memTempAlloc =
5714 builder.CreateAlloca(builder.getPtrTy(), nullptr, ".casted");
5715 builder.SetCurrentDebugLocation(DbgLoc);
5716 builder.restoreIP(curInsert);
5717
5718 builder.CreateStore(newV, memTempAlloc);
5719 newV = builder.CreateLoad(builder.getPtrTy(), memTempAlloc);
5720 }
5721
5722 mapData.Pointers[i] = newV;
5723 mapData.BasePointers[i] = newV;
5724 } break;
5725 case omp::VariableCaptureKind::This:
5726 case omp::VariableCaptureKind::VLAType:
5727 mapData.MapClause[i]->emitOpError("Unhandled capture kind");
5728 break;
5729 }
5730 }
5731 }
5732}
5733
5734// Generate all map related information and fill the combinedInfo.
5735static void genMapInfos(llvm::IRBuilderBase &builder,
5736 LLVM::ModuleTranslation &moduleTranslation,
5737 DataLayout &dl, MapInfosTy &combinedInfo,
5738 MapInfoData &mapData,
5739 TargetDirectiveEnumTy targetDirective) {
5740 assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
5741 "function only supported for host device codegen");
5742 // We wish to modify some of the methods in which arguments are
5743 // passed based on their capture type by the target region, this can
5744 // involve generating new loads and stores, which changes the
5745 // MLIR value to LLVM value mapping, however, we only wish to do this
5746 // locally for the current function/target and also avoid altering
5747 // ModuleTranslation, so we remap the base pointer or pointer stored
5748 // in the map infos corresponding MapInfoData, which is later accessed
5749 // by genMapInfos and createTarget to help generate the kernel and
5750 // kernel arg structure. It primarily becomes relevant in cases like
5751 // bycopy, or byref range'd arrays. In the default case, we simply
5752 // pass thee pointer byref as both basePointer and pointer.
5753 createAlteredByCaptureMap(mapData, moduleTranslation, builder);
5754
5755 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
5756
5757 // We operate under the assumption that all vectors that are
5758 // required in MapInfoData are of equal lengths (either filled with
5759 // default constructed data or appropiate information) so we can
5760 // utilise the size from any component of MapInfoData, if we can't
5761 // something is missing from the initial MapInfoData construction.
5762 for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
5763 // NOTE/TODO: We currently do not support arbitrary depth record
5764 // type mapping.
5765 if (mapData.IsAMember[i])
5766 continue;
5767
5768 auto mapInfoOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[i]);
5769 if (!mapInfoOp.getMembers().empty()) {
5770 processMapWithMembersOf(moduleTranslation, builder, *ompBuilder, dl,
5771 combinedInfo, mapData, i, targetDirective);
5772 continue;
5773 }
5774
5775 processIndividualMap(mapData, i, combinedInfo, targetDirective);
5776 }
5777}
5778
5779static llvm::Expected<llvm::Function *>
5780emitUserDefinedMapper(Operation *declMapperOp, llvm::IRBuilderBase &builder,
5781 LLVM::ModuleTranslation &moduleTranslation,
5782 llvm::StringRef mapperFuncName,
5783 TargetDirectiveEnumTy targetDirective);
5784
5785static llvm::Expected<llvm::Function *>
5786getOrCreateUserDefinedMapperFunc(Operation *op, llvm::IRBuilderBase &builder,
5787 LLVM::ModuleTranslation &moduleTranslation,
5788 TargetDirectiveEnumTy targetDirective) {
5789 assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
5790 "function only supported for host device codegen");
5791 auto declMapperOp = cast<omp::DeclareMapperOp>(op);
5792 std::string mapperFuncName =
5793 moduleTranslation.getOpenMPBuilder()->createPlatformSpecificName(
5794 {"omp_mapper", declMapperOp.getSymName()});
5795
5796 if (auto *lookupFunc = moduleTranslation.lookupFunction(mapperFuncName))
5797 return lookupFunc;
5798
5799 // Recursive types can cause re-entrant mapper emission. The mapper function
5800 // is created by OpenMPIRBuilder before the callbacks run, so it may already
5801 // exist in the LLVM module even though it is not yet registered in the
5802 // ModuleTranslation mapping table. Reuse and register it to break the
5803 // recursion.
5804 if (llvm::Function *existingFunc =
5805 moduleTranslation.getLLVMModule()->getFunction(mapperFuncName)) {
5806 moduleTranslation.mapFunction(mapperFuncName, existingFunc);
5807 return existingFunc;
5808 }
5809
5810 return emitUserDefinedMapper(declMapperOp, builder, moduleTranslation,
5811 mapperFuncName, targetDirective);
5812}
5813
5814static llvm::Expected<llvm::Function *>
5815emitUserDefinedMapper(Operation *op, llvm::IRBuilderBase &builder,
5816 LLVM::ModuleTranslation &moduleTranslation,
5817 llvm::StringRef mapperFuncName,
5818 TargetDirectiveEnumTy targetDirective) {
5819 assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
5820 "function only supported for host device codegen");
5821 auto declMapperOp = cast<omp::DeclareMapperOp>(op);
5822 auto declMapperInfoOp = declMapperOp.getDeclareMapperInfo();
5823 DataLayout dl = DataLayout(declMapperOp->getParentOfType<ModuleOp>());
5824 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
5825 llvm::Type *varType = moduleTranslation.convertType(declMapperOp.getType());
5826 SmallVector<Value> mapVars = declMapperInfoOp.getMapVars();
5827
5828 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
5829
5830 // Fill up the arrays with all the mapped variables.
5831 MapInfosTy combinedInfo;
5832 auto genMapInfoCB =
5833 [&](InsertPointTy codeGenIP, llvm::Value *ptrPHI,
5834 llvm::Value *unused2) -> llvm::OpenMPIRBuilder::MapInfosOrErrorTy {
5835 builder.restoreIP(codeGenIP);
5836 moduleTranslation.mapValue(declMapperOp.getSymVal(), ptrPHI);
5837 moduleTranslation.mapBlock(&declMapperOp.getRegion().front(),
5838 builder.GetInsertBlock());
5839 if (failed(moduleTranslation.convertBlock(declMapperOp.getRegion().front(),
5840 /*ignoreArguments=*/true,
5841 builder)))
5842 return llvm::make_error<PreviouslyReportedError>();
5843 MapInfoData mapData;
5844 collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, dl,
5845 builder);
5846 genMapInfos(builder, moduleTranslation, dl, combinedInfo, mapData,
5847 targetDirective);
5848
5849 // Drop the mapping that is no longer necessary so that the same region
5850 // can be processed multiple times.
5851 moduleTranslation.forgetMapping(declMapperOp.getRegion());
5852 return combinedInfo;
5853 };
5854
5855 auto customMapperCB = [&](unsigned i) -> llvm::Expected<llvm::Function *> {
5856 if (!combinedInfo.Mappers[i])
5857 return nullptr;
5858 return getOrCreateUserDefinedMapperFunc(combinedInfo.Mappers[i], builder,
5859 moduleTranslation, targetDirective);
5860 };
5861
5862 llvm::Expected<llvm::Function *> newFn = ompBuilder->emitUserDefinedMapper(
5863 genMapInfoCB, varType, mapperFuncName, customMapperCB);
5864 if (!newFn)
5865 return newFn.takeError();
5866 if ([[maybe_unused]] llvm::Function *mappedFunc =
5867 moduleTranslation.lookupFunction(mapperFuncName)) {
5868 assert(mappedFunc == *newFn &&
5869 "mapper function mapping disagrees with emitted function");
5870 } else {
5871 moduleTranslation.mapFunction(mapperFuncName, *newFn);
5872 }
5873 return *newFn;
5874}
5875
5876static LogicalResult
5877convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
5878 LLVM::ModuleTranslation &moduleTranslation) {
5879 llvm::Value *ifCond = nullptr;
5880 llvm::Value *deviceID = builder.getInt64(llvm::omp::OMP_DEVICEID_UNDEF);
5881 SmallVector<Value> mapVars;
5882 SmallVector<Value> useDevicePtrVars;
5883 SmallVector<Value> useDeviceAddrVars;
5884 llvm::omp::RuntimeFunction RTLFn;
5885 DataLayout DL = DataLayout(op->getParentOfType<ModuleOp>());
5886 TargetDirectiveEnumTy targetDirective = getTargetDirectiveEnumTyFromOp(op);
5887
5888 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
5889 llvm::OpenMPIRBuilder::TargetDataInfo info(
5890 /*RequiresDevicePointerInfo=*/true,
5891 /*SeparateBeginEndCalls=*/true);
5892 assert(!ompBuilder->Config.isTargetDevice() &&
5893 "target data/enter/exit/update are host ops");
5894 bool isOffloadEntry = !ompBuilder->Config.TargetTriples.empty();
5895
5896 auto getDeviceID = [&](mlir::Value dev) -> llvm::Value * {
5897 llvm::Value *v = moduleTranslation.lookupValue(dev);
5898 return builder.CreateIntCast(v, builder.getInt64Ty(), /*isSigned=*/true);
5899 };
5900
5901 LogicalResult result =
5903 .Case([&](omp::TargetDataOp dataOp) {
5904 if (failed(checkImplementationStatus(*dataOp)))
5905 return failure();
5906
5907 if (auto ifVar = dataOp.getIfExpr())
5908 ifCond = moduleTranslation.lookupValue(ifVar);
5909
5910 if (mlir::Value devId = dataOp.getDevice())
5911 deviceID = getDeviceID(devId);
5912
5913 mapVars = dataOp.getMapVars();
5914 useDevicePtrVars = dataOp.getUseDevicePtrVars();
5915 useDeviceAddrVars = dataOp.getUseDeviceAddrVars();
5916 return success();
5917 })
5918 .Case([&](omp::TargetEnterDataOp enterDataOp) -> LogicalResult {
5919 if (failed(checkImplementationStatus(*enterDataOp)))
5920 return failure();
5921
5922 if (auto ifVar = enterDataOp.getIfExpr())
5923 ifCond = moduleTranslation.lookupValue(ifVar);
5924
5925 if (mlir::Value devId = enterDataOp.getDevice())
5926 deviceID = getDeviceID(devId);
5927
5928 RTLFn =
5929 enterDataOp.getNowait()
5930 ? llvm::omp::OMPRTL___tgt_target_data_begin_nowait_mapper
5931 : llvm::omp::OMPRTL___tgt_target_data_begin_mapper;
5932 mapVars = enterDataOp.getMapVars();
5933 info.HasNoWait = enterDataOp.getNowait();
5934 return success();
5935 })
5936 .Case([&](omp::TargetExitDataOp exitDataOp) -> LogicalResult {
5937 if (failed(checkImplementationStatus(*exitDataOp)))
5938 return failure();
5939
5940 if (auto ifVar = exitDataOp.getIfExpr())
5941 ifCond = moduleTranslation.lookupValue(ifVar);
5942
5943 if (mlir::Value devId = exitDataOp.getDevice())
5944 deviceID = getDeviceID(devId);
5945
5946 RTLFn = exitDataOp.getNowait()
5947 ? llvm::omp::OMPRTL___tgt_target_data_end_nowait_mapper
5948 : llvm::omp::OMPRTL___tgt_target_data_end_mapper;
5949 mapVars = exitDataOp.getMapVars();
5950 info.HasNoWait = exitDataOp.getNowait();
5951 return success();
5952 })
5953 .Case([&](omp::TargetUpdateOp updateDataOp) -> LogicalResult {
5954 if (failed(checkImplementationStatus(*updateDataOp)))
5955 return failure();
5956
5957 if (auto ifVar = updateDataOp.getIfExpr())
5958 ifCond = moduleTranslation.lookupValue(ifVar);
5959
5960 if (mlir::Value devId = updateDataOp.getDevice())
5961 deviceID = getDeviceID(devId);
5962
5963 RTLFn =
5964 updateDataOp.getNowait()
5965 ? llvm::omp::OMPRTL___tgt_target_data_update_nowait_mapper
5966 : llvm::omp::OMPRTL___tgt_target_data_update_mapper;
5967 mapVars = updateDataOp.getMapVars();
5968 info.HasNoWait = updateDataOp.getNowait();
5969 return success();
5970 })
5971 .DefaultUnreachable("unexpected operation");
5972
5973 if (failed(result))
5974 return failure();
5975 // Pretend we have IF(false) if we're not doing offload.
5976 if (!isOffloadEntry)
5977 ifCond = builder.getFalse();
5978
5979 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
5980 MapInfoData mapData;
5981 collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, DL,
5982 builder, useDevicePtrVars, useDeviceAddrVars);
5983
5984 // Fill up the arrays with all the mapped variables.
5985 MapInfosTy combinedInfo;
5986 auto genMapInfoCB = [&](InsertPointTy codeGenIP) -> MapInfosTy & {
5987 builder.restoreIP(codeGenIP);
5988 genMapInfos(builder, moduleTranslation, DL, combinedInfo, mapData,
5989 targetDirective);
5990 return combinedInfo;
5991 };
5992
5993 // Define a lambda to apply mappings between use_device_addr and
5994 // use_device_ptr base pointers, and their associated block arguments.
5995 auto mapUseDevice =
5996 [&moduleTranslation](
5997 llvm::OpenMPIRBuilder::DeviceInfoTy type,
5999 llvm::SmallVectorImpl<Value> &useDeviceVars, MapInfoData &mapInfoData,
6000 llvm::function_ref<llvm::Value *(llvm::Value *)> mapper = nullptr) {
6001 for (auto [arg, useDevVar] :
6002 llvm::zip_equal(blockArgs, useDeviceVars)) {
6003
6004 auto getMapBasePtr = [](omp::MapInfoOp mapInfoOp) {
6005 return mapInfoOp.getVarPtrPtr() ? mapInfoOp.getVarPtrPtr()
6006 : mapInfoOp.getVarPtr();
6007 };
6008
6009 auto useDevMap = cast<omp::MapInfoOp>(useDevVar.getDefiningOp());
6010 for (auto [mapClause, devicePointer, basePointer] : llvm::zip_equal(
6011 mapInfoData.MapClause, mapInfoData.DevicePointers,
6012 mapInfoData.BasePointers)) {
6013 auto mapOp = cast<omp::MapInfoOp>(mapClause);
6014 if (getMapBasePtr(mapOp) != getMapBasePtr(useDevMap) ||
6015 devicePointer != type)
6016 continue;
6017
6018 if (llvm::Value *devPtrInfoMap =
6019 mapper ? mapper(basePointer) : basePointer) {
6020 moduleTranslation.mapValue(arg, devPtrInfoMap);
6021 break;
6022 }
6023 }
6024 }
6025 };
6026
6027 using BodyGenTy = llvm::OpenMPIRBuilder::BodyGenTy;
6028 auto bodyGenCB = [&](InsertPointTy codeGenIP, BodyGenTy bodyGenType)
6029 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
6030 // We must always restoreIP regardless of doing anything the caller
6031 // does not restore it, leading to incorrect (no) branch generation.
6032 builder.restoreIP(codeGenIP);
6033 assert(isa<omp::TargetDataOp>(op) &&
6034 "BodyGen requested for non TargetDataOp");
6035 auto blockArgIface = cast<omp::BlockArgOpenMPOpInterface>(op);
6036 Region &region = cast<omp::TargetDataOp>(op).getRegion();
6037 switch (bodyGenType) {
6038 case BodyGenTy::Priv:
6039 // Check if any device ptr/addr info is available
6040 if (!info.DevicePtrInfoMap.empty()) {
6041 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
6042 blockArgIface.getUseDeviceAddrBlockArgs(),
6043 useDeviceAddrVars, mapData,
6044 [&](llvm::Value *basePointer) -> llvm::Value * {
6045 if (!info.DevicePtrInfoMap[basePointer].second)
6046 return nullptr;
6047 return builder.CreateLoad(
6048 builder.getPtrTy(),
6049 info.DevicePtrInfoMap[basePointer].second);
6050 });
6051 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
6052 blockArgIface.getUseDevicePtrBlockArgs(), useDevicePtrVars,
6053 mapData, [&](llvm::Value *basePointer) {
6054 return info.DevicePtrInfoMap[basePointer].second;
6055 });
6056
6057 if (failed(inlineConvertOmpRegions(region, "omp.data.region", builder,
6058 moduleTranslation)))
6059 return llvm::make_error<PreviouslyReportedError>();
6060 }
6061 break;
6062 case BodyGenTy::DupNoPriv:
6063 if (info.DevicePtrInfoMap.empty()) {
6064 // For host device we still need to do the mapping for codegen,
6065 // otherwise it may try to lookup a missing value.
6066 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
6067 blockArgIface.getUseDeviceAddrBlockArgs(),
6068 useDeviceAddrVars, mapData);
6069 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
6070 blockArgIface.getUseDevicePtrBlockArgs(), useDevicePtrVars,
6071 mapData);
6072 }
6073 break;
6074 case BodyGenTy::NoPriv:
6075 // If device info is available then region has already been generated
6076 if (info.DevicePtrInfoMap.empty()) {
6077 if (failed(inlineConvertOmpRegions(region, "omp.data.region", builder,
6078 moduleTranslation)))
6079 return llvm::make_error<PreviouslyReportedError>();
6080 }
6081 break;
6082 }
6083 return builder.saveIP();
6084 };
6085
6086 auto customMapperCB =
6087 [&](unsigned int i) -> llvm::Expected<llvm::Function *> {
6088 if (!combinedInfo.Mappers[i])
6089 return nullptr;
6090 info.HasMapper = true;
6091 return getOrCreateUserDefinedMapperFunc(combinedInfo.Mappers[i], builder,
6092 moduleTranslation, targetDirective);
6093 };
6094
6095 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
6096 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
6097 findAllocaInsertPoint(builder, moduleTranslation);
6098 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = [&]() {
6099 if (isa<omp::TargetDataOp>(op))
6100 return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(),
6101 deviceID, ifCond, info, genMapInfoCB,
6102 customMapperCB,
6103 /*MapperFunc=*/nullptr, bodyGenCB,
6104 /*DeviceAddrCB=*/nullptr);
6105 return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(),
6106 deviceID, ifCond, info, genMapInfoCB,
6107 customMapperCB, &RTLFn);
6108 }();
6109
6110 if (failed(handleError(afterIP, *op)))
6111 return failure();
6112
6113 builder.restoreIP(*afterIP);
6114 return success();
6115}
6116
6117static LogicalResult
6118convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder,
6119 LLVM::ModuleTranslation &moduleTranslation) {
6120 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
6121 auto distributeOp = cast<omp::DistributeOp>(opInst);
6122 if (failed(checkImplementationStatus(opInst)))
6123 return failure();
6124
6125 /// Process teams op reduction in distribute if the reduction is contained in
6126 /// the distribute op.
6127 omp::TeamsOp teamsOp = opInst.getParentOfType<omp::TeamsOp>();
6128 bool doDistributeReduction =
6129 teamsOp ? teamsReductionContainedInDistribute(teamsOp) : false;
6130
6131 DenseMap<Value, llvm::Value *> reductionVariableMap;
6132 unsigned numReductionVars = teamsOp ? teamsOp.getNumReductionVars() : 0;
6134 SmallVector<llvm::Value *> privateReductionVariables(numReductionVars);
6135 llvm::ArrayRef<bool> isByRef;
6136
6137 if (doDistributeReduction) {
6138 isByRef = getIsByRef(teamsOp.getReductionByref());
6139 assert(isByRef.size() == teamsOp.getNumReductionVars());
6140
6141 collectReductionDecls(teamsOp, reductionDecls);
6142 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
6143 findAllocaInsertPoint(builder, moduleTranslation);
6144
6145 MutableArrayRef<BlockArgument> reductionArgs =
6146 llvm::cast<omp::BlockArgOpenMPOpInterface>(*teamsOp)
6147 .getReductionBlockArgs();
6148
6150 teamsOp, reductionArgs, builder, moduleTranslation, allocaIP,
6151 reductionDecls, privateReductionVariables, reductionVariableMap,
6152 isByRef)))
6153 return failure();
6154 }
6155
6156 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
6157 auto bodyGenCB = [&](InsertPointTy allocaIP,
6158 InsertPointTy codeGenIP) -> llvm::Error {
6159 // Save the alloca insertion point on ModuleTranslation stack for use in
6160 // nested regions.
6162 moduleTranslation, allocaIP);
6163
6164 // DistributeOp has only one region associated with it.
6165 builder.restoreIP(codeGenIP);
6166 PrivateVarsInfo privVarsInfo(distributeOp);
6167
6169 allocatePrivateVars(builder, moduleTranslation, privVarsInfo, allocaIP);
6170 if (handleError(afterAllocas, opInst).failed())
6171 return llvm::make_error<PreviouslyReportedError>();
6172
6173 if (handleError(initPrivateVars(builder, moduleTranslation, privVarsInfo),
6174 opInst)
6175 .failed())
6176 return llvm::make_error<PreviouslyReportedError>();
6177
6178 if (failed(copyFirstPrivateVars(
6179 distributeOp, builder, moduleTranslation, privVarsInfo.mlirVars,
6180 privVarsInfo.llvmVars, privVarsInfo.privatizers,
6181 distributeOp.getPrivateNeedsBarrier())))
6182 return llvm::make_error<PreviouslyReportedError>();
6183
6184 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
6185 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
6187 convertOmpOpRegions(distributeOp.getRegion(), "omp.distribute.region",
6188 builder, moduleTranslation);
6189 if (!regionBlock)
6190 return regionBlock.takeError();
6191 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
6192
6193 // Skip applying a workshare loop below when translating 'distribute
6194 // parallel do' (it's been already handled by this point while translating
6195 // the nested omp.wsloop).
6196 if (!isa_and_present<omp::WsloopOp>(distributeOp.getNestedWrapper())) {
6197 // TODO: Add support for clauses which are valid for DISTRIBUTE
6198 // constructs. Static schedule is the default.
6199 bool hasDistSchedule = distributeOp.getDistScheduleStatic();
6200 auto schedule = hasDistSchedule ? omp::ClauseScheduleKind::Distribute
6201 : omp::ClauseScheduleKind::Static;
6202 // dist_schedule clauses are ordered - otherise this should be false
6203 bool isOrdered = hasDistSchedule;
6204 std::optional<omp::ScheduleModifier> scheduleMod;
6205 bool isSimd = false;
6206 llvm::omp::WorksharingLoopType workshareLoopType =
6207 llvm::omp::WorksharingLoopType::DistributeStaticLoop;
6208 bool loopNeedsBarrier = false;
6209 llvm::Value *chunk = moduleTranslation.lookupValue(
6210 distributeOp.getDistScheduleChunkSize());
6211 llvm::CanonicalLoopInfo *loopInfo =
6212 findCurrentLoopInfo(moduleTranslation);
6213 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
6214 ompBuilder->applyWorkshareLoop(
6215 ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
6216 convertToScheduleKind(schedule), chunk, isSimd,
6217 scheduleMod == omp::ScheduleModifier::monotonic,
6218 scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
6219 workshareLoopType, false, hasDistSchedule, chunk);
6220
6221 if (!wsloopIP)
6222 return wsloopIP.takeError();
6223 }
6224 if (failed(cleanupPrivateVars(builder, moduleTranslation,
6225 distributeOp.getLoc(), privVarsInfo.llvmVars,
6226 privVarsInfo.privatizers)))
6227 return llvm::make_error<PreviouslyReportedError>();
6228
6229 return llvm::Error::success();
6230 };
6231
6232 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
6233 findAllocaInsertPoint(builder, moduleTranslation);
6234 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
6235 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
6236 ompBuilder->createDistribute(ompLoc, allocaIP, bodyGenCB);
6237
6238 if (failed(handleError(afterIP, opInst)))
6239 return failure();
6240
6241 builder.restoreIP(*afterIP);
6242
6243 if (doDistributeReduction) {
6244 // Process the reductions if required.
6246 teamsOp, builder, moduleTranslation, allocaIP, reductionDecls,
6247 privateReductionVariables, isByRef,
6248 /*isNoWait*/ false, /*isTeamsReduction*/ true);
6249 }
6250 return success();
6251}
6252
6253/// Lowers the FlagsAttr which is applied to the module on the device
6254/// pass when offloading, this attribute contains OpenMP RTL globals that can
6255/// be passed as flags to the frontend, otherwise they are set to default
6256static LogicalResult
6257convertFlagsAttr(Operation *op, mlir::omp::FlagsAttr attribute,
6258 LLVM::ModuleTranslation &moduleTranslation) {
6259 if (!cast<mlir::ModuleOp>(op))
6260 return failure();
6261
6262 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
6263
6264 ompBuilder->M.addModuleFlag(llvm::Module::Max, "openmp-device",
6265 attribute.getOpenmpDeviceVersion());
6266
6267 if (attribute.getNoGpuLib())
6268 return success();
6269
6270 ompBuilder->createGlobalFlag(
6271 attribute.getDebugKind() /*LangOpts().OpenMPTargetDebug*/,
6272 "__omp_rtl_debug_kind");
6273 ompBuilder->createGlobalFlag(
6274 attribute
6275 .getAssumeTeamsOversubscription() /*LangOpts().OpenMPTeamSubscription*/
6276 ,
6277 "__omp_rtl_assume_teams_oversubscription");
6278 ompBuilder->createGlobalFlag(
6279 attribute
6280 .getAssumeThreadsOversubscription() /*LangOpts().OpenMPThreadSubscription*/
6281 ,
6282 "__omp_rtl_assume_threads_oversubscription");
6283 ompBuilder->createGlobalFlag(
6284 attribute.getAssumeNoThreadState() /*LangOpts().OpenMPNoThreadState*/,
6285 "__omp_rtl_assume_no_thread_state");
6286 ompBuilder->createGlobalFlag(
6287 attribute
6288 .getAssumeNoNestedParallelism() /*LangOpts().OpenMPNoNestedParallelism*/
6289 ,
6290 "__omp_rtl_assume_no_nested_parallelism");
6291 return success();
6292}
6293
6294static void getTargetEntryUniqueInfo(llvm::TargetRegionEntryInfo &targetInfo,
6295 omp::TargetOp targetOp,
6296 llvm::StringRef parentName = "") {
6297 auto fileLoc = targetOp.getLoc()->findInstanceOf<FileLineColLoc>();
6298
6299 assert(fileLoc && "No file found from location");
6300 StringRef fileName = fileLoc.getFilename().getValue();
6301
6302 llvm::sys::fs::UniqueID id;
6303 uint64_t line = fileLoc.getLine();
6304 if (auto ec = llvm::sys::fs::getUniqueID(fileName, id)) {
6305 size_t fileHash = llvm::hash_value(fileName.str());
6306 size_t deviceId = 0xdeadf17e;
6307 targetInfo =
6308 llvm::TargetRegionEntryInfo(parentName, deviceId, fileHash, line);
6309 } else {
6310 targetInfo = llvm::TargetRegionEntryInfo(parentName, id.getDevice(),
6311 id.getFile(), line);
6312 }
6313}
6314
6315static void
6316handleDeclareTargetMapVar(MapInfoData &mapData,
6317 LLVM::ModuleTranslation &moduleTranslation,
6318 llvm::IRBuilderBase &builder, llvm::Function *func) {
6319 assert(moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
6320 "function only supported for target device codegen");
6321 llvm::IRBuilderBase::InsertPointGuard guard(builder);
6322 for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
6323 // In the case of declare target mapped variables, the basePointer is
6324 // the reference pointer generated by the convertDeclareTargetAttr
6325 // method. Whereas the kernelValue is the original variable, so for
6326 // the device we must replace all uses of this original global variable
6327 // (stored in kernelValue) with the reference pointer (stored in
6328 // basePointer for declare target mapped variables), as for device the
6329 // data is mapped into this reference pointer and should be loaded
6330 // from it, the original variable is discarded. On host both exist and
6331 // metadata is generated (elsewhere in the convertDeclareTargetAttr)
6332 // function to link the two variables in the runtime and then both the
6333 // reference pointer and the pointer are assigned in the kernel argument
6334 // structure for the host.
6335 if (mapData.IsDeclareTarget[i]) {
6336 // If the original map value is a constant, then we have to make sure all
6337 // of it's uses within the current kernel/function that we are going to
6338 // rewrite are converted to instructions, as we will be altering the old
6339 // use (OriginalValue) from a constant to an instruction, which will be
6340 // illegal and ICE the compiler if the user is a constant expression of
6341 // some kind e.g. a constant GEP.
6342 if (auto *constant = dyn_cast<llvm::Constant>(mapData.OriginalValue[i]))
6343 convertUsersOfConstantsToInstructions(constant, func, false);
6344
6345 // The users iterator will get invalidated if we modify an element,
6346 // so we populate this vector of uses to alter each user on an
6347 // individual basis to emit its own load (rather than one load for
6348 // all).
6350 for (llvm::User *user : mapData.OriginalValue[i]->users())
6351 userVec.push_back(user);
6352
6353 for (llvm::User *user : userVec) {
6354 if (auto *insn = dyn_cast<llvm::Instruction>(user)) {
6355 if (insn->getFunction() == func) {
6356 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
6357 llvm::Value *substitute = mapData.BasePointers[i];
6358 if (isDeclareTargetLink(mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr()
6359 : mapOp.getVarPtr())) {
6360 builder.SetCurrentDebugLocation(insn->getDebugLoc());
6361 substitute = builder.CreateLoad(
6362 mapData.BasePointers[i]->getType(), mapData.BasePointers[i]);
6363 cast<llvm::LoadInst>(substitute)->moveBefore(insn->getIterator());
6364 }
6365 user->replaceUsesOfWith(mapData.OriginalValue[i], substitute);
6366 }
6367 }
6368 }
6369 }
6370 }
6371}
6372
6373// The createDeviceArgumentAccessor function generates
6374// instructions for retrieving (acessing) kernel
6375// arguments inside of the device kernel for use by
6376// the kernel. This enables different semantics such as
6377// the creation of temporary copies of data allowing
6378// semantics like read-only/no host write back kernel
6379// arguments.
6380//
6381// This currently implements a very light version of Clang's
6382// EmitParmDecl's handling of direct argument handling as well
6383// as a portion of the argument access generation based on
6384// capture types found at the end of emitOutlinedFunctionPrologue
6385// in Clang. The indirect path handling of EmitParmDecl's may be
6386// required for future work, but a direct 1-to-1 copy doesn't seem
6387// possible as the logic is rather scattered throughout Clang's
6388// lowering and perhaps we wish to deviate slightly.
6389//
6390// \param mapData - A container containing vectors of information
6391// corresponding to the input argument, which should have a
6392// corresponding entry in the MapInfoData containers
6393// OrigialValue's.
6394// \param arg - This is the generated kernel function argument that
6395// corresponds to the passed in input argument. We generated different
6396// accesses of this Argument, based on capture type and other Input
6397// related information.
6398// \param input - This is the host side value that will be passed to
6399// the kernel i.e. the kernel input, we rewrite all uses of this within
6400// the kernel (as we generate the kernel body based on the target's region
6401// which maintians references to the original input) to the retVal argument
6402// apon exit of this function inside of the OMPIRBuilder. This interlinks
6403// the kernel argument to future uses of it in the function providing
6404// appropriate "glue" instructions inbetween.
6405// \param retVal - This is the value that all uses of input inside of the
6406// kernel will be re-written to, the goal of this function is to generate
6407// an appropriate location for the kernel argument to be accessed from,
6408// e.g. ByRef will result in a temporary allocation location and then
6409// a store of the kernel argument into this allocated memory which
6410// will then be loaded from, ByCopy will use the allocated memory
6411// directly.
6412static llvm::IRBuilderBase::InsertPoint
6413createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg,
6414 llvm::Value *input, llvm::Value *&retVal,
6415 llvm::IRBuilderBase &builder,
6416 llvm::OpenMPIRBuilder &ompBuilder,
6417 LLVM::ModuleTranslation &moduleTranslation,
6418 llvm::IRBuilderBase::InsertPoint allocaIP,
6419 llvm::IRBuilderBase::InsertPoint codeGenIP) {
6420 assert(ompBuilder.Config.isTargetDevice() &&
6421 "function only supported for target device codegen");
6422 builder.restoreIP(allocaIP);
6423
6424 omp::VariableCaptureKind capture = omp::VariableCaptureKind::ByRef;
6425 LLVM::TypeToLLVMIRTranslator typeToLLVMIRTranslator(
6426 ompBuilder.M.getContext());
6427 unsigned alignmentValue = 0;
6428 // Find the associated MapInfoData entry for the current input
6429 for (size_t i = 0; i < mapData.MapClause.size(); ++i)
6430 if (mapData.OriginalValue[i] == input) {
6431 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
6432 capture = mapOp.getMapCaptureType();
6433 // Get information of alignment of mapped object
6434 alignmentValue = typeToLLVMIRTranslator.getPreferredAlignment(
6435 mapOp.getVarType(), ompBuilder.M.getDataLayout());
6436 break;
6437 }
6438
6439 unsigned int allocaAS = ompBuilder.M.getDataLayout().getAllocaAddrSpace();
6440 unsigned int defaultAS =
6441 ompBuilder.M.getDataLayout().getProgramAddressSpace();
6442
6443 // Create the alloca for the argument the current point.
6444 llvm::Value *v = builder.CreateAlloca(arg.getType(), allocaAS);
6445
6446 if (allocaAS != defaultAS && arg.getType()->isPointerTy())
6447 v = builder.CreateAddrSpaceCast(v, builder.getPtrTy(defaultAS));
6448
6449 builder.CreateStore(&arg, v);
6450
6451 builder.restoreIP(codeGenIP);
6452
6453 switch (capture) {
6454 case omp::VariableCaptureKind::ByCopy: {
6455 retVal = v;
6456 break;
6457 }
6458 case omp::VariableCaptureKind::ByRef: {
6459 llvm::LoadInst *loadInst = builder.CreateAlignedLoad(
6460 v->getType(), v,
6461 ompBuilder.M.getDataLayout().getPrefTypeAlign(v->getType()));
6462 // CreateAlignedLoad function creates similar LLVM IR:
6463 // %res = load ptr, ptr %input, align 8
6464 // This LLVM IR does not contain information about alignment
6465 // of the loaded value. We need to add !align metadata to unblock
6466 // optimizer. The existence of the !align metadata on the instruction
6467 // tells the optimizer that the value loaded is known to be aligned to
6468 // a boundary specified by the integer value in the metadata node.
6469 // Example:
6470 // %res = load ptr, ptr %input, align 8, !align !align_md_node
6471 // ^ ^
6472 // | |
6473 // alignment of %input address |
6474 // |
6475 // alignment of %res object
6476 if (v->getType()->isPointerTy() && alignmentValue) {
6477 llvm::MDBuilder MDB(builder.getContext());
6478 loadInst->setMetadata(
6479 llvm::LLVMContext::MD_align,
6480 llvm::MDNode::get(builder.getContext(),
6481 MDB.createConstant(llvm::ConstantInt::get(
6482 llvm::Type::getInt64Ty(builder.getContext()),
6483 alignmentValue))));
6484 }
6485 retVal = loadInst;
6486
6487 break;
6488 }
6489 case omp::VariableCaptureKind::This:
6490 case omp::VariableCaptureKind::VLAType:
6491 // TODO: Consider returning error to use standard reporting for
6492 // unimplemented features.
6493 assert(false && "Currently unsupported capture kind");
6494 break;
6495 }
6496
6497 return builder.saveIP();
6498}
6499
6500/// Follow uses of `host_eval`-defined block arguments of the given `omp.target`
6501/// operation and populate output variables with their corresponding host value
6502/// (i.e. operand evaluated outside of the target region), based on their uses
6503/// inside of the target region.
6504///
6505/// Loop bounds and steps are only optionally populated, if output vectors are
6506/// provided.
6507static void
6508extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
6509 Value &numTeamsLower, Value &numTeamsUpper,
6510 Value &threadLimit,
6511 llvm::SmallVectorImpl<Value> *lowerBounds = nullptr,
6512 llvm::SmallVectorImpl<Value> *upperBounds = nullptr,
6513 llvm::SmallVectorImpl<Value> *steps = nullptr) {
6514 auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
6515 for (auto item : llvm::zip_equal(targetOp.getHostEvalVars(),
6516 blockArgIface.getHostEvalBlockArgs())) {
6517 Value hostEvalVar = std::get<0>(item), blockArg = std::get<1>(item);
6518
6519 for (Operation *user : blockArg.getUsers()) {
6521 .Case([&](omp::TeamsOp teamsOp) {
6522 if (teamsOp.getNumTeamsLower() == blockArg)
6523 numTeamsLower = hostEvalVar;
6524 else if (llvm::is_contained(teamsOp.getNumTeamsUpperVars(),
6525 blockArg))
6526 numTeamsUpper = hostEvalVar;
6527 else if (!teamsOp.getThreadLimitVars().empty() &&
6528 teamsOp.getThreadLimit(0) == blockArg)
6529 threadLimit = hostEvalVar;
6530 else
6531 llvm_unreachable("unsupported host_eval use");
6532 })
6533 .Case([&](omp::ParallelOp parallelOp) {
6534 if (!parallelOp.getNumThreadsVars().empty() &&
6535 parallelOp.getNumThreads(0) == blockArg)
6536 numThreads = hostEvalVar;
6537 else
6538 llvm_unreachable("unsupported host_eval use");
6539 })
6540 .Case([&](omp::LoopNestOp loopOp) {
6541 auto processBounds =
6542 [&](OperandRange opBounds,
6543 llvm::SmallVectorImpl<Value> *outBounds) -> bool {
6544 bool found = false;
6545 for (auto [i, lb] : llvm::enumerate(opBounds)) {
6546 if (lb == blockArg) {
6547 found = true;
6548 if (outBounds)
6549 (*outBounds)[i] = hostEvalVar;
6550 }
6551 }
6552 return found;
6553 };
6554 bool found =
6555 processBounds(loopOp.getLoopLowerBounds(), lowerBounds);
6556 found = processBounds(loopOp.getLoopUpperBounds(), upperBounds) ||
6557 found;
6558 found = processBounds(loopOp.getLoopSteps(), steps) || found;
6559 (void)found;
6560 assert(found && "unsupported host_eval use");
6561 })
6562 .DefaultUnreachable("unsupported host_eval use");
6563 }
6564 }
6565}
6566
6567/// If \p op is of the given type parameter, return it casted to that type.
6568/// Otherwise, if its immediate parent operation (or some other higher-level
6569/// parent, if \p immediateParent is false) is of that type, return that parent
6570/// casted to the given type.
6571///
6572/// If \p op is \c null or neither it or its parent(s) are of the specified
6573/// type, return a \c null operation.
6574template <typename OpTy>
6575static OpTy castOrGetParentOfType(Operation *op, bool immediateParent = false) {
6576 if (!op)
6577 return OpTy();
6578
6579 if (OpTy casted = dyn_cast<OpTy>(op))
6580 return casted;
6581
6582 if (immediateParent)
6583 return dyn_cast_if_present<OpTy>(op->getParentOp());
6584
6585 return op->getParentOfType<OpTy>();
6586}
6587
6588/// If the given \p value is defined by an \c llvm.mlir.constant operation and
6589/// it is of an integer type, return its value.
6590static std::optional<int64_t> extractConstInteger(Value value) {
6591 if (!value)
6592 return std::nullopt;
6593
6594 if (auto constOp = value.getDefiningOp<LLVM::ConstantOp>())
6595 if (auto constAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
6596 return constAttr.getInt();
6597
6598 return std::nullopt;
6599}
6600
6601static uint64_t getTypeByteSize(mlir::Type type, const DataLayout &dl) {
6602 uint64_t sizeInBits = dl.getTypeSizeInBits(type);
6603 uint64_t sizeInBytes = sizeInBits / 8;
6604 return sizeInBytes;
6605}
6606
6607template <typename OpTy>
6608static uint64_t getReductionDataSize(OpTy &op) {
6609 if (op.getNumReductionVars() > 0) {
6611 collectReductionDecls(op, reductions);
6612
6614 members.reserve(reductions.size());
6615 for (omp::DeclareReductionOp &red : reductions)
6616 members.push_back(red.getType());
6617 Operation *opp = op.getOperation();
6618 auto structType = mlir::LLVM::LLVMStructType::getLiteral(
6619 opp->getContext(), members, /*isPacked=*/false);
6620 DataLayout dl = DataLayout(opp->getParentOfType<ModuleOp>());
6621 return getTypeByteSize(structType, dl);
6622 }
6623 return 0;
6624}
6625
6626/// Populate default `MinTeams`, `MaxTeams` and `MaxThreads` to their default
6627/// values as stated by the corresponding clauses, if constant.
6628///
6629/// These default values must be set before the creation of the outlined LLVM
6630/// function for the target region, so that they can be used to initialize the
6631/// corresponding global `ConfigurationEnvironmentTy` structure.
6632static void
6633initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
6634 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs,
6635 bool isTargetDevice, bool isGPU) {
6636 // TODO: Handle constant 'if' clauses.
6637
6638 Value numThreads, numTeamsLower, numTeamsUpper, threadLimit;
6639 if (!isTargetDevice) {
6640 extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper,
6641 threadLimit);
6642 } else {
6643 // In the target device, values for these clauses are not passed as
6644 // host_eval, but instead evaluated prior to entry to the region. This
6645 // ensures values are mapped and available inside of the target region.
6646 if (auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
6647 numTeamsLower = teamsOp.getNumTeamsLower();
6648 // Handle num_teams upper bounds (only first value for now)
6649 if (!teamsOp.getNumTeamsUpperVars().empty())
6650 numTeamsUpper = teamsOp.getNumTeams(0);
6651 if (!teamsOp.getThreadLimitVars().empty())
6652 threadLimit = teamsOp.getThreadLimit(0);
6653 }
6654
6655 if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) {
6656 if (!parallelOp.getNumThreadsVars().empty())
6657 numThreads = parallelOp.getNumThreads(0);
6658 }
6659 }
6660
6661 // Handle clauses impacting the number of teams.
6662
6663 int32_t minTeamsVal = 1, maxTeamsVal = -1;
6664 if (castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
6665 // TODO: Use `hostNumTeamsLower` to initialize `minTeamsVal`. For now,
6666 // match clang and set min and max to the same value.
6667 if (numTeamsUpper) {
6668 if (auto val = extractConstInteger(numTeamsUpper))
6669 minTeamsVal = maxTeamsVal = *val;
6670 } else {
6671 minTeamsVal = maxTeamsVal = 0;
6672 }
6673 } else if (castOrGetParentOfType<omp::ParallelOp>(capturedOp,
6674 /*immediateParent=*/true) ||
6676 /*immediateParent=*/true)) {
6677 minTeamsVal = maxTeamsVal = 1;
6678 } else {
6679 minTeamsVal = maxTeamsVal = -1;
6680 }
6681
6682 // Handle clauses impacting the number of threads.
6683
6684 auto setMaxValueFromClause = [](Value clauseValue, int32_t &result) {
6685 if (!clauseValue)
6686 return;
6687
6688 if (auto val = extractConstInteger(clauseValue))
6689 result = *val;
6690
6691 // Found an applicable clause, so it's not undefined. Mark as unknown
6692 // because it's not constant.
6693 if (result < 0)
6694 result = 0;
6695 };
6696
6697 // Extract 'thread_limit' clause from 'target' and 'teams' directives.
6698 int32_t targetThreadLimitVal = -1, teamsThreadLimitVal = -1;
6699 if (!targetOp.getThreadLimitVars().empty())
6700 setMaxValueFromClause(targetOp.getThreadLimit(0), targetThreadLimitVal);
6701 setMaxValueFromClause(threadLimit, teamsThreadLimitVal);
6702
6703 // Extract 'max_threads' clause from 'parallel' or set to 1 if it's SIMD.
6704 int32_t maxThreadsVal = -1;
6706 setMaxValueFromClause(numThreads, maxThreadsVal);
6707 else if (castOrGetParentOfType<omp::SimdOp>(capturedOp,
6708 /*immediateParent=*/true))
6709 maxThreadsVal = 1;
6710
6711 // For max values, < 0 means unset, == 0 means set but unknown. Select the
6712 // minimum value between 'max_threads' and 'thread_limit' clauses that were
6713 // set.
6714 int32_t combinedMaxThreadsVal = targetThreadLimitVal;
6715 if (combinedMaxThreadsVal < 0 ||
6716 (teamsThreadLimitVal >= 0 && teamsThreadLimitVal < combinedMaxThreadsVal))
6717 combinedMaxThreadsVal = teamsThreadLimitVal;
6718
6719 if (combinedMaxThreadsVal < 0 ||
6720 (maxThreadsVal >= 0 && maxThreadsVal < combinedMaxThreadsVal))
6721 combinedMaxThreadsVal = maxThreadsVal;
6722
6723 int32_t reductionDataSize = 0;
6724 if (isGPU && capturedOp) {
6725 if (auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp))
6726 reductionDataSize = getReductionDataSize(teamsOp);
6727 }
6728
6729 // Update kernel bounds structure for the `OpenMPIRBuilder` to use.
6730 omp::TargetRegionFlags kernelFlags = targetOp.getKernelExecFlags(capturedOp);
6731 assert(
6732 omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic |
6733 omp::TargetRegionFlags::spmd) &&
6734 "invalid kernel flags");
6735 attrs.ExecFlags =
6736 omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic)
6737 ? omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::spmd)
6738 ? llvm::omp::OMP_TGT_EXEC_MODE_GENERIC_SPMD
6739 : llvm::omp::OMP_TGT_EXEC_MODE_GENERIC
6740 : llvm::omp::OMP_TGT_EXEC_MODE_SPMD;
6741 if (omp::bitEnumContainsAll(kernelFlags,
6742 omp::TargetRegionFlags::spmd |
6743 omp::TargetRegionFlags::no_loop) &&
6744 !omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic))
6745 attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_SPMD_NO_LOOP;
6746
6747 attrs.MinTeams = minTeamsVal;
6748 attrs.MaxTeams.front() = maxTeamsVal;
6749 attrs.MinThreads = 1;
6750 attrs.MaxThreads.front() = combinedMaxThreadsVal;
6751 attrs.ReductionDataSize = reductionDataSize;
6752 // TODO: Allow modified buffer length similar to
6753 // fopenmp-cuda-teams-reduction-recs-num flag in clang.
6754 if (attrs.ReductionDataSize != 0)
6755 attrs.ReductionBufferLength = 1024;
6756}
6757
6758/// Gather LLVM runtime values for all clauses evaluated in the host that are
6759/// passed to the kernel invocation.
6760///
6761/// This function must be called only when compiling for the host. Also, it will
6762/// only provide correct results if it's called after the body of \c targetOp
6763/// has been fully generated.
6764static void
6765initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
6766 LLVM::ModuleTranslation &moduleTranslation,
6767 omp::TargetOp targetOp, Operation *capturedOp,
6768 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs) {
6769 omp::LoopNestOp loopOp = castOrGetParentOfType<omp::LoopNestOp>(capturedOp);
6770 unsigned numLoops = loopOp ? loopOp.getNumLoops() : 0;
6771
6772 Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
6773 llvm::SmallVector<Value> lowerBounds(numLoops), upperBounds(numLoops),
6774 steps(numLoops);
6775 extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper,
6776 teamsThreadLimit, &lowerBounds, &upperBounds, &steps);
6777
6778 // TODO: Handle constant 'if' clauses.
6779 if (!targetOp.getThreadLimitVars().empty()) {
6780 Value targetThreadLimit = targetOp.getThreadLimit(0);
6781 attrs.TargetThreadLimit.front() =
6782 moduleTranslation.lookupValue(targetThreadLimit);
6783 }
6784
6785 // The __kmpc_push_num_teams_51 function expects int32 as the arguments. So,
6786 // truncate or sign extend lower and upper num_teams bounds as well as
6787 // thread_limit to match int32 ABI requirements for the OpenMP runtime.
6788 if (numTeamsLower)
6789 attrs.MinTeams = builder.CreateSExtOrTrunc(
6790 moduleTranslation.lookupValue(numTeamsLower), builder.getInt32Ty());
6791
6792 if (numTeamsUpper)
6793 attrs.MaxTeams.front() = builder.CreateSExtOrTrunc(
6794 moduleTranslation.lookupValue(numTeamsUpper), builder.getInt32Ty());
6795
6796 if (teamsThreadLimit)
6797 attrs.TeamsThreadLimit.front() = builder.CreateSExtOrTrunc(
6798 moduleTranslation.lookupValue(teamsThreadLimit), builder.getInt32Ty());
6799
6800 if (numThreads)
6801 attrs.MaxThreads = moduleTranslation.lookupValue(numThreads);
6802
6803 if (omp::bitEnumContainsAny(targetOp.getKernelExecFlags(capturedOp),
6804 omp::TargetRegionFlags::trip_count)) {
6805 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
6806 attrs.LoopTripCount = nullptr;
6807
6808 // To calculate the trip count, we multiply together the trip counts of
6809 // every collapsed canonical loop. We don't need to create the loop nests
6810 // here, since we're only interested in the trip count.
6811 for (auto [loopLower, loopUpper, loopStep] :
6812 llvm::zip_equal(lowerBounds, upperBounds, steps)) {
6813 llvm::Value *lowerBound = moduleTranslation.lookupValue(loopLower);
6814 llvm::Value *upperBound = moduleTranslation.lookupValue(loopUpper);
6815 llvm::Value *step = moduleTranslation.lookupValue(loopStep);
6816
6817 if (!lowerBound || !upperBound || !step) {
6818 attrs.LoopTripCount = nullptr;
6819 break;
6820 }
6821
6822 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
6823 llvm::Value *tripCount = ompBuilder->calculateCanonicalLoopTripCount(
6824 loc, lowerBound, upperBound, step, /*IsSigned=*/true,
6825 loopOp.getLoopInclusive());
6826
6827 if (!attrs.LoopTripCount) {
6828 attrs.LoopTripCount = tripCount;
6829 continue;
6830 }
6831
6832 // TODO: Enable UndefinedSanitizer to diagnose an overflow here.
6833 attrs.LoopTripCount = builder.CreateMul(attrs.LoopTripCount, tripCount,
6834 {}, /*HasNUW=*/true);
6835 }
6836 }
6837
6838 attrs.DeviceID = builder.getInt64(llvm::omp::OMP_DEVICEID_UNDEF);
6839 if (mlir::Value devId = targetOp.getDevice()) {
6840 attrs.DeviceID = moduleTranslation.lookupValue(devId);
6841 attrs.DeviceID =
6842 builder.CreateSExtOrTrunc(attrs.DeviceID, builder.getInt64Ty());
6843 }
6844}
6845
6846static LogicalResult
6847convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
6848 LLVM::ModuleTranslation &moduleTranslation) {
6849 auto targetOp = cast<omp::TargetOp>(opInst);
6850
6851 // The current debug location already has the DISubprogram for the outlined
6852 // function that will be created for the target op. We save it here so that
6853 // we can set it on the outlined function.
6854 llvm::DebugLoc outlinedFnLoc = builder.getCurrentDebugLocation();
6855 if (failed(checkImplementationStatus(opInst)))
6856 return failure();
6857
6858 // During the handling of target op, we will generate instructions in the
6859 // parent function like call to the oulined function or branch to a new
6860 // BasicBlock. We set the debug location here to parent function so that those
6861 // get the correct debug locations. For outlined functions, the normal MLIR op
6862 // conversion will automatically pick the correct location.
6863 llvm::BasicBlock *parentBB = builder.GetInsertBlock();
6864 assert(parentBB && "No insert block is set for the builder");
6865 llvm::Function *parentLLVMFn = parentBB->getParent();
6866 assert(parentLLVMFn && "Parent Function must be valid");
6867 if (llvm::DISubprogram *SP = parentLLVMFn->getSubprogram())
6868 builder.SetCurrentDebugLocation(llvm::DILocation::get(
6869 parentLLVMFn->getContext(), outlinedFnLoc.getLine(),
6870 outlinedFnLoc.getCol(), SP, outlinedFnLoc.getInlinedAt()));
6871
6872 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
6873 bool isTargetDevice = ompBuilder->Config.isTargetDevice();
6874 bool isGPU = ompBuilder->Config.isGPU();
6875
6876 auto parentFn = opInst.getParentOfType<LLVM::LLVMFuncOp>();
6877 auto argIface = cast<omp::BlockArgOpenMPOpInterface>(opInst);
6878 auto &targetRegion = targetOp.getRegion();
6879 // Holds the private vars that have been mapped along with the block
6880 // argument that corresponds to the MapInfoOp corresponding to the private
6881 // var in question. So, for instance:
6882 //
6883 // %10 = omp.map.info var_ptr(%6#0 : !fir.ref<!fir.box<!fir.heap<i32>>>, ..)
6884 // omp.target map_entries(%10 -> %arg0) private(@box.privatizer %6#0-> %arg1)
6885 //
6886 // Then, %10 has been created so that the descriptor can be used by the
6887 // privatizer @box.privatizer on the device side. Here we'd record {%6#0,
6888 // %arg0} in the mappedPrivateVars map.
6889 llvm::DenseMap<Value, Value> mappedPrivateVars;
6890 DataLayout dl = DataLayout(opInst.getParentOfType<ModuleOp>());
6891 SmallVector<Value> mapVars = targetOp.getMapVars();
6892 SmallVector<Value> hdaVars = targetOp.getHasDeviceAddrVars();
6893 ArrayRef<BlockArgument> mapBlockArgs = argIface.getMapBlockArgs();
6894 ArrayRef<BlockArgument> hdaBlockArgs = argIface.getHasDeviceAddrBlockArgs();
6895 llvm::Function *llvmOutlinedFn = nullptr;
6896 TargetDirectiveEnumTy targetDirective =
6897 getTargetDirectiveEnumTyFromOp(&opInst);
6898
6899 // TODO: It can also be false if a compile-time constant `false` IF clause is
6900 // specified.
6901 bool isOffloadEntry =
6902 isTargetDevice || !ompBuilder->Config.TargetTriples.empty();
6903
6904 // For some private variables, the MapsForPrivatizedVariablesPass
6905 // creates MapInfoOp instances. Go through the private variables and
6906 // the mapped variables so that during codegeneration we are able
6907 // to quickly look up the corresponding map variable, if any for each
6908 // private variable.
6909 if (!targetOp.getPrivateVars().empty() && !targetOp.getMapVars().empty()) {
6910 OperandRange privateVars = targetOp.getPrivateVars();
6911 std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms();
6912 std::optional<DenseI64ArrayAttr> privateMapIndices =
6913 targetOp.getPrivateMapsAttr();
6914
6915 for (auto [privVarIdx, privVarSymPair] :
6916 llvm::enumerate(llvm::zip_equal(privateVars, *privateSyms))) {
6917 auto privVar = std::get<0>(privVarSymPair);
6918 auto privSym = std::get<1>(privVarSymPair);
6919
6920 SymbolRefAttr privatizerName = llvm::cast<SymbolRefAttr>(privSym);
6921 omp::PrivateClauseOp privatizer =
6922 findPrivatizer(targetOp, privatizerName);
6923
6924 if (!privatizer.needsMap())
6925 continue;
6926
6927 mlir::Value mappedValue =
6928 targetOp.getMappedValueForPrivateVar(privVarIdx);
6929 assert(mappedValue && "Expected to find mapped value for a privatized "
6930 "variable that needs mapping");
6931
6932 // The MapInfoOp defining the map var isn't really needed later.
6933 // So, we don't store it in any datastructure. Instead, we just
6934 // do some sanity checks on it right now.
6935 auto mapInfoOp = mappedValue.getDefiningOp<omp::MapInfoOp>();
6936 [[maybe_unused]] Type varType = mapInfoOp.getVarType();
6937
6938 // Check #1: Check that the type of the private variable matches
6939 // the type of the variable being mapped.
6940 if (!isa<LLVM::LLVMPointerType>(privVar.getType()))
6941 assert(
6942 varType == privVar.getType() &&
6943 "Type of private var doesn't match the type of the mapped value");
6944
6945 // Ok, only 1 sanity check for now.
6946 // Record the block argument corresponding to this mapvar.
6947 mappedPrivateVars.insert(
6948 {privVar,
6949 targetRegion.getArgument(argIface.getMapBlockArgsStart() +
6950 (*privateMapIndices)[privVarIdx])});
6951 }
6952 }
6953
6954 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
6955 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP)
6956 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
6957 llvm::IRBuilderBase::InsertPointGuard guard(builder);
6958 builder.SetCurrentDebugLocation(llvm::DebugLoc());
6959 // Forward target-cpu and target-features function attributes from the
6960 // original function to the new outlined function.
6961 llvm::Function *llvmParentFn =
6962 moduleTranslation.lookupFunction(parentFn.getName());
6963 llvmOutlinedFn = codeGenIP.getBlock()->getParent();
6964 assert(llvmParentFn && llvmOutlinedFn &&
6965 "Both parent and outlined functions must exist at this point");
6966
6967 if (outlinedFnLoc && llvmParentFn->getSubprogram())
6968 llvmOutlinedFn->setSubprogram(outlinedFnLoc->getScope()->getSubprogram());
6969
6970 if (auto attr = llvmParentFn->getFnAttribute("target-cpu");
6971 attr.isStringAttribute())
6972 llvmOutlinedFn->addFnAttr(attr);
6973
6974 if (auto attr = llvmParentFn->getFnAttribute("target-features");
6975 attr.isStringAttribute())
6976 llvmOutlinedFn->addFnAttr(attr);
6977
6978 for (auto [arg, mapOp] : llvm::zip_equal(mapBlockArgs, mapVars)) {
6979 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
6980 llvm::Value *mapOpValue =
6981 moduleTranslation.lookupValue(mapInfoOp.getVarPtr());
6982 moduleTranslation.mapValue(arg, mapOpValue);
6983 }
6984 for (auto [arg, mapOp] : llvm::zip_equal(hdaBlockArgs, hdaVars)) {
6985 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
6986 llvm::Value *mapOpValue =
6987 moduleTranslation.lookupValue(mapInfoOp.getVarPtr());
6988 moduleTranslation.mapValue(arg, mapOpValue);
6989 }
6990
6991 // Do privatization after moduleTranslation has already recorded
6992 // mapped values.
6993 PrivateVarsInfo privateVarsInfo(targetOp);
6994
6996 allocatePrivateVars(builder, moduleTranslation, privateVarsInfo,
6997 allocaIP, &mappedPrivateVars);
6998
6999 if (failed(handleError(afterAllocas, *targetOp)))
7000 return llvm::make_error<PreviouslyReportedError>();
7001
7002 builder.restoreIP(codeGenIP);
7003 if (handleError(initPrivateVars(builder, moduleTranslation, privateVarsInfo,
7004 &mappedPrivateVars),
7005 *targetOp)
7006 .failed())
7007 return llvm::make_error<PreviouslyReportedError>();
7008
7009 if (failed(copyFirstPrivateVars(
7010 targetOp, builder, moduleTranslation, privateVarsInfo.mlirVars,
7011 privateVarsInfo.llvmVars, privateVarsInfo.privatizers,
7012 targetOp.getPrivateNeedsBarrier(), &mappedPrivateVars)))
7013 return llvm::make_error<PreviouslyReportedError>();
7014
7015 SmallVector<Region *> privateCleanupRegions;
7016 llvm::transform(privateVarsInfo.privatizers,
7017 std::back_inserter(privateCleanupRegions),
7018 [](omp::PrivateClauseOp privatizer) {
7019 return &privatizer.getDeallocRegion();
7020 });
7021
7023 targetRegion, "omp.target", builder, moduleTranslation);
7024
7025 if (!exitBlock)
7026 return exitBlock.takeError();
7027
7028 builder.SetInsertPoint(*exitBlock);
7029 if (!privateCleanupRegions.empty()) {
7030 if (failed(inlineOmpRegionCleanup(
7031 privateCleanupRegions, privateVarsInfo.llvmVars,
7032 moduleTranslation, builder, "omp.targetop.private.cleanup",
7033 /*shouldLoadCleanupRegionArg=*/false))) {
7034 return llvm::createStringError(
7035 "failed to inline `dealloc` region of `omp.private` "
7036 "op in the target region");
7037 }
7038 return builder.saveIP();
7039 }
7040
7041 return InsertPointTy(exitBlock.get(), exitBlock.get()->end());
7042 };
7043
7044 StringRef parentName = parentFn.getName();
7045
7046 llvm::TargetRegionEntryInfo entryInfo;
7047
7048 getTargetEntryUniqueInfo(entryInfo, targetOp, parentName);
7049
7050 MapInfoData mapData;
7051 collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, dl,
7052 builder, /*useDevPtrOperands=*/{},
7053 /*useDevAddrOperands=*/{}, hdaVars);
7054
7055 MapInfosTy combinedInfos;
7056 auto genMapInfoCB =
7057 [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP) -> MapInfosTy & {
7058 builder.restoreIP(codeGenIP);
7059 genMapInfos(builder, moduleTranslation, dl, combinedInfos, mapData,
7060 targetDirective);
7061
7062 // Append a null entry for the implicit dyn_ptr argument so the argument
7063 // count sent to the runtime already includes it.
7064 auto *nullPtr = llvm::Constant::getNullValue(builder.getPtrTy());
7065 combinedInfos.BasePointers.push_back(nullPtr);
7066 combinedInfos.Pointers.push_back(nullPtr);
7067 combinedInfos.DevicePointers.push_back(
7068 llvm::OpenMPIRBuilder::DeviceInfoTy::None);
7069 combinedInfos.Sizes.push_back(builder.getInt64(0));
7070 combinedInfos.Types.push_back(
7071 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM |
7072 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL);
7073 if (!combinedInfos.Names.empty())
7074 combinedInfos.Names.push_back(nullPtr);
7075 combinedInfos.Mappers.push_back(nullptr);
7076
7077 return combinedInfos;
7078 };
7079
7080 auto argAccessorCB = [&](llvm::Argument &arg, llvm::Value *input,
7081 llvm::Value *&retVal, InsertPointTy allocaIP,
7082 InsertPointTy codeGenIP)
7083 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
7084 llvm::IRBuilderBase::InsertPointGuard guard(builder);
7085 builder.SetCurrentDebugLocation(llvm::DebugLoc());
7086 // We just return the unaltered argument for the host function
7087 // for now, some alterations may be required in the future to
7088 // keep host fallback functions working identically to the device
7089 // version (e.g. pass ByCopy values should be treated as such on
7090 // host and device, currently not always the case)
7091 if (!isTargetDevice) {
7092 retVal = cast<llvm::Value>(&arg);
7093 return codeGenIP;
7094 }
7095
7096 return createDeviceArgumentAccessor(mapData, arg, input, retVal, builder,
7097 *ompBuilder, moduleTranslation,
7098 allocaIP, codeGenIP);
7099 };
7100
7101 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs runtimeAttrs;
7102 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs;
7103 Operation *targetCapturedOp = targetOp.getInnermostCapturedOmpOp();
7104 initTargetDefaultAttrs(targetOp, targetCapturedOp, defaultAttrs,
7105 isTargetDevice, isGPU);
7106
7107 // Collect host-evaluated values needed to properly launch the kernel from the
7108 // host.
7109 if (!isTargetDevice)
7110 initTargetRuntimeAttrs(builder, moduleTranslation, targetOp,
7111 targetCapturedOp, runtimeAttrs);
7112
7113 // Pass host-evaluated values as parameters to the kernel / host fallback,
7114 // except if they are constants. In any case, map the MLIR block argument to
7115 // the corresponding LLVM values.
7117 SmallVector<Value> hostEvalVars = targetOp.getHostEvalVars();
7118 ArrayRef<BlockArgument> hostEvalBlockArgs = argIface.getHostEvalBlockArgs();
7119 for (auto [arg, var] : llvm::zip_equal(hostEvalBlockArgs, hostEvalVars)) {
7120 llvm::Value *value = moduleTranslation.lookupValue(var);
7121 moduleTranslation.mapValue(arg, value);
7122
7123 if (!llvm::isa<llvm::Constant>(value))
7124 kernelInput.push_back(value);
7125 }
7126
7127 for (size_t i = 0, e = mapData.OriginalValue.size(); i != e; ++i) {
7128 // declare target arguments are not passed to kernels as arguments
7129 // TODO: We currently do not handle cases where a member is explicitly
7130 // passed in as an argument, this will likley need to be handled in
7131 // the near future, rather than using IsAMember, it may be better to
7132 // test if the relevant BlockArg is used within the target region and
7133 // then use that as a basis for exclusion in the kernel inputs.
7134 if (!mapData.IsDeclareTarget[i] && !mapData.IsAMember[i])
7135 kernelInput.push_back(mapData.OriginalValue[i]);
7136 }
7137
7139 buildDependData(targetOp.getDependKinds(), targetOp.getDependVars(),
7140 moduleTranslation, dds);
7141
7142 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
7143 findAllocaInsertPoint(builder, moduleTranslation);
7144 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
7145
7146 llvm::OpenMPIRBuilder::TargetDataInfo info(
7147 /*RequiresDevicePointerInfo=*/false,
7148 /*SeparateBeginEndCalls=*/true);
7149
7150 auto customMapperCB =
7151 [&](unsigned int i) -> llvm::Expected<llvm::Function *> {
7152 if (!combinedInfos.Mappers[i])
7153 return nullptr;
7154 info.HasMapper = true;
7155 return getOrCreateUserDefinedMapperFunc(combinedInfos.Mappers[i], builder,
7156 moduleTranslation, targetDirective);
7157 };
7158
7159 llvm::Value *ifCond = nullptr;
7160 if (Value targetIfCond = targetOp.getIfExpr())
7161 ifCond = moduleTranslation.lookupValue(targetIfCond);
7162
7163 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
7164 moduleTranslation.getOpenMPBuilder()->createTarget(
7165 ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), info, entryInfo,
7166 defaultAttrs, runtimeAttrs, ifCond, kernelInput, genMapInfoCB, bodyCB,
7167 argAccessorCB, customMapperCB, dds, targetOp.getNowait());
7168
7169 if (failed(handleError(afterIP, opInst)))
7170 return failure();
7171
7172 builder.restoreIP(*afterIP);
7173
7174 // Remap access operations to declare target reference pointers for the
7175 // device, essentially generating extra loadop's as necessary
7176 if (moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice())
7177 handleDeclareTargetMapVar(mapData, moduleTranslation, builder,
7178 llvmOutlinedFn);
7179
7180 return success();
7181}
7182
7183static LogicalResult
7184convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute,
7185 LLVM::ModuleTranslation &moduleTranslation) {
7186 // Amend omp.declare_target by deleting the IR of the outlined functions
7187 // created for target regions. They cannot be filtered out from MLIR earlier
7188 // because the omp.target operation inside must be translated to LLVM, but
7189 // the wrapper functions themselves must not remain at the end of the
7190 // process. We know that functions where omp.declare_target does not match
7191 // omp.is_target_device at this stage can only be wrapper functions because
7192 // those that aren't are removed earlier as an MLIR transformation pass.
7193 if (FunctionOpInterface funcOp = dyn_cast<FunctionOpInterface>(op)) {
7194 if (auto offloadMod = dyn_cast<omp::OffloadModuleInterface>(
7195 op->getParentOfType<ModuleOp>().getOperation())) {
7196 if (!offloadMod.getIsTargetDevice())
7197 return success();
7198
7199 omp::DeclareTargetDeviceType declareType =
7200 attribute.getDeviceType().getValue();
7201
7202 if (declareType == omp::DeclareTargetDeviceType::host) {
7203 llvm::Function *llvmFunc =
7204 moduleTranslation.lookupFunction(funcOp.getName());
7205 llvmFunc->dropAllReferences();
7206 llvmFunc->eraseFromParent();
7207 }
7208 }
7209 return success();
7210 }
7211
7212 if (LLVM::GlobalOp gOp = dyn_cast<LLVM::GlobalOp>(op)) {
7213 llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
7214 if (auto *gVal = llvmModule->getNamedValue(gOp.getSymName())) {
7215 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
7216 bool isDeclaration = gOp.isDeclaration();
7217 bool isExternallyVisible =
7218 gOp.getVisibility() != mlir::SymbolTable::Visibility::Private;
7219 auto loc = op->getLoc()->findInstanceOf<FileLineColLoc>();
7220 llvm::StringRef mangledName = gOp.getSymName();
7221 auto captureClause =
7222 convertToCaptureClauseKind(attribute.getCaptureClause().getValue());
7223 auto deviceClause =
7224 convertToDeviceClauseKind(attribute.getDeviceType().getValue());
7225 // unused for MLIR at the moment, required in Clang for book
7226 // keeping
7227 std::vector<llvm::GlobalVariable *> generatedRefs;
7228
7229 std::vector<llvm::Triple> targetTriple;
7230 auto targetTripleAttr = dyn_cast_or_null<mlir::StringAttr>(
7231 op->getParentOfType<mlir::ModuleOp>()->getAttr(
7232 LLVM::LLVMDialect::getTargetTripleAttrName()));
7233 if (targetTripleAttr)
7234 targetTriple.emplace_back(targetTripleAttr.data());
7235
7236 auto fileInfoCallBack = [&loc]() {
7237 std::string filename = "";
7238 std::uint64_t lineNo = 0;
7239
7240 if (loc) {
7241 filename = loc.getFilename().str();
7242 lineNo = loc.getLine();
7243 }
7244
7245 return std::pair<std::string, std::uint64_t>(llvm::StringRef(filename),
7246 lineNo);
7247 };
7248
7249 auto vfs = llvm::vfs::getRealFileSystem();
7250
7251 ompBuilder->registerTargetGlobalVariable(
7252 captureClause, deviceClause, isDeclaration, isExternallyVisible,
7253 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack, *vfs),
7254 mangledName, generatedRefs, /*OpenMPSimd*/ false, targetTriple,
7255 /*GlobalInitializer*/ nullptr, /*VariableLinkage*/ nullptr,
7256 gVal->getType(), gVal);
7257
7258 if (ompBuilder->Config.isTargetDevice() &&
7259 (attribute.getCaptureClause().getValue() !=
7260 mlir::omp::DeclareTargetCaptureClause::to ||
7261 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
7262 ompBuilder->getAddrOfDeclareTargetVar(
7263 captureClause, deviceClause, isDeclaration, isExternallyVisible,
7264 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack, *vfs),
7265 mangledName, generatedRefs, /*OpenMPSimd*/ false, targetTriple,
7266 gVal->getType(), /*GlobalInitializer*/ nullptr,
7267 /*VariableLinkage*/ nullptr);
7268 }
7269 }
7270 }
7271
7272 return success();
7273}
7274
7275namespace {
7276
7277/// Implementation of the dialect interface that converts operations belonging
7278/// to the OpenMP dialect to LLVM IR.
7279class OpenMPDialectLLVMIRTranslationInterface
7280 : public LLVMTranslationDialectInterface {
7281public:
7282 using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
7283
7284 /// Translates the given operation to LLVM IR using the provided IR builder
7285 /// and saving the state in `moduleTranslation`.
7286 LogicalResult
7287 convertOperation(Operation *op, llvm::IRBuilderBase &builder,
7288 LLVM::ModuleTranslation &moduleTranslation) const final;
7289
7290 /// Given an OpenMP MLIR attribute, create the corresponding LLVM-IR,
7291 /// runtime calls, or operation amendments
7292 LogicalResult
7293 amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
7294 NamedAttribute attribute,
7295 LLVM::ModuleTranslation &moduleTranslation) const final;
7296};
7297
7298} // namespace
7299
7300LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
7301 Operation *op, ArrayRef<llvm::Instruction *> instructions,
7302 NamedAttribute attribute,
7303 LLVM::ModuleTranslation &moduleTranslation) const {
7304 return llvm::StringSwitch<llvm::function_ref<LogicalResult(Attribute)>>(
7305 attribute.getName())
7306 .Case("omp.is_target_device",
7307 [&](Attribute attr) {
7308 if (auto deviceAttr = dyn_cast<BoolAttr>(attr)) {
7309 llvm::OpenMPIRBuilderConfig &config =
7310 moduleTranslation.getOpenMPBuilder()->Config;
7311 config.setIsTargetDevice(deviceAttr.getValue());
7312 return success();
7313 }
7314 return failure();
7315 })
7316 .Case("omp.is_gpu",
7317 [&](Attribute attr) {
7318 if (auto gpuAttr = dyn_cast<BoolAttr>(attr)) {
7319 llvm::OpenMPIRBuilderConfig &config =
7320 moduleTranslation.getOpenMPBuilder()->Config;
7321 config.setIsGPU(gpuAttr.getValue());
7322 return success();
7323 }
7324 return failure();
7325 })
7326 .Case("omp.host_ir_filepath",
7327 [&](Attribute attr) {
7328 if (auto filepathAttr = dyn_cast<StringAttr>(attr)) {
7329 llvm::OpenMPIRBuilder *ompBuilder =
7330 moduleTranslation.getOpenMPBuilder();
7331 auto VFS = llvm::vfs::getRealFileSystem();
7332 ompBuilder->loadOffloadInfoMetadata(*VFS,
7333 filepathAttr.getValue());
7334 return success();
7335 }
7336 return failure();
7337 })
7338 .Case("omp.flags",
7339 [&](Attribute attr) {
7340 if (auto rtlAttr = dyn_cast<omp::FlagsAttr>(attr))
7341 return convertFlagsAttr(op, rtlAttr, moduleTranslation);
7342 return failure();
7343 })
7344 .Case("omp.version",
7345 [&](Attribute attr) {
7346 if (auto versionAttr = dyn_cast<omp::VersionAttr>(attr)) {
7347 llvm::OpenMPIRBuilder *ompBuilder =
7348 moduleTranslation.getOpenMPBuilder();
7349 ompBuilder->M.addModuleFlag(llvm::Module::Max, "openmp",
7350 versionAttr.getVersion());
7351 return success();
7352 }
7353 return failure();
7354 })
7355 .Case("omp.declare_target",
7356 [&](Attribute attr) {
7357 if (auto declareTargetAttr =
7358 dyn_cast<omp::DeclareTargetAttr>(attr))
7359 return convertDeclareTargetAttr(op, declareTargetAttr,
7360 moduleTranslation);
7361 return failure();
7362 })
7363 .Case("omp.requires",
7364 [&](Attribute attr) {
7365 if (auto requiresAttr = dyn_cast<omp::ClauseRequiresAttr>(attr)) {
7366 using Requires = omp::ClauseRequires;
7367 Requires flags = requiresAttr.getValue();
7368 llvm::OpenMPIRBuilderConfig &config =
7369 moduleTranslation.getOpenMPBuilder()->Config;
7370 config.setHasRequiresReverseOffload(
7371 bitEnumContainsAll(flags, Requires::reverse_offload));
7372 config.setHasRequiresUnifiedAddress(
7373 bitEnumContainsAll(flags, Requires::unified_address));
7374 config.setHasRequiresUnifiedSharedMemory(
7375 bitEnumContainsAll(flags, Requires::unified_shared_memory));
7376 config.setHasRequiresDynamicAllocators(
7377 bitEnumContainsAll(flags, Requires::dynamic_allocators));
7378 return success();
7379 }
7380 return failure();
7381 })
7382 .Case("omp.target_triples",
7383 [&](Attribute attr) {
7384 if (auto triplesAttr = dyn_cast<ArrayAttr>(attr)) {
7385 llvm::OpenMPIRBuilderConfig &config =
7386 moduleTranslation.getOpenMPBuilder()->Config;
7387 config.TargetTriples.clear();
7388 config.TargetTriples.reserve(triplesAttr.size());
7389 for (Attribute tripleAttr : triplesAttr) {
7390 if (auto tripleStrAttr = dyn_cast<StringAttr>(tripleAttr))
7391 config.TargetTriples.emplace_back(tripleStrAttr.getValue());
7392 else
7393 return failure();
7394 }
7395 return success();
7396 }
7397 return failure();
7398 })
7399 .Default([](Attribute) {
7400 // Fall through for omp attributes that do not require lowering.
7401 return success();
7402 })(attribute.getValue());
7403
7404 return failure();
7405}
7406
7407// Returns true if the operation is not inside a TargetOp, it is part of a
7408// function and that function is not declare target.
7409static bool isHostDeviceOp(Operation *op) {
7410 // Assumes no reverse offloading
7411 if (op->getParentOfType<omp::TargetOp>())
7412 return false;
7413
7414 if (auto parentFn = op->getParentOfType<LLVM::LLVMFuncOp>()) {
7415 if (auto declareTargetIface =
7416 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
7417 parentFn.getOperation()))
7418 if (declareTargetIface.isDeclareTarget() &&
7419 declareTargetIface.getDeclareTargetDeviceType() !=
7420 mlir::omp::DeclareTargetDeviceType::host)
7421 return false;
7422
7423 return true;
7424 }
7425
7426 return false;
7427}
7428
7429static llvm::Function *getOmpTargetAlloc(llvm::IRBuilderBase &builder,
7430 llvm::Module *llvmModule) {
7431 llvm::Type *i64Ty = builder.getInt64Ty();
7432 llvm::Type *i32Ty = builder.getInt32Ty();
7433 llvm::Type *returnType = builder.getPtrTy(0);
7434 llvm::FunctionType *fnType =
7435 llvm::FunctionType::get(returnType, {i64Ty, i32Ty}, false);
7436 llvm::Function *func = cast<llvm::Function>(
7437 llvmModule->getOrInsertFunction("omp_target_alloc", fnType).getCallee());
7438 return func;
7439}
7440
7441static LogicalResult
7442convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder,
7443 LLVM::ModuleTranslation &moduleTranslation) {
7444 auto allocMemOp = cast<omp::TargetAllocMemOp>(opInst);
7445 if (!allocMemOp)
7446 return failure();
7447
7448 // Get "omp_target_alloc" function
7449 llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
7450 llvm::Function *ompTargetAllocFunc = getOmpTargetAlloc(builder, llvmModule);
7451 // Get the corresponding device value in llvm
7452 mlir::Value deviceNum = allocMemOp.getDevice();
7453 llvm::Value *llvmDeviceNum = moduleTranslation.lookupValue(deviceNum);
7454 // Get the allocation size.
7455 llvm::DataLayout dataLayout = llvmModule->getDataLayout();
7456 mlir::Type heapTy = allocMemOp.getAllocatedType();
7457 llvm::Type *llvmHeapTy = moduleTranslation.convertType(heapTy);
7458 llvm::TypeSize typeSize = dataLayout.getTypeStoreSize(llvmHeapTy);
7459 llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue());
7460 for (auto typeParam : allocMemOp.getTypeparams())
7461 allocSize =
7462 builder.CreateMul(allocSize, moduleTranslation.lookupValue(typeParam));
7463 // Create call to "omp_target_alloc" with the args as translated llvm values.
7464 llvm::CallInst *call =
7465 builder.CreateCall(ompTargetAllocFunc, {allocSize, llvmDeviceNum});
7466 llvm::Value *resultI64 = builder.CreatePtrToInt(call, builder.getInt64Ty());
7467
7468 // Map the result
7469 moduleTranslation.mapValue(allocMemOp.getResult(), resultI64);
7470 return success();
7471}
7472
7473static llvm::Function *getOmpTargetFree(llvm::IRBuilderBase &builder,
7474 llvm::Module *llvmModule) {
7475 llvm::Type *ptrTy = builder.getPtrTy(0);
7476 llvm::Type *i32Ty = builder.getInt32Ty();
7477 llvm::Type *voidTy = builder.getVoidTy();
7478 llvm::FunctionType *fnType =
7479 llvm::FunctionType::get(voidTy, {ptrTy, i32Ty}, false);
7480 llvm::Function *func = dyn_cast<llvm::Function>(
7481 llvmModule->getOrInsertFunction("omp_target_free", fnType).getCallee());
7482 return func;
7483}
7484
7485static LogicalResult
7486convertTargetFreeMemOp(Operation &opInst, llvm::IRBuilderBase &builder,
7487 LLVM::ModuleTranslation &moduleTranslation) {
7488 auto freeMemOp = cast<omp::TargetFreeMemOp>(opInst);
7489 if (!freeMemOp)
7490 return failure();
7491
7492 // Get "omp_target_free" function
7493 llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
7494 llvm::Function *ompTragetFreeFunc = getOmpTargetFree(builder, llvmModule);
7495 // Get the corresponding device value in llvm
7496 mlir::Value deviceNum = freeMemOp.getDevice();
7497 llvm::Value *llvmDeviceNum = moduleTranslation.lookupValue(deviceNum);
7498 // Get the corresponding heapref value in llvm
7499 mlir::Value heapref = freeMemOp.getHeapref();
7500 llvm::Value *llvmHeapref = moduleTranslation.lookupValue(heapref);
7501 // Convert heapref int to ptr and call "omp_target_free"
7502 llvm::Value *intToPtr =
7503 builder.CreateIntToPtr(llvmHeapref, builder.getPtrTy(0));
7504 builder.CreateCall(ompTragetFreeFunc, {intToPtr, llvmDeviceNum});
7505 return success();
7506}
7507
7508/// Given an OpenMP MLIR operation, create the corresponding LLVM IR (including
7509/// OpenMP runtime calls).
7510LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
7511 Operation *op, llvm::IRBuilderBase &builder,
7512 LLVM::ModuleTranslation &moduleTranslation) const {
7513 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
7514
7515 if (ompBuilder->Config.isTargetDevice() &&
7516 !isa<omp::TargetOp, omp::MapInfoOp, omp::TerminatorOp, omp::YieldOp>(
7517 op) &&
7518 isHostDeviceOp(op))
7519 return op->emitOpError() << "unsupported host op found in device";
7520
7521 // For each loop, introduce one stack frame to hold loop information. Ensure
7522 // this is only done for the outermost loop wrapper to prevent introducing
7523 // multiple stack frames for a single loop. Initially set to null, the loop
7524 // information structure is initialized during translation of the nested
7525 // omp.loop_nest operation, making it available to translation of all loop
7526 // wrappers after their body has been successfully translated.
7527 bool isOutermostLoopWrapper =
7528 isa_and_present<omp::LoopWrapperInterface>(op) &&
7529 !dyn_cast_if_present<omp::LoopWrapperInterface>(op->getParentOp());
7530
7531 if (isOutermostLoopWrapper)
7532 moduleTranslation.stackPush<OpenMPLoopInfoStackFrame>();
7533
7534 auto result =
7535 llvm::TypeSwitch<Operation *, LogicalResult>(op)
7536 .Case([&](omp::BarrierOp op) -> LogicalResult {
7538 return failure();
7539
7540 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
7541 ompBuilder->createBarrier(builder.saveIP(),
7542 llvm::omp::OMPD_barrier);
7543 LogicalResult res = handleError(afterIP, *op);
7544 if (res.succeeded()) {
7545 // If the barrier generated a cancellation check, the insertion
7546 // point might now need to be changed to a new continuation block
7547 builder.restoreIP(*afterIP);
7548 }
7549 return res;
7550 })
7551 .Case([&](omp::TaskyieldOp op) {
7553 return failure();
7554
7555 ompBuilder->createTaskyield(builder.saveIP());
7556 return success();
7557 })
7558 .Case([&](omp::FlushOp op) {
7560 return failure();
7561
7562 // No support in Openmp runtime function (__kmpc_flush) to accept
7563 // the argument list.
7564 // OpenMP standard states the following:
7565 // "An implementation may implement a flush with a list by ignoring
7566 // the list, and treating it the same as a flush without a list."
7567 //
7568 // The argument list is discarded so that, flush with a list is
7569 // treated same as a flush without a list.
7570 ompBuilder->createFlush(builder.saveIP());
7571 return success();
7572 })
7573 .Case([&](omp::ParallelOp op) {
7574 return convertOmpParallel(op, builder, moduleTranslation);
7575 })
7576 .Case([&](omp::MaskedOp) {
7577 return convertOmpMasked(*op, builder, moduleTranslation);
7578 })
7579 .Case([&](omp::MasterOp) {
7580 return convertOmpMaster(*op, builder, moduleTranslation);
7581 })
7582 .Case([&](omp::CriticalOp) {
7583 return convertOmpCritical(*op, builder, moduleTranslation);
7584 })
7585 .Case([&](omp::OrderedRegionOp) {
7586 return convertOmpOrderedRegion(*op, builder, moduleTranslation);
7587 })
7588 .Case([&](omp::OrderedOp) {
7589 return convertOmpOrdered(*op, builder, moduleTranslation);
7590 })
7591 .Case([&](omp::WsloopOp) {
7592 return convertOmpWsloop(*op, builder, moduleTranslation);
7593 })
7594 .Case([&](omp::SimdOp) {
7595 return convertOmpSimd(*op, builder, moduleTranslation);
7596 })
7597 .Case([&](omp::AtomicReadOp) {
7598 return convertOmpAtomicRead(*op, builder, moduleTranslation);
7599 })
7600 .Case([&](omp::AtomicWriteOp) {
7601 return convertOmpAtomicWrite(*op, builder, moduleTranslation);
7602 })
7603 .Case([&](omp::AtomicUpdateOp op) {
7604 return convertOmpAtomicUpdate(op, builder, moduleTranslation);
7605 })
7606 .Case([&](omp::AtomicCaptureOp op) {
7607 return convertOmpAtomicCapture(op, builder, moduleTranslation);
7608 })
7609 .Case([&](omp::CancelOp op) {
7610 return convertOmpCancel(op, builder, moduleTranslation);
7611 })
7612 .Case([&](omp::CancellationPointOp op) {
7613 return convertOmpCancellationPoint(op, builder, moduleTranslation);
7614 })
7615 .Case([&](omp::SectionsOp) {
7616 return convertOmpSections(*op, builder, moduleTranslation);
7617 })
7618 .Case([&](omp::SingleOp op) {
7619 return convertOmpSingle(op, builder, moduleTranslation);
7620 })
7621 .Case([&](omp::TeamsOp op) {
7622 return convertOmpTeams(op, builder, moduleTranslation);
7623 })
7624 .Case([&](omp::TaskOp op) {
7625 return convertOmpTaskOp(op, builder, moduleTranslation);
7626 })
7627 .Case([&](omp::TaskloopOp op) {
7628 return convertOmpTaskloopOp(*op, builder, moduleTranslation);
7629 })
7630 .Case([&](omp::TaskgroupOp op) {
7631 return convertOmpTaskgroupOp(op, builder, moduleTranslation);
7632 })
7633 .Case([&](omp::TaskwaitOp op) {
7634 return convertOmpTaskwaitOp(op, builder, moduleTranslation);
7635 })
7636 .Case<omp::YieldOp, omp::TerminatorOp, omp::DeclareMapperOp,
7637 omp::DeclareMapperInfoOp, omp::DeclareReductionOp,
7638 omp::CriticalDeclareOp>([](auto op) {
7639 // `yield` and `terminator` can be just omitted. The block structure
7640 // was created in the region that handles their parent operation.
7641 // `declare_reduction` will be used by reductions and is not
7642 // converted directly, skip it.
7643 // `declare_mapper` and `declare_mapper.info` are handled whenever
7644 // they are referred to through a `map` clause.
7645 // `critical.declare` is only used to declare names of critical
7646 // sections which will be used by `critical` ops and hence can be
7647 // ignored for lowering. The OpenMP IRBuilder will create unique
7648 // name for critical section names.
7649 return success();
7650 })
7651 .Case([&](omp::ThreadprivateOp) {
7652 return convertOmpThreadprivate(*op, builder, moduleTranslation);
7653 })
7654 .Case<omp::TargetDataOp, omp::TargetEnterDataOp,
7655 omp::TargetExitDataOp, omp::TargetUpdateOp>([&](auto op) {
7656 return convertOmpTargetData(op, builder, moduleTranslation);
7657 })
7658 .Case([&](omp::TargetOp) {
7659 return convertOmpTarget(*op, builder, moduleTranslation);
7660 })
7661 .Case([&](omp::DistributeOp) {
7662 return convertOmpDistribute(*op, builder, moduleTranslation);
7663 })
7664 .Case([&](omp::LoopNestOp) {
7665 return convertOmpLoopNest(*op, builder, moduleTranslation);
7666 })
7667 .Case<omp::MapInfoOp, omp::MapBoundsOp, omp::PrivateClauseOp,
7668 omp::AffinityEntryOp, omp::IteratorOp>([&](auto op) {
7669 // No-op, should be handled by relevant owning operations e.g.
7670 // TargetOp, TargetEnterDataOp, TargetExitDataOp, TargetDataOp
7671 // etc. and then discarded
7672 return success();
7673 })
7674 .Case([&](omp::NewCliOp op) {
7675 // Meta-operation: Doesn't do anything by itself, but used to
7676 // identify a loop.
7677 return success();
7678 })
7679 .Case([&](omp::CanonicalLoopOp op) {
7680 return convertOmpCanonicalLoopOp(op, builder, moduleTranslation);
7681 })
7682 .Case([&](omp::UnrollHeuristicOp op) {
7683 // FIXME: Handling omp.unroll_heuristic as an executable requires
7684 // that the generator (e.g. omp.canonical_loop) has been seen first.
7685 // For construct that require all codegen to occur inside a callback
7686 // (e.g. OpenMPIRBilder::createParallel), all codegen of that
7687 // contained region including their transformations must occur at
7688 // the omp.canonical_loop.
7689 return applyUnrollHeuristic(op, builder, moduleTranslation);
7690 })
7691 .Case([&](omp::TileOp op) {
7692 return applyTile(op, builder, moduleTranslation);
7693 })
7694 .Case([&](omp::FuseOp op) {
7695 return applyFuse(op, builder, moduleTranslation);
7696 })
7697 .Case([&](omp::TargetAllocMemOp) {
7698 return convertTargetAllocMemOp(*op, builder, moduleTranslation);
7699 })
7700 .Case([&](omp::TargetFreeMemOp) {
7701 return convertTargetFreeMemOp(*op, builder, moduleTranslation);
7702 })
7703 .Default([&](Operation *inst) {
7704 return inst->emitError()
7705 << "not yet implemented: " << inst->getName();
7706 });
7707
7708 if (isOutermostLoopWrapper)
7709 moduleTranslation.stackPop();
7710
7711 return result;
7712}
7713
7715 registry.insert<omp::OpenMPDialect>();
7716 registry.addExtension(+[](MLIRContext *ctx, omp::OpenMPDialect *dialect) {
7717 dialect->addInterfaces<OpenMPDialectLLVMIRTranslationInterface>();
7718 });
7719}
7720
7722 DialectRegistry registry;
7724 context.appendDialectRegistry(registry);
7725}
for(Operation *op :ops)
return success()
lhs
static ze_device_handle_t getDevice(const uint32_t driverIdx=0, const int32_t devIdx=0)
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
if(!isCopyOut)
static bool mapTypeToBool(ClauseMapFlags value, ClauseMapFlags flag)
static void handleDeclareTargetMapVar(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::Function *func)
static LogicalResult convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP atomic update operation using OpenMPIRBuilder.
static llvm::omp::OrderKind convertOrderKind(std::optional< omp::ClauseOrderKind > o)
Convert Order attribute to llvm::omp::OrderKind.
static bool teamsReductionContainedInDistribute(omp::TeamsOp teamsOp)
static LogicalResult convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
owningDataPtrPtrReductionGens[i]
static Operation * getGlobalOpFromValue(Value value)
static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind convertToCaptureClauseKind(mlir::omp::DeclareTargetCaptureClause captureClause)
static mlir::LogicalResult convertIteratorRegion(llvm::Value *linearIV, IteratorInfo &iterInfo, mlir::Block &iteratorRegionBlock, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo, bool first)
static OpTy castOrGetParentOfType(Operation *op, bool immediateParent=false)
If op is of the given type parameter, return it casted to that type. Otherwise, if its immediate pare...
static LogicalResult convertOmpOrderedRegion(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'ordered_region' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertTargetFreeMemOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static void processMapMembersWithParent(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, llvm::omp::OpenMPOffloadMappingFlags memberOfFlag, TargetDirectiveEnumTy targetDirective)
static LogicalResult convertOmpAtomicWrite(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an omp.atomic.write operation to LLVM IR.
static OwningAtomicReductionGen makeAtomicReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Create an OpenMPIRBuilder-compatible atomic reduction generator for the given reduction declaration.
static OwningDataPtrPtrReductionGen makeRefDataPtrGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, bool isByRef)
Create an OpenMPIRBuilder-compatible data_ptr_ptr reduction generator for the given reduction declara...
static void popCancelFinalizationCB(const ArrayRef< llvm::UncondBrInst * > cancelTerminators, llvm::OpenMPIRBuilder &ompBuilder, const llvm::OpenMPIRBuilder::InsertPointTy &afterIP)
If we cancelled the construct, we should branch to the finalization block of that construct....
static llvm::Value * getRefPtrIfDeclareTarget(Value value, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult applyUnrollHeuristic(omp::UnrollHeuristicOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Apply a #pragma omp unroll / "!$omp unroll" transformation using the OpenMPIRBuilder.
static LogicalResult convertOmpMaster(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'master' operation into LLVM IR using OpenMPIRBuilder.
static void getAsIntegers(ArrayAttr values, llvm::SmallVector< int64_t > &ints)
static llvm::Value * findAssociatedValue(Value privateVar, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Return the llvm::Value * corresponding to the privateVar that is being privatized....
static ArrayRef< bool > getIsByRef(std::optional< ArrayRef< bool > > attr)
static LogicalResult convertOmpTaskwaitOp(omp::TaskwaitOp twOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP loop nest into LLVM IR using OpenMPIRBuilder.
static mlir::LogicalResult fillIteratorLoop(mlir::omp::IteratorOp itersOp, llvm::IRBuilderBase &builder, mlir::LLVM::ModuleTranslation &moduleTranslation, IteratorInfo &iterInfo, llvm::StringRef loopName, IteratorStoreEntryTy genStoreEntry)
static LogicalResult cleanupPrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, Location loc, SmallVectorImpl< llvm::Value * > &llvmPrivateVars, SmallVectorImpl< omp::PrivateClauseOp > &privateDecls)
static void createAlteredByCaptureMap(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder)
static LogicalResult convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP task construct into LLVM IR using OpenMPIRBuilder.
static void genMapInfos(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, TargetDirectiveEnumTy targetDirective)
static llvm::AtomicOrdering convertAtomicOrdering(std::optional< omp::ClauseMemoryOrderKind > ao)
Convert an Atomic Ordering attribute to llvm::AtomicOrdering.
static void setInsertPointForPossiblyEmptyBlock(llvm::IRBuilderBase &builder, llvm::BasicBlock *block=nullptr)
llvm::function_ref< void(llvm::Value *linearIV, mlir::omp::YieldOp yield)> IteratorStoreEntryTy
static LogicalResult convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'critical' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, TargetDirectiveEnumTy targetDirective)
static LogicalResult convertOmpCanonicalLoopOp(omp::CanonicalLoopOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Convert an omp.canonical_loop to LLVM-IR.
static LogicalResult convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static std::optional< int64_t > extractConstInteger(Value value)
If the given value is defined by an llvm.mlir.constant operation and it is of an integer type,...
static llvm::Expected< llvm::Value * > initPrivateVar(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, omp::PrivateClauseOp &privDecl, llvm::Value *nonPrivateVar, BlockArgument &blockArg, llvm::Value *llvmPrivateVar, llvm::BasicBlock *privInitBlock, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Initialize a single (first)private variable. You probably want to use allocateAndInitPrivateVars inst...
static mlir::LogicalResult buildAffinityData(mlir::omp::TaskOp &taskOp, llvm::IRBuilderBase &builder, mlir::LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::AffinityData &ad)
static LogicalResult allocAndInitializeReductionVars(OP op, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, llvm::ArrayRef< bool > isByRef)
static LogicalResult convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Function * getOmpTargetAlloc(llvm::IRBuilderBase &builder, llvm::Module *llvmModule)
static llvm::Expected< llvm::Function * > emitUserDefinedMapper(Operation *declMapperOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::StringRef mapperFuncName, TargetDirectiveEnumTy targetDirective)
static llvm::Expected< llvm::BasicBlock * > allocatePrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, PrivateVarsInfo &privateVarsInfo, const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Allocate and initialize delayed private variables. Returns the basic block which comes after all of t...
static LogicalResult convertOmpOrdered(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'ordered' operation into LLVM IR using OpenMPIRBuilder.
static void processMapWithMembersOf(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, TargetDirectiveEnumTy targetDirective)
static LogicalResult convertOmpTaskloopOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpMasked(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'masked' operation into LLVM IR using OpenMPIRBuilder.
static llvm::AtomicRMWInst::BinOp convertBinOpToAtomic(Operation &op)
Converts an LLVM dialect binary operation to the corresponding enum value for atomicrmw supported bin...
static LogicalResult convertOmpCancel(omp::CancelOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp)
allocatedType moduleTranslation static convertType(allocatedType) LogicalResult inlineOmpRegionCleanup(llvm::SmallVectorImpl< Region * > &cleanupRegions, llvm::ArrayRef< llvm::Value * > privateVariables, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, StringRef regionName, bool shouldLoadCleanupRegionArg=true)
handling of DeclareReductionOp's cleanup region
static LogicalResult applyFuse(omp::FuseOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Apply a #pragma omp fuse / !$omp fuse transformation using the OpenMPIRBuilder.
static llvm::SmallString< 64 > getDeclareTargetRefPtrSuffix(LLVM::GlobalOp globalOp, llvm::OpenMPIRBuilder &ompBuilder)
static llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Error initPrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, PrivateVarsInfo &privateVarsInfo, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
static llvm::CanonicalLoopInfo * findCurrentLoopInfo(LLVM::ModuleTranslation &moduleTranslation)
Find the loop information structure for the loop nest being translated.
static OwningReductionGen makeReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Create an OpenMPIRBuilder-compatible reduction generator for the given reduction declaration.
static std::vector< llvm::Value * > calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, bool isArrayTy, OperandRange bounds)
This function calculates the array/pointer offset for map data provided with bounds operations,...
static void storeAffinityEntry(llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, llvm::Value *affinityList, llvm::Value *index, llvm::Value *addr, llvm::Value *len)
static LogicalResult convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts the OpenMP parallel operation to LLVM IR.
static void pushCancelFinalizationCB(SmallVectorImpl< llvm::UncondBrInst * > &cancelTerminators, llvm::IRBuilderBase &llvmBuilder, llvm::OpenMPIRBuilder &ompBuilder, mlir::Operation *op, llvm::omp::Directive cancelDirective)
Shared implementation of a callback which adds a termiator for the new block created for the branch t...
static llvm::OpenMPIRBuilder::InsertPointTy findAllocaInsertPoint(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Find the insertion point for allocas given the current insertion point for normal operations in the b...
static LogicalResult inlineConvertOmpRegions(Region &region, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::Value * > *continuationBlockArgs=nullptr)
Translates the blocks contained in the given region and appends them to at the current insertion poin...
static LogicalResult convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP Threadprivate operation into LLVM IR using OpenMPIRBuilder.
static omp::PrivateClauseOp findPrivatizer(Operation *from, SymbolRefAttr symbolName)
Looks up from the operation from and returns the PrivateClauseOp with name symbolName.
static LogicalResult convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Expected< llvm::Function * > getOrCreateUserDefinedMapperFunc(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, TargetDirectiveEnumTy targetDirective)
static uint64_t getTypeByteSize(mlir::Type type, const DataLayout &dl)
static void processIndividualMap(MapInfoData &mapData, size_t mapDataIdx, MapInfosTy &combinedInfo, TargetDirectiveEnumTy targetDirective, int mapDataParentIdx=-1)
static void getTargetEntryUniqueInfo(llvm::TargetRegionEntryInfo &targetInfo, omp::TargetOp targetOp, llvm::StringRef parentName="")
static void extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads, Value &numTeamsLower, Value &numTeamsUpper, Value &threadLimit, llvm::SmallVectorImpl< Value > *lowerBounds=nullptr, llvm::SmallVectorImpl< Value > *upperBounds=nullptr, llvm::SmallVectorImpl< Value > *steps=nullptr)
Follow uses of host_eval-defined block arguments of the given omp.target operation and populate outpu...
static llvm::Expected< llvm::BasicBlock * > convertOmpOpRegions(Region &region, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::PHINode * > *continuationBlockPHIs=nullptr)
Converts the given region that appears within an OpenMP dialect operation to LLVM IR,...
static LogicalResult copyFirstPrivateVars(mlir::Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::Value * > &moldVars, ArrayRef< llvm::Value * > llvmPrivateVars, SmallVectorImpl< omp::PrivateClauseOp > &privateDecls, bool insertBarrier, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
static LogicalResult allocReductionVars(T loop, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, SmallVectorImpl< DeferredStore > &deferredStores, llvm::ArrayRef< bool > isByRefs)
Allocate space for privatized reduction variables.
static bool constructIsCancellable(Operation *op)
Returns true if the construct contains omp.cancel or omp.cancellation_point.
static llvm::omp::OpenMPOffloadMappingFlags convertClauseMapFlags(omp::ClauseMapFlags mlirFlags)
static llvm::IRBuilderBase::InsertPoint createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg, llvm::Value *input, llvm::Value *&retVal, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase::InsertPoint allocaIP, llvm::IRBuilderBase::InsertPoint codeGenIP)
static void mapInitializationArgs(T loop, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, DenseMap< Value, llvm::Value * > &reductionVariableMap, unsigned i)
Map input arguments to reduction initialization region.
static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind)
Convert ProcBindKind from MLIR-generated enum to LLVM enum.
static void fillAffinityLocators(Operation::operand_range affinityVars, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::Value *affinityList)
static void getOverlappedMembers(llvm::SmallVectorImpl< size_t > &overlapMapDataIdxs, omp::MapInfoOp parentOp)
static LogicalResult convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP single construct into LLVM IR using OpenMPIRBuilder.
static bool isDeclareTargetTo(Value value)
static uint64_t getArrayElementSizeInBits(LLVM::LLVMArrayType arrTy, DataLayout &dl)
static void collectReductionDecls(T op, SmallVectorImpl< omp::DeclareReductionOp > &reductions)
Populates reductions with reduction declarations used in the given op.
static LogicalResult handleError(llvm::Error error, Operation &op)
static LogicalResult convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind convertToDeviceClauseKind(mlir::omp::DeclareTargetDeviceType deviceClause)
static LogicalResult checkImplementationStatus(Operation &op)
Check whether translation to LLVM IR for the given operation is currently supported.
static LogicalResult createReductionsAndCleanup(OP op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, ArrayRef< llvm::Value * > privateReductionVariables, ArrayRef< bool > isByRef, bool isNowait=false, bool isTeamsReduction=false)
static LogicalResult convertOmpCancellationPoint(omp::CancellationPointOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static bool opIsInSingleThread(mlir::Operation *op)
This can't always be determined statically, but when we can, it is good to avoid generating compiler-...
static uint64_t getReductionDataSize(OpTy &op)
static LogicalResult convertOmpAtomicRead(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Convert omp.atomic.read operation to LLVM IR.
static llvm::omp::Directive convertCancellationConstructType(omp::ClauseCancellationConstructType directive)
static void initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp, llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs, bool isTargetDevice, bool isGPU)
Populate default MinTeams, MaxTeams and MaxThreads to their default values as stated by the correspon...
static void initTargetRuntimeAttrs(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, omp::TargetOp targetOp, Operation *capturedOp, llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs)
Gather LLVM runtime values for all clauses evaluated in the host that are passed to the kernel invoca...
static LogicalResult convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static bool isHostDeviceOp(Operation *op)
static bool isDeclareTargetLink(Value value)
static LogicalResult convertFlagsAttr(Operation *op, mlir::omp::FlagsAttr attribute, LLVM::ModuleTranslation &moduleTranslation)
Lowers the FlagsAttr which is applied to the module on the device pass when offloading,...
static bool checkIfPointerMap(omp::MapInfoOp mapOp)
static LogicalResult applyTile(omp::TileOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Apply a #pragma omp tile / !$omp tile transformation using the OpenMPIRBuilder.
static void sortMapIndices(llvm::SmallVectorImpl< size_t > &indices, omp::MapInfoOp mapInfo)
static void buildDependData(std::optional< ArrayAttr > dependKinds, OperandRange dependVars, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::OpenMPIRBuilder::DependData > &dds)
static llvm::Function * getOmpTargetFree(llvm::IRBuilderBase &builder, llvm::Module *llvmModule)
static LogicalResult convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder.
static void collectMapDataFromMapOperands(MapInfoData &mapData, SmallVectorImpl< Value > &mapVars, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, llvm::IRBuilderBase &builder, ArrayRef< Value > useDevPtrOperands={}, ArrayRef< Value > useDevAddrOperands={}, ArrayRef< Value > hasDevAddrOperands={})
static void extractAtomicControlFlags(omp::AtomicUpdateOp atomicUpdateOp, bool &isIgnoreDenormalMode, bool &isFineGrainedMemory, bool &isRemoteMemory)
static Operation * genLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr, unsigned numCases, bool needsUniv, ArrayRef< TensorLevel > tidLvls)
Generates a for-loop or a while-loop, depending on whether it implements singleton iteration or co-it...
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition TypeID.h:331
#define div(a, b)
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:306
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
Operation & front()
Definition Block.h:163
Operation & back()
Definition Block.h:162
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
iterator begin()
Definition Block.h:153
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
An instance of this location represents a tuple of file, line number, and column number.
Definition Location.h:174
Implementation class for module translation.
llvm::BasicBlock * lookupBlock(Block *block) const
Finds an LLVM IR basic block that corresponds to the given MLIR block.
WalkResult stackWalk(llvm::function_ref< WalkResult(T &)> callback)
Calls callback for every ModuleTranslation stack frame of type T starting from the top of the stack.
void stackPush(Args &&...args)
Creates a stack frame of type T on ModuleTranslation stack.
LogicalResult convertBlock(Block &bb, bool ignoreArguments, llvm::IRBuilderBase &builder)
Translates the contents of the given block to LLVM IR using this translator.
SmallVector< llvm::Value * > lookupValues(ValueRange values)
Looks up remapped a list of remapped values.
void mapFunction(StringRef name, llvm::Function *func)
Stores the mapping between a function name and its LLVM IR representation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
void invalidateOmpLoop(omp::NewCliOp mlir)
Mark an OpenMP loop as having been consumed.
SymbolTableCollection & symbolTable()
llvm::Type * convertType(Type type)
Converts the type from MLIR LLVM dialect to LLVM.
llvm::OpenMPIRBuilder * getOpenMPBuilder()
Returns the OpenMP IR builder associated with the LLVM IR module being constructed.
void mapOmpLoop(omp::NewCliOp mlir, llvm::CanonicalLoopInfo *llvm)
Map an MLIR OpenMP dialect CanonicalLoopInfo to its lowered LLVM-IR OpenMPIRBuilder CanonicalLoopInfo...
llvm::GlobalValue * lookupGlobal(Operation *op)
Finds an LLVM IR global value that corresponds to the given MLIR operation defining a global value.
SaveStateStack< T, ModuleTranslation > SaveStack
RAII object calling stackPush/stackPop on construction/destruction.
llvm::Function * lookupFunction(StringRef name) const
Finds an LLVM IR function by its name.
void mapBlock(Block *mlir, llvm::BasicBlock *llvm)
Stores the mapping between an MLIR block and LLVM IR basic block.
llvm::Module * getLLVMModule()
Returns the LLVM module in which the IR is being constructed.
void stackPop()
Pops the last element from the ModuleTranslation stack.
void forgetMapping(Region &region)
Removes the mapping for blocks contained in the region and values defined in these blocks.
void mapValue(Value mlir, llvm::Value *llvm)
Stores the mapping between an MLIR value and its LLVM IR counterpart.
llvm::CanonicalLoopInfo * lookupOMPLoop(omp::NewCliOp mlir) const
Find the LLVM-IR loop that represents an MLIR loop.
llvm::LLVMContext & getLLVMContext() const
Returns the LLVM context in which the IR is being constructed.
Utility class to translate MLIR LLVM dialect types to LLVM IR.
Definition TypeToLLVM.h:39
unsigned getPreferredAlignment(Type type, const llvm::DataLayout &layout)
Returns the preferred alignment for the type given the data layout.
T findInstanceOf()
Return an instance of the given location type if one is nested under the current location.
Definition Location.h:45
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
void appendDialectRegistry(const DialectRegistry &registry)
Append the contents of the given dialect registry to the registry associated with this context.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:44
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:712
Value getOperand(unsigned idx)
Definition Operation.h:376
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:433
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:241
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:252
unsigned getNumOperands()
Definition Operation.h:372
OperandRange operand_range
Definition Operation.h:397
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:256
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:116
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:404
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:823
user_range getUsers()
Returns a range of all users.
Definition Operation.h:899
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:234
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
BlockArgListType getArguments()
Definition Region.h:81
bool empty()
Definition Region.h:60
unsigned getNumArguments()
Definition Region.h:123
iterator begin()
Definition Region.h:55
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition Region.h:200
BlockListType & getBlocks()
Definition Region.h:45
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
Concrete CRTP base class for StateStack frames.
Definition StateStack.h:47
@ Private
The symbol is private and may only be referenced by SymbolRefAttrs local to the operations within the...
Definition SymbolTable.h:97
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult advance()
Definition WalkResult.h:47
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition WalkResult.h:51
static WalkResult interrupt()
Definition WalkResult.h:46
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Definition CallGraph.h:229
void connectPHINodes(Region &region, const ModuleTranslation &state)
For all blocks in the region that were converted to LLVM IR using the given ModuleTranslation,...
llvm::Constant * createMappingInformation(Location loc, llvm::OpenMPIRBuilder &builder)
Create a constant string representing the mapping information extracted from the MLIR location inform...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
auto getDims(VectorType vType)
Returns a range over the dims (size and scalability) of a VectorType.
Include the generated interface declarations.
SetVector< Block * > getBlocksSortedByDominance(Region &region)
Gets a list of blocks that is sorted according to dominance.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void registerOpenMPDialectTranslation(DialectRegistry &registry)
Register the OpenMP dialect and the translation from it to the LLVM IR in the given registry;.
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:125
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition Utils.cpp:1330
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:120
A util to collect info needed to convert delayed privatizers from MLIR to LLVM.
SmallVector< mlir::Value > mlirVars
SmallVector< omp::PrivateClauseOp > privatizers
MutableArrayRef< BlockArgument > blockArgs
SmallVector< llvm::Value * > llvmVars