MLIR 22.0.0git
OpenMPToLLVMIRTranslation.cpp
Go to the documentation of this file.
1//===- OpenMPToLLVMIRTranslation.cpp - Translate OpenMP dialect to LLVM IR-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a translation between the MLIR OpenMP dialect and LLVM
10// IR.
11//
12//===----------------------------------------------------------------------===//
19#include "mlir/IR/Operation.h"
20#include "mlir/Support/LLVM.h"
23
24#include "llvm/ADT/ArrayRef.h"
25#include "llvm/ADT/SmallVector.h"
26#include "llvm/ADT/TypeSwitch.h"
27#include "llvm/Frontend/OpenMP/OMPConstants.h"
28#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
29#include "llvm/IR/Constants.h"
30#include "llvm/IR/DebugInfoMetadata.h"
31#include "llvm/IR/DerivedTypes.h"
32#include "llvm/IR/IRBuilder.h"
33#include "llvm/IR/MDBuilder.h"
34#include "llvm/IR/ReplaceConstant.h"
35#include "llvm/Support/FileSystem.h"
36#include "llvm/Support/VirtualFileSystem.h"
37#include "llvm/TargetParser/Triple.h"
38#include "llvm/Transforms/Utils/ModuleUtils.h"
39
40#include <cstdint>
41#include <iterator>
42#include <numeric>
43#include <optional>
44#include <utility>
45
46using namespace mlir;
47
48namespace {
49static llvm::omp::ScheduleKind
50convertToScheduleKind(std::optional<omp::ClauseScheduleKind> schedKind) {
51 if (!schedKind.has_value())
52 return llvm::omp::OMP_SCHEDULE_Default;
53 switch (schedKind.value()) {
54 case omp::ClauseScheduleKind::Static:
55 return llvm::omp::OMP_SCHEDULE_Static;
56 case omp::ClauseScheduleKind::Dynamic:
57 return llvm::omp::OMP_SCHEDULE_Dynamic;
58 case omp::ClauseScheduleKind::Guided:
59 return llvm::omp::OMP_SCHEDULE_Guided;
60 case omp::ClauseScheduleKind::Auto:
61 return llvm::omp::OMP_SCHEDULE_Auto;
62 case omp::ClauseScheduleKind::Runtime:
63 return llvm::omp::OMP_SCHEDULE_Runtime;
64 case omp::ClauseScheduleKind::Distribute:
65 return llvm::omp::OMP_SCHEDULE_Distribute;
66 }
67 llvm_unreachable("unhandled schedule clause argument");
68}
69
70/// ModuleTranslation stack frame for OpenMP operations. This keeps track of the
71/// insertion points for allocas.
72class OpenMPAllocaStackFrame
73 : public StateStackFrameBase<OpenMPAllocaStackFrame> {
74public:
76
77 explicit OpenMPAllocaStackFrame(llvm::OpenMPIRBuilder::InsertPointTy allocaIP)
78 : allocaInsertPoint(allocaIP) {}
79 llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
80};
81
82/// Stack frame to hold a \see llvm::CanonicalLoopInfo representing the
83/// collapsed canonical loop information corresponding to an \c omp.loop_nest
84/// operation.
85class OpenMPLoopInfoStackFrame
86 : public StateStackFrameBase<OpenMPLoopInfoStackFrame> {
87public:
88 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPLoopInfoStackFrame)
89 llvm::CanonicalLoopInfo *loopInfo = nullptr;
90};
91
92/// Custom error class to signal translation errors that don't need reporting,
93/// since encountering them will have already triggered relevant error messages.
94///
95/// Its purpose is to serve as the glue between MLIR failures represented as
96/// \see LogicalResult instances and \see llvm::Error instances used to
97/// propagate errors through the \see llvm::OpenMPIRBuilder. Generally, when an
98/// error of the first type is raised, a message is emitted directly (the \see
99/// LogicalResult itself does not hold any information). If we need to forward
100/// this error condition as an \see llvm::Error while avoiding triggering some
101/// redundant error reporting later on, we need a custom \see llvm::ErrorInfo
102/// class to just signal this situation has happened.
103///
104/// For example, this class should be used to trigger errors from within
105/// callbacks passed to the \see OpenMPIRBuilder when they were triggered by the
106/// translation of their own regions. This unclutters the error log from
107/// redundant messages.
108class PreviouslyReportedError
109 : public llvm::ErrorInfo<PreviouslyReportedError> {
110public:
111 void log(raw_ostream &) const override {
112 // Do not log anything.
113 }
114
115 std::error_code convertToErrorCode() const override {
116 llvm_unreachable(
117 "PreviouslyReportedError doesn't support ECError conversion");
118 }
119
120 // Used by ErrorInfo::classID.
121 static char ID;
122};
123
124char PreviouslyReportedError::ID = 0;
125
126/*
127 * Custom class for processing linear clause for omp.wsloop
128 * and omp.simd. Linear clause translation requires setup,
129 * initialization, update, and finalization at varying
130 * basic blocks in the IR. This class helps maintain
131 * internal state to allow consistent translation in
132 * each of these stages.
133 */
134
135class LinearClauseProcessor {
136
137private:
138 SmallVector<llvm::Value *> linearPreconditionVars;
139 SmallVector<llvm::Value *> linearLoopBodyTemps;
140 SmallVector<llvm::Value *> linearOrigVal;
141 SmallVector<llvm::Value *> linearSteps;
142 SmallVector<llvm::Type *> linearVarTypes;
143 llvm::BasicBlock *linearFinalizationBB;
144 llvm::BasicBlock *linearExitBB;
145 llvm::BasicBlock *linearLastIterExitBB;
146
147public:
148 // Register type for the linear variables
149 void registerType(LLVM::ModuleTranslation &moduleTranslation,
150 mlir::Attribute &ty) {
151 linearVarTypes.push_back(moduleTranslation.convertType(
152 mlir::cast<mlir::TypeAttr>(ty).getValue()));
153 }
154
155 // Allocate space for linear variabes
156 void createLinearVar(llvm::IRBuilderBase &builder,
157 LLVM::ModuleTranslation &moduleTranslation,
158 llvm::Value *linearVar, int idx) {
159 linearPreconditionVars.push_back(
160 builder.CreateAlloca(linearVarTypes[idx], nullptr, ".linear_var"));
161 llvm::Value *linearLoopBodyTemp =
162 builder.CreateAlloca(linearVarTypes[idx], nullptr, ".linear_result");
163 linearOrigVal.push_back(linearVar);
164 linearLoopBodyTemps.push_back(linearLoopBodyTemp);
165 }
166
167 // Initialize linear step
168 inline void initLinearStep(LLVM::ModuleTranslation &moduleTranslation,
169 mlir::Value &linearStep) {
170 linearSteps.push_back(moduleTranslation.lookupValue(linearStep));
171 }
172
173 // Emit IR for initialization of linear variables
174 void initLinearVar(llvm::IRBuilderBase &builder,
175 LLVM::ModuleTranslation &moduleTranslation,
176 llvm::BasicBlock *loopPreHeader) {
177 builder.SetInsertPoint(loopPreHeader->getTerminator());
178 for (size_t index = 0; index < linearOrigVal.size(); index++) {
179 llvm::LoadInst *linearVarLoad =
180 builder.CreateLoad(linearVarTypes[index], linearOrigVal[index]);
181 builder.CreateStore(linearVarLoad, linearPreconditionVars[index]);
182 }
183 }
184
185 // Emit IR for updating Linear variables
186 void updateLinearVar(llvm::IRBuilderBase &builder, llvm::BasicBlock *loopBody,
187 llvm::Value *loopInductionVar) {
188 builder.SetInsertPoint(loopBody->getTerminator());
189 for (size_t index = 0; index < linearPreconditionVars.size(); index++) {
190 llvm::Type *linearVarType = linearVarTypes[index];
191 llvm::Value *iv = loopInductionVar;
192 llvm::Value *step = linearSteps[index];
193
194 if (!iv->getType()->isIntegerTy())
195 llvm_unreachable("OpenMP loop induction variable must be an integer "
196 "type");
197
198 if (linearVarType->isIntegerTy()) {
199 // Integer path: normalize all arithmetic to linearVarType
200 iv = builder.CreateSExtOrTrunc(iv, linearVarType);
201 step = builder.CreateSExtOrTrunc(step, linearVarType);
202
203 llvm::LoadInst *linearVarStart =
204 builder.CreateLoad(linearVarType, linearPreconditionVars[index]);
205 llvm::Value *mulInst = builder.CreateMul(iv, step);
206 llvm::Value *addInst = builder.CreateAdd(linearVarStart, mulInst);
207 builder.CreateStore(addInst, linearLoopBodyTemps[index]);
208 } else if (linearVarType->isFloatingPointTy()) {
209 // Float path: perform multiply in integer, then convert to float
210 step = builder.CreateSExtOrTrunc(step, iv->getType());
211 llvm::Value *mulInst = builder.CreateMul(iv, step);
212
213 llvm::LoadInst *linearVarStart =
214 builder.CreateLoad(linearVarType, linearPreconditionVars[index]);
215 llvm::Value *mulFp = builder.CreateSIToFP(mulInst, linearVarType);
216 llvm::Value *addInst = builder.CreateFAdd(linearVarStart, mulFp);
217 builder.CreateStore(addInst, linearLoopBodyTemps[index]);
218 } else {
219 llvm_unreachable(
220 "Linear variable must be of integer or floating-point type");
221 }
222 }
223 }
224
225 // Linear variable finalization is conditional on the last logical iteration.
226 // Create BB splits to manage the same.
227 void splitLinearFiniBB(llvm::IRBuilderBase &builder,
228 llvm::BasicBlock *loopExit) {
229 linearFinalizationBB = loopExit->splitBasicBlock(
230 loopExit->getTerminator(), "omp_loop.linear_finalization");
231 linearExitBB = linearFinalizationBB->splitBasicBlock(
232 linearFinalizationBB->getTerminator(), "omp_loop.linear_exit");
233 linearLastIterExitBB = linearFinalizationBB->splitBasicBlock(
234 linearFinalizationBB->getTerminator(), "omp_loop.linear_lastiter_exit");
235 }
236
237 // Finalize the linear vars
238 llvm::OpenMPIRBuilder::InsertPointOrErrorTy
239 finalizeLinearVar(llvm::IRBuilderBase &builder,
240 LLVM::ModuleTranslation &moduleTranslation,
241 llvm::Value *lastIter) {
242 // Emit condition to check whether last logical iteration is being executed
243 builder.SetInsertPoint(linearFinalizationBB->getTerminator());
244 llvm::Value *loopLastIterLoad = builder.CreateLoad(
245 llvm::Type::getInt32Ty(builder.getContext()), lastIter);
246 llvm::Value *isLast =
247 builder.CreateCmp(llvm::CmpInst::ICMP_NE, loopLastIterLoad,
248 llvm::ConstantInt::get(
249 llvm::Type::getInt32Ty(builder.getContext()), 0));
250 // Store the linear variable values to original variables.
251 builder.SetInsertPoint(linearLastIterExitBB->getTerminator());
252 for (size_t index = 0; index < linearOrigVal.size(); index++) {
253 llvm::LoadInst *linearVarTemp =
254 builder.CreateLoad(linearVarTypes[index], linearLoopBodyTemps[index]);
255 builder.CreateStore(linearVarTemp, linearOrigVal[index]);
256 }
257
258 // Create conditional branch such that the linear variable
259 // values are stored to original variables only at the
260 // last logical iteration
261 builder.SetInsertPoint(linearFinalizationBB->getTerminator());
262 builder.CreateCondBr(isLast, linearLastIterExitBB, linearExitBB);
263 linearFinalizationBB->getTerminator()->eraseFromParent();
264 // Emit barrier
265 builder.SetInsertPoint(linearExitBB->getTerminator());
266 return moduleTranslation.getOpenMPBuilder()->createBarrier(
267 builder.saveIP(), llvm::omp::OMPD_barrier);
268 }
269
270 // Emit stores for linear variables. Useful in case of SIMD
271 // construct.
272 void emitStoresForLinearVar(llvm::IRBuilderBase &builder) {
273 for (size_t index = 0; index < linearOrigVal.size(); index++) {
274 llvm::LoadInst *linearVarTemp =
275 builder.CreateLoad(linearVarTypes[index], linearLoopBodyTemps[index]);
276 builder.CreateStore(linearVarTemp, linearOrigVal[index]);
277 }
278 }
279
280 // Rewrite all uses of the original variable in `BBName`
281 // with the linear variable in-place
282 void rewriteInPlace(llvm::IRBuilderBase &builder, const std::string &BBName,
283 size_t varIndex) {
284 llvm::SmallVector<llvm::User *> users;
285 for (llvm::User *user : linearOrigVal[varIndex]->users())
286 users.push_back(user);
287 for (auto *user : users) {
288 if (auto *userInst = dyn_cast<llvm::Instruction>(user)) {
289 if (userInst->getParent()->getName().str().find(BBName) !=
290 std::string::npos)
291 user->replaceUsesOfWith(linearOrigVal[varIndex],
292 linearLoopBodyTemps[varIndex]);
293 }
294 }
295 }
296};
297
298} // namespace
299
300/// Looks up from the operation from and returns the PrivateClauseOp with
301/// name symbolName
302static omp::PrivateClauseOp findPrivatizer(Operation *from,
303 SymbolRefAttr symbolName) {
304 omp::PrivateClauseOp privatizer =
306 symbolName);
307 assert(privatizer && "privatizer not found in the symbol table");
308 return privatizer;
309}
310
311/// Check whether translation to LLVM IR for the given operation is currently
312/// supported. If not, descriptive diagnostics will be emitted to let users know
313/// this is a not-yet-implemented feature.
314///
315/// \returns success if no unimplemented features are needed to translate the
316/// given operation.
317static LogicalResult checkImplementationStatus(Operation &op) {
318 auto todo = [&op](StringRef clauseName) {
319 return op.emitError() << "not yet implemented: Unhandled clause "
320 << clauseName << " in " << op.getName()
321 << " operation";
322 };
323
324 auto checkAllocate = [&todo](auto op, LogicalResult &result) {
325 if (!op.getAllocateVars().empty() || !op.getAllocatorVars().empty())
326 result = todo("allocate");
327 };
328 auto checkBare = [&todo](auto op, LogicalResult &result) {
329 if (op.getBare())
330 result = todo("ompx_bare");
331 };
332 auto checkCollapse = [&todo](auto op, LogicalResult &result) {
333 if (op.getCollapseNumLoops() > 1)
334 result = todo("collapse");
335 };
336 auto checkDepend = [&todo](auto op, LogicalResult &result) {
337 if (!op.getDependVars().empty() || op.getDependKinds())
338 result = todo("depend");
339 };
340 auto checkHint = [](auto op, LogicalResult &) {
341 if (op.getHint())
342 op.emitWarning("hint clause discarded");
343 };
344 auto checkInReduction = [&todo](auto op, LogicalResult &result) {
345 if (!op.getInReductionVars().empty() || op.getInReductionByref() ||
346 op.getInReductionSyms())
347 result = todo("in_reduction");
348 };
349 auto checkNowait = [&todo](auto op, LogicalResult &result) {
350 if (op.getNowait())
351 result = todo("nowait");
352 };
353 auto checkOrder = [&todo](auto op, LogicalResult &result) {
354 if (op.getOrder() || op.getOrderMod())
355 result = todo("order");
356 };
357 auto checkParLevelSimd = [&todo](auto op, LogicalResult &result) {
358 if (op.getParLevelSimd())
359 result = todo("parallelization-level");
360 };
361 auto checkPrivate = [&todo](auto op, LogicalResult &result) {
362 if (!op.getPrivateVars().empty() || op.getPrivateSyms())
363 result = todo("privatization");
364 };
365 auto checkReduction = [&todo](auto op, LogicalResult &result) {
366 if (isa<omp::TeamsOp>(op) || isa<omp::TaskloopOp>(op))
367 if (!op.getReductionVars().empty() || op.getReductionByref() ||
368 op.getReductionSyms())
369 result = todo("reduction");
370 if (op.getReductionMod() &&
371 op.getReductionMod().value() != omp::ReductionModifier::defaultmod)
372 result = todo("reduction with modifier");
373 };
374 auto checkTaskReduction = [&todo](auto op, LogicalResult &result) {
375 if (!op.getTaskReductionVars().empty() || op.getTaskReductionByref() ||
376 op.getTaskReductionSyms())
377 result = todo("task_reduction");
378 };
379
380 LogicalResult result = success();
382 .Case([&](omp::DistributeOp op) {
383 checkAllocate(op, result);
384 checkOrder(op, result);
385 })
386 .Case([&](omp::LoopNestOp op) {
387 if (mlir::isa<omp::TaskloopOp>(op.getOperation()->getParentOp()))
388 checkCollapse(op, result);
389 })
390 .Case([&](omp::OrderedRegionOp op) { checkParLevelSimd(op, result); })
391 .Case([&](omp::SectionsOp op) {
392 checkAllocate(op, result);
393 checkPrivate(op, result);
394 checkReduction(op, result);
395 })
396 .Case([&](omp::SingleOp op) {
397 checkAllocate(op, result);
398 checkPrivate(op, result);
399 })
400 .Case([&](omp::TeamsOp op) {
401 checkAllocate(op, result);
402 checkPrivate(op, result);
403 })
404 .Case([&](omp::TaskOp op) {
405 checkAllocate(op, result);
406 checkInReduction(op, result);
407 })
408 .Case([&](omp::TaskgroupOp op) {
409 checkAllocate(op, result);
410 checkTaskReduction(op, result);
411 })
412 .Case([&](omp::TaskwaitOp op) {
413 checkDepend(op, result);
414 checkNowait(op, result);
415 })
416 .Case([&](omp::TaskloopOp op) {
417 checkAllocate(op, result);
418 checkInReduction(op, result);
419 checkReduction(op, result);
420 })
421 .Case([&](omp::WsloopOp op) {
422 checkAllocate(op, result);
423 checkOrder(op, result);
424 checkReduction(op, result);
425 })
426 .Case([&](omp::ParallelOp op) {
427 checkAllocate(op, result);
428 checkReduction(op, result);
429 })
430 .Case([&](omp::SimdOp op) { checkReduction(op, result); })
431 .Case<omp::AtomicReadOp, omp::AtomicWriteOp, omp::AtomicUpdateOp,
432 omp::AtomicCaptureOp>([&](auto op) { checkHint(op, result); })
433 .Case<omp::TargetEnterDataOp, omp::TargetExitDataOp>(
434 [&](auto op) { checkDepend(op, result); })
435 .Case<omp::TargetUpdateOp>([&](auto op) { checkDepend(op, result); })
436 .Case([&](omp::TargetOp op) {
437 checkAllocate(op, result);
438 checkBare(op, result);
439 checkInReduction(op, result);
440 })
441 .Default([](Operation &) {
442 // Assume all clauses for an operation can be translated unless they are
443 // checked above.
444 });
445 return result;
446}
447
448static LogicalResult handleError(llvm::Error error, Operation &op) {
449 LogicalResult result = success();
450 if (error) {
451 llvm::handleAllErrors(
452 std::move(error),
453 [&](const PreviouslyReportedError &) { result = failure(); },
454 [&](const llvm::ErrorInfoBase &err) {
455 result = op.emitError(err.message());
456 });
457 }
458 return result;
459}
460
461template <typename T>
462static LogicalResult handleError(llvm::Expected<T> &result, Operation &op) {
463 if (!result)
464 return handleError(result.takeError(), op);
465
466 return success();
467}
468
469/// Find the insertion point for allocas given the current insertion point for
470/// normal operations in the builder.
471static llvm::OpenMPIRBuilder::InsertPointTy
472findAllocaInsertPoint(llvm::IRBuilderBase &builder,
473 LLVM::ModuleTranslation &moduleTranslation) {
474 // If there is an alloca insertion point on stack, i.e. we are in a nested
475 // operation and a specific point was provided by some surrounding operation,
476 // use it.
477 llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
478 WalkResult walkResult = moduleTranslation.stackWalk<OpenMPAllocaStackFrame>(
479 [&](OpenMPAllocaStackFrame &frame) {
480 allocaInsertPoint = frame.allocaInsertPoint;
481 return WalkResult::interrupt();
482 });
483 // In cases with multiple levels of outlining, the tree walk might find an
484 // alloca insertion point that is inside the original function while the
485 // builder insertion point is inside the outlined function. We need to make
486 // sure that we do not use it in those cases.
487 if (walkResult.wasInterrupted() &&
488 allocaInsertPoint.getBlock()->getParent() ==
489 builder.GetInsertBlock()->getParent())
490 return allocaInsertPoint;
491
492 // Otherwise, insert to the entry block of the surrounding function.
493 // If the current IRBuilder InsertPoint is the function's entry, it cannot
494 // also be used for alloca insertion which would result in insertion order
495 // confusion. Create a new BasicBlock for the Builder and use the entry block
496 // for the allocs.
497 // TODO: Create a dedicated alloca BasicBlock at function creation such that
498 // we do not need to move the current InertPoint here.
499 if (builder.GetInsertBlock() ==
500 &builder.GetInsertBlock()->getParent()->getEntryBlock()) {
501 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end() &&
502 "Assuming end of basic block");
503 llvm::BasicBlock *entryBB = llvm::BasicBlock::Create(
504 builder.getContext(), "entry", builder.GetInsertBlock()->getParent(),
505 builder.GetInsertBlock()->getNextNode());
506 builder.CreateBr(entryBB);
507 builder.SetInsertPoint(entryBB);
508 }
509
510 llvm::BasicBlock &funcEntryBlock =
511 builder.GetInsertBlock()->getParent()->getEntryBlock();
512 return llvm::OpenMPIRBuilder::InsertPointTy(
513 &funcEntryBlock, funcEntryBlock.getFirstInsertionPt());
514}
515
516/// Find the loop information structure for the loop nest being translated. It
517/// will return a `null` value unless called from the translation function for
518/// a loop wrapper operation after successfully translating its body.
519static llvm::CanonicalLoopInfo *
521 llvm::CanonicalLoopInfo *loopInfo = nullptr;
522 moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
523 [&](OpenMPLoopInfoStackFrame &frame) {
524 loopInfo = frame.loopInfo;
525 return WalkResult::interrupt();
526 });
527 return loopInfo;
528}
529
530/// Converts the given region that appears within an OpenMP dialect operation to
531/// LLVM IR, creating a branch from the `sourceBlock` to the entry block of the
532/// region, and a branch from any block with an successor-less OpenMP terminator
533/// to `continuationBlock`. Populates `continuationBlockPHIs` with the PHI nodes
534/// of the continuation block if provided.
536 Region &region, StringRef blockName, llvm::IRBuilderBase &builder,
537 LLVM::ModuleTranslation &moduleTranslation,
538 SmallVectorImpl<llvm::PHINode *> *continuationBlockPHIs = nullptr) {
539 bool isLoopWrapper = isa<omp::LoopWrapperInterface>(region.getParentOp());
540
541 llvm::BasicBlock *continuationBlock =
542 splitBB(builder, true, "omp.region.cont");
543 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
544
545 llvm::LLVMContext &llvmContext = builder.getContext();
546 for (Block &bb : region) {
547 llvm::BasicBlock *llvmBB = llvm::BasicBlock::Create(
548 llvmContext, blockName, builder.GetInsertBlock()->getParent(),
549 builder.GetInsertBlock()->getNextNode());
550 moduleTranslation.mapBlock(&bb, llvmBB);
551 }
552
553 llvm::Instruction *sourceTerminator = sourceBlock->getTerminator();
554
555 // Terminators (namely YieldOp) may be forwarding values to the region that
556 // need to be available in the continuation block. Collect the types of these
557 // operands in preparation of creating PHI nodes. This is skipped for loop
558 // wrapper operations, for which we know in advance they have no terminators.
559 SmallVector<llvm::Type *> continuationBlockPHITypes;
560 unsigned numYields = 0;
561
562 if (!isLoopWrapper) {
563 bool operandsProcessed = false;
564 for (Block &bb : region.getBlocks()) {
565 if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(bb.getTerminator())) {
566 if (!operandsProcessed) {
567 for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
568 continuationBlockPHITypes.push_back(
569 moduleTranslation.convertType(yield->getOperand(i).getType()));
570 }
571 operandsProcessed = true;
572 } else {
573 assert(continuationBlockPHITypes.size() == yield->getNumOperands() &&
574 "mismatching number of values yielded from the region");
575 for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
576 llvm::Type *operandType =
577 moduleTranslation.convertType(yield->getOperand(i).getType());
578 (void)operandType;
579 assert(continuationBlockPHITypes[i] == operandType &&
580 "values of mismatching types yielded from the region");
581 }
582 }
583 numYields++;
584 }
585 }
586 }
587
588 // Insert PHI nodes in the continuation block for any values forwarded by the
589 // terminators in this region.
590 if (!continuationBlockPHITypes.empty())
591 assert(
592 continuationBlockPHIs &&
593 "expected continuation block PHIs if converted regions yield values");
594 if (continuationBlockPHIs) {
595 llvm::IRBuilderBase::InsertPointGuard guard(builder);
596 continuationBlockPHIs->reserve(continuationBlockPHITypes.size());
597 builder.SetInsertPoint(continuationBlock, continuationBlock->begin());
598 for (llvm::Type *ty : continuationBlockPHITypes)
599 continuationBlockPHIs->push_back(builder.CreatePHI(ty, numYields));
600 }
601
602 // Convert blocks one by one in topological order to ensure
603 // defs are converted before uses.
605 for (Block *bb : blocks) {
606 llvm::BasicBlock *llvmBB = moduleTranslation.lookupBlock(bb);
607 // Retarget the branch of the entry block to the entry block of the
608 // converted region (regions are single-entry).
609 if (bb->isEntryBlock()) {
610 assert(sourceTerminator->getNumSuccessors() == 1 &&
611 "provided entry block has multiple successors");
612 assert(sourceTerminator->getSuccessor(0) == continuationBlock &&
613 "ContinuationBlock is not the successor of the entry block");
614 sourceTerminator->setSuccessor(0, llvmBB);
615 }
616
617 llvm::IRBuilderBase::InsertPointGuard guard(builder);
618 if (failed(
619 moduleTranslation.convertBlock(*bb, bb->isEntryBlock(), builder)))
620 return llvm::make_error<PreviouslyReportedError>();
621
622 // Create a direct branch here for loop wrappers to prevent their lack of a
623 // terminator from causing a crash below.
624 if (isLoopWrapper) {
625 builder.CreateBr(continuationBlock);
626 continue;
627 }
628
629 // Special handling for `omp.yield` and `omp.terminator` (we may have more
630 // than one): they return the control to the parent OpenMP dialect operation
631 // so replace them with the branch to the continuation block. We handle this
632 // here to avoid relying inter-function communication through the
633 // ModuleTranslation class to set up the correct insertion point. This is
634 // also consistent with MLIR's idiom of handling special region terminators
635 // in the same code that handles the region-owning operation.
636 Operation *terminator = bb->getTerminator();
637 if (isa<omp::TerminatorOp, omp::YieldOp>(terminator)) {
638 builder.CreateBr(continuationBlock);
639
640 for (unsigned i = 0, e = terminator->getNumOperands(); i < e; ++i)
641 (*continuationBlockPHIs)[i]->addIncoming(
642 moduleTranslation.lookupValue(terminator->getOperand(i)), llvmBB);
643 }
644 }
645 // After all blocks have been traversed and values mapped, connect the PHI
646 // nodes to the results of preceding blocks.
647 LLVM::detail::connectPHINodes(region, moduleTranslation);
648
649 // Remove the blocks and values defined in this region from the mapping since
650 // they are not visible outside of this region. This allows the same region to
651 // be converted several times, that is cloned, without clashes, and slightly
652 // speeds up the lookups.
653 moduleTranslation.forgetMapping(region);
654
655 return continuationBlock;
656}
657
658/// Convert ProcBindKind from MLIR-generated enum to LLVM enum.
659static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind) {
660 switch (kind) {
661 case omp::ClauseProcBindKind::Close:
662 return llvm::omp::ProcBindKind::OMP_PROC_BIND_close;
663 case omp::ClauseProcBindKind::Master:
664 return llvm::omp::ProcBindKind::OMP_PROC_BIND_master;
665 case omp::ClauseProcBindKind::Primary:
666 return llvm::omp::ProcBindKind::OMP_PROC_BIND_primary;
667 case omp::ClauseProcBindKind::Spread:
668 return llvm::omp::ProcBindKind::OMP_PROC_BIND_spread;
669 }
670 llvm_unreachable("Unknown ClauseProcBindKind kind");
671}
672
673/// Maps block arguments from \p blockArgIface (which are MLIR values) to the
674/// corresponding LLVM values of \p the interface's operands. This is useful
675/// when an OpenMP region with entry block arguments is converted to LLVM. In
676/// this case the block arguments are (part of) of the OpenMP region's entry
677/// arguments and the operands are (part of) of the operands to the OpenMP op
678/// containing the region.
679static void forwardArgs(LLVM::ModuleTranslation &moduleTranslation,
680 omp::BlockArgOpenMPOpInterface blockArgIface) {
682 blockArgIface.getBlockArgsPairs(blockArgsPairs);
683 for (auto [var, arg] : blockArgsPairs)
684 moduleTranslation.mapValue(arg, moduleTranslation.lookupValue(var));
685}
686
687/// Converts an OpenMP 'masked' operation into LLVM IR using OpenMPIRBuilder.
688static LogicalResult
689convertOmpMasked(Operation &opInst, llvm::IRBuilderBase &builder,
690 LLVM::ModuleTranslation &moduleTranslation) {
691 auto maskedOp = cast<omp::MaskedOp>(opInst);
692 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
693
694 if (failed(checkImplementationStatus(opInst)))
695 return failure();
696
697 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
698 // MaskedOp has only one region associated with it.
699 auto &region = maskedOp.getRegion();
700 builder.restoreIP(codeGenIP);
701 return convertOmpOpRegions(region, "omp.masked.region", builder,
702 moduleTranslation)
703 .takeError();
704 };
705
706 // TODO: Perform finalization actions for variables. This has to be
707 // called for variables which have destructors/finalizers.
708 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
709
710 llvm::Value *filterVal = nullptr;
711 if (auto filterVar = maskedOp.getFilteredThreadId()) {
712 filterVal = moduleTranslation.lookupValue(filterVar);
713 } else {
714 llvm::LLVMContext &llvmContext = builder.getContext();
715 filterVal =
716 llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext), /*V=*/0);
717 }
718 assert(filterVal != nullptr);
719 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
720 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
721 moduleTranslation.getOpenMPBuilder()->createMasked(ompLoc, bodyGenCB,
722 finiCB, filterVal);
723
724 if (failed(handleError(afterIP, opInst)))
725 return failure();
726
727 builder.restoreIP(*afterIP);
728 return success();
729}
730
731/// Converts an OpenMP 'master' operation into LLVM IR using OpenMPIRBuilder.
732static LogicalResult
733convertOmpMaster(Operation &opInst, llvm::IRBuilderBase &builder,
734 LLVM::ModuleTranslation &moduleTranslation) {
735 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
736 auto masterOp = cast<omp::MasterOp>(opInst);
737
738 if (failed(checkImplementationStatus(opInst)))
739 return failure();
740
741 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
742 // MasterOp has only one region associated with it.
743 auto &region = masterOp.getRegion();
744 builder.restoreIP(codeGenIP);
745 return convertOmpOpRegions(region, "omp.master.region", builder,
746 moduleTranslation)
747 .takeError();
748 };
749
750 // TODO: Perform finalization actions for variables. This has to be
751 // called for variables which have destructors/finalizers.
752 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
753
754 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
755 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
756 moduleTranslation.getOpenMPBuilder()->createMaster(ompLoc, bodyGenCB,
757 finiCB);
758
759 if (failed(handleError(afterIP, opInst)))
760 return failure();
761
762 builder.restoreIP(*afterIP);
763 return success();
764}
765
766/// Converts an OpenMP 'critical' operation into LLVM IR using OpenMPIRBuilder.
767static LogicalResult
768convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder,
769 LLVM::ModuleTranslation &moduleTranslation) {
770 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
771 auto criticalOp = cast<omp::CriticalOp>(opInst);
772
773 if (failed(checkImplementationStatus(opInst)))
774 return failure();
775
776 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
777 // CriticalOp has only one region associated with it.
778 auto &region = cast<omp::CriticalOp>(opInst).getRegion();
779 builder.restoreIP(codeGenIP);
780 return convertOmpOpRegions(region, "omp.critical.region", builder,
781 moduleTranslation)
782 .takeError();
783 };
784
785 // TODO: Perform finalization actions for variables. This has to be
786 // called for variables which have destructors/finalizers.
787 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
788
789 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
790 llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
791 llvm::Constant *hint = nullptr;
792
793 // If it has a name, it probably has a hint too.
794 if (criticalOp.getNameAttr()) {
795 // The verifiers in OpenMP Dialect guarentee that all the pointers are
796 // non-null
797 auto symbolRef = cast<SymbolRefAttr>(criticalOp.getNameAttr());
798 auto criticalDeclareOp =
800 symbolRef);
801 hint =
802 llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext),
803 static_cast<int>(criticalDeclareOp.getHint()));
804 }
805 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
806 moduleTranslation.getOpenMPBuilder()->createCritical(
807 ompLoc, bodyGenCB, finiCB, criticalOp.getName().value_or(""), hint);
808
809 if (failed(handleError(afterIP, opInst)))
810 return failure();
811
812 builder.restoreIP(*afterIP);
813 return success();
814}
815
816/// A util to collect info needed to convert delayed privatizers from MLIR to
817/// LLVM.
819 template <typename OP>
821 : blockArgs(
822 cast<omp::BlockArgOpenMPOpInterface>(*op).getPrivateBlockArgs()) {
823 mlirVars.reserve(blockArgs.size());
824 llvmVars.reserve(blockArgs.size());
825 collectPrivatizationDecls<OP>(op);
826
827 for (mlir::Value privateVar : op.getPrivateVars())
828 mlirVars.push_back(privateVar);
829 }
830
835
836private:
837 /// Populates `privatizations` with privatization declarations used for the
838 /// given op.
839 template <class OP>
840 void collectPrivatizationDecls(OP op) {
841 std::optional<ArrayAttr> attr = op.getPrivateSyms();
842 if (!attr)
843 return;
844
845 privatizers.reserve(privatizers.size() + attr->size());
846 for (auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
847 privatizers.push_back(findPrivatizer(op, symbolRef));
848 }
849 }
850};
851
852/// Populates `reductions` with reduction declarations used in the given op.
853template <typename T>
854static void
857 std::optional<ArrayAttr> attr = op.getReductionSyms();
858 if (!attr)
859 return;
860
861 reductions.reserve(reductions.size() + op.getNumReductionVars());
862 for (auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
863 reductions.push_back(
865 op, symbolRef));
866 }
867}
868
869/// Translates the blocks contained in the given region and appends them to at
870/// the current insertion point of `builder`. The operations of the entry block
871/// are appended to the current insertion block. If set, `continuationBlockArgs`
872/// is populated with translated values that correspond to the values
873/// omp.yield'ed from the region.
874static LogicalResult inlineConvertOmpRegions(
875 Region &region, StringRef blockName, llvm::IRBuilderBase &builder,
876 LLVM::ModuleTranslation &moduleTranslation,
877 SmallVectorImpl<llvm::Value *> *continuationBlockArgs = nullptr) {
878 if (region.empty())
879 return success();
880
881 // Special case for single-block regions that don't create additional blocks:
882 // insert operations without creating additional blocks.
883 if (region.hasOneBlock()) {
884 llvm::Instruction *potentialTerminator =
885 builder.GetInsertBlock()->empty() ? nullptr
886 : &builder.GetInsertBlock()->back();
887
888 if (potentialTerminator && potentialTerminator->isTerminator())
889 potentialTerminator->removeFromParent();
890 moduleTranslation.mapBlock(&region.front(), builder.GetInsertBlock());
891
892 if (failed(moduleTranslation.convertBlock(
893 region.front(), /*ignoreArguments=*/true, builder)))
894 return failure();
895
896 // The continuation arguments are simply the translated terminator operands.
897 if (continuationBlockArgs)
898 llvm::append_range(
899 *continuationBlockArgs,
900 moduleTranslation.lookupValues(region.front().back().getOperands()));
901
902 // Drop the mapping that is no longer necessary so that the same region can
903 // be processed multiple times.
904 moduleTranslation.forgetMapping(region);
905
906 if (potentialTerminator && potentialTerminator->isTerminator()) {
907 llvm::BasicBlock *block = builder.GetInsertBlock();
908 if (block->empty()) {
909 // this can happen for really simple reduction init regions e.g.
910 // %0 = llvm.mlir.constant(0 : i32) : i32
911 // omp.yield(%0 : i32)
912 // because the llvm.mlir.constant (MLIR op) isn't converted into any
913 // llvm op
914 potentialTerminator->insertInto(block, block->begin());
915 } else {
916 potentialTerminator->insertAfter(&block->back());
917 }
918 }
919
920 return success();
921 }
922
924 llvm::Expected<llvm::BasicBlock *> continuationBlock =
925 convertOmpOpRegions(region, blockName, builder, moduleTranslation, &phis);
926
927 if (failed(handleError(continuationBlock, *region.getParentOp())))
928 return failure();
929
930 if (continuationBlockArgs)
931 llvm::append_range(*continuationBlockArgs, phis);
932 builder.SetInsertPoint(*continuationBlock,
933 (*continuationBlock)->getFirstInsertionPt());
934 return success();
935}
936
937namespace {
938/// Owning equivalents of OpenMPIRBuilder::(Atomic)ReductionGen that are used to
939/// store lambdas with capture.
940using OwningReductionGen =
941 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
942 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *,
943 llvm::Value *&)>;
944using OwningAtomicReductionGen =
945 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
946 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Type *, llvm::Value *,
947 llvm::Value *)>;
948using OwningDataPtrPtrReductionGen =
949 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
950 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *&)>;
951} // namespace
952
953/// Create an OpenMPIRBuilder-compatible reduction generator for the given
954/// reduction declaration. The generator uses `builder` but ignores its
955/// insertion point.
956static OwningReductionGen
957makeReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder,
958 LLVM::ModuleTranslation &moduleTranslation) {
959 // The lambda is mutable because we need access to non-const methods of decl
960 // (which aren't actually mutating it), and we must capture decl by-value to
961 // avoid the dangling reference after the parent function returns.
962 OwningReductionGen gen =
963 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
964 llvm::Value *lhs, llvm::Value *rhs,
965 llvm::Value *&result) mutable
966 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
967 moduleTranslation.mapValue(decl.getReductionLhsArg(), lhs);
968 moduleTranslation.mapValue(decl.getReductionRhsArg(), rhs);
969 builder.restoreIP(insertPoint);
971 if (failed(inlineConvertOmpRegions(decl.getReductionRegion(),
972 "omp.reduction.nonatomic.body", builder,
973 moduleTranslation, &phis)))
974 return llvm::createStringError(
975 "failed to inline `combiner` region of `omp.declare_reduction`");
976 result = llvm::getSingleElement(phis);
977 return builder.saveIP();
978 };
979 return gen;
980}
981
982/// Create an OpenMPIRBuilder-compatible atomic reduction generator for the
983/// given reduction declaration. The generator uses `builder` but ignores its
984/// insertion point. Returns null if there is no atomic region available in the
985/// reduction declaration.
986static OwningAtomicReductionGen
987makeAtomicReductionGen(omp::DeclareReductionOp decl,
988 llvm::IRBuilderBase &builder,
989 LLVM::ModuleTranslation &moduleTranslation) {
990 if (decl.getAtomicReductionRegion().empty())
991 return OwningAtomicReductionGen();
992
993 // The lambda is mutable because we need access to non-const methods of decl
994 // (which aren't actually mutating it), and we must capture decl by-value to
995 // avoid the dangling reference after the parent function returns.
996 OwningAtomicReductionGen atomicGen =
997 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint, llvm::Type *,
998 llvm::Value *lhs, llvm::Value *rhs) mutable
999 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1000 moduleTranslation.mapValue(decl.getAtomicReductionLhsArg(), lhs);
1001 moduleTranslation.mapValue(decl.getAtomicReductionRhsArg(), rhs);
1002 builder.restoreIP(insertPoint);
1004 if (failed(inlineConvertOmpRegions(decl.getAtomicReductionRegion(),
1005 "omp.reduction.atomic.body", builder,
1006 moduleTranslation, &phis)))
1007 return llvm::createStringError(
1008 "failed to inline `atomic` region of `omp.declare_reduction`");
1009 assert(phis.empty());
1010 return builder.saveIP();
1011 };
1012 return atomicGen;
1013}
1014
1015/// Create an OpenMPIRBuilder-compatible `data_ptr_ptr` reduction generator for
1016/// the given reduction declaration. The generator uses `builder` but ignores
1017/// its insertion point. Returns null if there is no `data_ptr_ptr` region
1018/// available in the reduction declaration.
1019static OwningDataPtrPtrReductionGen
1020makeRefDataPtrGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder,
1021 LLVM::ModuleTranslation &moduleTranslation, bool isByRef) {
1022 if (!isByRef)
1023 return OwningDataPtrPtrReductionGen();
1024
1025 OwningDataPtrPtrReductionGen refDataPtrGen =
1026 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
1027 llvm::Value *byRefVal, llvm::Value *&result) mutable
1028 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1029 moduleTranslation.mapValue(decl.getDataPtrPtrRegionArg(), byRefVal);
1030 builder.restoreIP(insertPoint);
1032 if (failed(inlineConvertOmpRegions(decl.getDataPtrPtrRegion(),
1033 "omp.data_ptr_ptr.body", builder,
1034 moduleTranslation, &phis)))
1035 return llvm::createStringError(
1036 "failed to inline `data_ptr_ptr` region of `omp.declare_reduction`");
1037 result = llvm::getSingleElement(phis);
1038 return builder.saveIP();
1039 };
1040
1041 return refDataPtrGen;
1042}
1043
1044/// Converts an OpenMP 'ordered' operation into LLVM IR using OpenMPIRBuilder.
1045static LogicalResult
1046convertOmpOrdered(Operation &opInst, llvm::IRBuilderBase &builder,
1047 LLVM::ModuleTranslation &moduleTranslation) {
1048 auto orderedOp = cast<omp::OrderedOp>(opInst);
1049
1050 if (failed(checkImplementationStatus(opInst)))
1051 return failure();
1052
1053 omp::ClauseDepend dependType = *orderedOp.getDoacrossDependType();
1054 bool isDependSource = dependType == omp::ClauseDepend::dependsource;
1055 unsigned numLoops = *orderedOp.getDoacrossNumLoops();
1056 SmallVector<llvm::Value *> vecValues =
1057 moduleTranslation.lookupValues(orderedOp.getDoacrossDependVars());
1058
1059 size_t indexVecValues = 0;
1060 while (indexVecValues < vecValues.size()) {
1061 SmallVector<llvm::Value *> storeValues;
1062 storeValues.reserve(numLoops);
1063 for (unsigned i = 0; i < numLoops; i++) {
1064 storeValues.push_back(vecValues[indexVecValues]);
1065 indexVecValues++;
1066 }
1067 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1068 findAllocaInsertPoint(builder, moduleTranslation);
1069 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1070 builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createOrderedDepend(
1071 ompLoc, allocaIP, numLoops, storeValues, ".cnt.addr", isDependSource));
1072 }
1073 return success();
1074}
1075
1076/// Converts an OpenMP 'ordered_region' operation into LLVM IR using
1077/// OpenMPIRBuilder.
1078static LogicalResult
1079convertOmpOrderedRegion(Operation &opInst, llvm::IRBuilderBase &builder,
1080 LLVM::ModuleTranslation &moduleTranslation) {
1081 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1082 auto orderedRegionOp = cast<omp::OrderedRegionOp>(opInst);
1083
1084 if (failed(checkImplementationStatus(opInst)))
1085 return failure();
1086
1087 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
1088 // OrderedOp has only one region associated with it.
1089 auto &region = cast<omp::OrderedRegionOp>(opInst).getRegion();
1090 builder.restoreIP(codeGenIP);
1091 return convertOmpOpRegions(region, "omp.ordered.region", builder,
1092 moduleTranslation)
1093 .takeError();
1094 };
1095
1096 // TODO: Perform finalization actions for variables. This has to be
1097 // called for variables which have destructors/finalizers.
1098 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
1099
1100 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1101 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1102 moduleTranslation.getOpenMPBuilder()->createOrderedThreadsSimd(
1103 ompLoc, bodyGenCB, finiCB, !orderedRegionOp.getParLevelSimd());
1104
1105 if (failed(handleError(afterIP, opInst)))
1106 return failure();
1107
1108 builder.restoreIP(*afterIP);
1109 return success();
1110}
1111
1112namespace {
1113/// Contains the arguments for an LLVM store operation
1114struct DeferredStore {
1115 DeferredStore(llvm::Value *value, llvm::Value *address)
1116 : value(value), address(address) {}
1117
1118 llvm::Value *value;
1119 llvm::Value *address;
1120};
1121} // namespace
1122
1123/// Allocate space for privatized reduction variables.
1124/// `deferredStores` contains information to create store operations which needs
1125/// to be inserted after all allocas
1126template <typename T>
1127static LogicalResult
1129 llvm::IRBuilderBase &builder,
1130 LLVM::ModuleTranslation &moduleTranslation,
1131 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1133 SmallVectorImpl<llvm::Value *> &privateReductionVariables,
1134 DenseMap<Value, llvm::Value *> &reductionVariableMap,
1135 SmallVectorImpl<DeferredStore> &deferredStores,
1136 llvm::ArrayRef<bool> isByRefs) {
1137 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1138 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1139
1140 // delay creating stores until after all allocas
1141 deferredStores.reserve(loop.getNumReductionVars());
1142
1143 for (std::size_t i = 0; i < loop.getNumReductionVars(); ++i) {
1144 Region &allocRegion = reductionDecls[i].getAllocRegion();
1145 if (isByRefs[i]) {
1146 if (allocRegion.empty())
1147 continue;
1148
1150 if (failed(inlineConvertOmpRegions(allocRegion, "omp.reduction.alloc",
1151 builder, moduleTranslation, &phis)))
1152 return loop.emitError(
1153 "failed to inline `alloc` region of `omp.declare_reduction`");
1154
1155 assert(phis.size() == 1 && "expected one allocation to be yielded");
1156 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1157
1158 // Allocate reduction variable (which is a pointer to the real reduction
1159 // variable allocated in the inlined region)
1160 llvm::Value *var = builder.CreateAlloca(
1161 moduleTranslation.convertType(reductionDecls[i].getType()));
1162
1163 llvm::Type *ptrTy = builder.getPtrTy();
1164 llvm::Value *castVar =
1165 builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
1166 llvm::Value *castPhi =
1167 builder.CreatePointerBitCastOrAddrSpaceCast(phis[0], ptrTy);
1168
1169 deferredStores.emplace_back(castPhi, castVar);
1170
1171 privateReductionVariables[i] = castVar;
1172 moduleTranslation.mapValue(reductionArgs[i], castPhi);
1173 reductionVariableMap.try_emplace(loop.getReductionVars()[i], castPhi);
1174 } else {
1175 assert(allocRegion.empty() &&
1176 "allocaction is implicit for by-val reduction");
1177 llvm::Value *var = builder.CreateAlloca(
1178 moduleTranslation.convertType(reductionDecls[i].getType()));
1179
1180 llvm::Type *ptrTy = builder.getPtrTy();
1181 llvm::Value *castVar =
1182 builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
1183
1184 moduleTranslation.mapValue(reductionArgs[i], castVar);
1185 privateReductionVariables[i] = castVar;
1186 reductionVariableMap.try_emplace(loop.getReductionVars()[i], castVar);
1187 }
1188 }
1189
1190 return success();
1191}
1192
1193/// Map input arguments to reduction initialization region
1194template <typename T>
1195static void
1197 llvm::IRBuilderBase &builder,
1199 DenseMap<Value, llvm::Value *> &reductionVariableMap,
1200 unsigned i) {
1201 // map input argument to the initialization region
1202 mlir::omp::DeclareReductionOp &reduction = reductionDecls[i];
1203 Region &initializerRegion = reduction.getInitializerRegion();
1204 Block &entry = initializerRegion.front();
1205
1206 mlir::Value mlirSource = loop.getReductionVars()[i];
1207 llvm::Value *llvmSource = moduleTranslation.lookupValue(mlirSource);
1208 llvm::Value *origVal = llvmSource;
1209 // If a non-pointer value is expected, load the value from the source pointer.
1210 if (!isa<LLVM::LLVMPointerType>(
1211 reduction.getInitializerMoldArg().getType()) &&
1212 isa<LLVM::LLVMPointerType>(mlirSource.getType())) {
1213 origVal =
1214 builder.CreateLoad(moduleTranslation.convertType(
1215 reduction.getInitializerMoldArg().getType()),
1216 llvmSource, "omp_orig");
1217 }
1218 moduleTranslation.mapValue(reduction.getInitializerMoldArg(), origVal);
1219
1220 if (entry.getNumArguments() > 1) {
1221 llvm::Value *allocation =
1222 reductionVariableMap.lookup(loop.getReductionVars()[i]);
1223 moduleTranslation.mapValue(reduction.getInitializerAllocArg(), allocation);
1224 }
1225}
1226
1227static void
1228setInsertPointForPossiblyEmptyBlock(llvm::IRBuilderBase &builder,
1229 llvm::BasicBlock *block = nullptr) {
1230 if (block == nullptr)
1231 block = builder.GetInsertBlock();
1232
1233 if (block->empty() || block->getTerminator() == nullptr)
1234 builder.SetInsertPoint(block);
1235 else
1236 builder.SetInsertPoint(block->getTerminator());
1237}
1238
1239/// Inline reductions' `init` regions. This functions assumes that the
1240/// `builder`'s insertion point is where the user wants the `init` regions to be
1241/// inlined; i.e. it does not try to find a proper insertion location for the
1242/// `init` regions. It also leaves the `builder's insertions point in a state
1243/// where the user can continue the code-gen directly afterwards.
1244template <typename OP>
1245static LogicalResult
1246initReductionVars(OP op, ArrayRef<BlockArgument> reductionArgs,
1247 llvm::IRBuilderBase &builder,
1248 LLVM::ModuleTranslation &moduleTranslation,
1249 llvm::BasicBlock *latestAllocaBlock,
1251 SmallVectorImpl<llvm::Value *> &privateReductionVariables,
1252 DenseMap<Value, llvm::Value *> &reductionVariableMap,
1253 llvm::ArrayRef<bool> isByRef,
1254 SmallVectorImpl<DeferredStore> &deferredStores) {
1255 if (op.getNumReductionVars() == 0)
1256 return success();
1257
1258 llvm::BasicBlock *initBlock = splitBB(builder, true, "omp.reduction.init");
1259 auto allocaIP = llvm::IRBuilderBase::InsertPoint(
1260 latestAllocaBlock, latestAllocaBlock->getTerminator()->getIterator());
1261 builder.restoreIP(allocaIP);
1262 SmallVector<llvm::Value *> byRefVars(op.getNumReductionVars());
1263
1264 for (unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1265 if (isByRef[i]) {
1266 if (!reductionDecls[i].getAllocRegion().empty())
1267 continue;
1268
1269 // TODO: remove after all users of by-ref are updated to use the alloc
1270 // region: Allocate reduction variable (which is a pointer to the real
1271 // reduciton variable allocated in the inlined region)
1272 byRefVars[i] = builder.CreateAlloca(
1273 moduleTranslation.convertType(reductionDecls[i].getType()));
1274 }
1275 }
1276
1277 setInsertPointForPossiblyEmptyBlock(builder, initBlock);
1278
1279 // store result of the alloc region to the allocated pointer to the real
1280 // reduction variable
1281 for (auto [data, addr] : deferredStores)
1282 builder.CreateStore(data, addr);
1283
1284 // Before the loop, store the initial values of reductions into reduction
1285 // variables. Although this could be done after allocas, we don't want to mess
1286 // up with the alloca insertion point.
1287 for (unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1289
1290 // map block argument to initializer region
1291 mapInitializationArgs(op, moduleTranslation, builder, reductionDecls,
1292 reductionVariableMap, i);
1293
1294 // TODO In some cases (specially on the GPU), the init regions may
1295 // contains stack alloctaions. If the region is inlined in a loop, this is
1296 // problematic. Instead of just inlining the region, handle allocations by
1297 // hoisting fixed length allocations to the function entry and using
1298 // stacksave and restore for variable length ones.
1299 if (failed(inlineConvertOmpRegions(reductionDecls[i].getInitializerRegion(),
1300 "omp.reduction.neutral", builder,
1301 moduleTranslation, &phis)))
1302 return failure();
1303
1304 assert(phis.size() == 1 && "expected one value to be yielded from the "
1305 "reduction neutral element declaration region");
1306
1308
1309 if (isByRef[i]) {
1310 if (!reductionDecls[i].getAllocRegion().empty())
1311 // done in allocReductionVars
1312 continue;
1313
1314 // TODO: this path can be removed once all users of by-ref are updated to
1315 // use an alloc region
1316
1317 // Store the result of the inlined region to the allocated reduction var
1318 // ptr
1319 builder.CreateStore(phis[0], byRefVars[i]);
1320
1321 privateReductionVariables[i] = byRefVars[i];
1322 moduleTranslation.mapValue(reductionArgs[i], phis[0]);
1323 reductionVariableMap.try_emplace(op.getReductionVars()[i], phis[0]);
1324 } else {
1325 // for by-ref case the store is inside of the reduction region
1326 builder.CreateStore(phis[0], privateReductionVariables[i]);
1327 // the rest was handled in allocByValReductionVars
1328 }
1329
1330 // forget the mapping for the initializer region because we might need a
1331 // different mapping if this reduction declaration is re-used for a
1332 // different variable
1333 moduleTranslation.forgetMapping(reductionDecls[i].getInitializerRegion());
1334 }
1335
1336 return success();
1337}
1338
1339/// Collect reduction info
1340template <typename T>
1341static void collectReductionInfo(
1342 T loop, llvm::IRBuilderBase &builder,
1343 LLVM::ModuleTranslation &moduleTranslation,
1346 SmallVectorImpl<OwningAtomicReductionGen> &owningAtomicReductionGens,
1348 const ArrayRef<llvm::Value *> privateReductionVariables,
1350 ArrayRef<bool> isByRef) {
1351 unsigned numReductions = loop.getNumReductionVars();
1352
1353 for (unsigned i = 0; i < numReductions; ++i) {
1354 owningReductionGens.push_back(
1355 makeReductionGen(reductionDecls[i], builder, moduleTranslation));
1356 owningAtomicReductionGens.push_back(
1357 makeAtomicReductionGen(reductionDecls[i], builder, moduleTranslation));
1359 reductionDecls[i], builder, moduleTranslation, isByRef[i]));
1360 }
1361
1362 // Collect the reduction information.
1363 reductionInfos.reserve(numReductions);
1364 for (unsigned i = 0; i < numReductions; ++i) {
1365 llvm::OpenMPIRBuilder::ReductionGenAtomicCBTy atomicGen = nullptr;
1366 if (owningAtomicReductionGens[i])
1367 atomicGen = owningAtomicReductionGens[i];
1368 llvm::Value *variable =
1369 moduleTranslation.lookupValue(loop.getReductionVars()[i]);
1370 mlir::Type allocatedType;
1371 reductionDecls[i].getAllocRegion().walk([&](mlir::Operation *op) {
1372 if (auto alloca = mlir::dyn_cast<LLVM::AllocaOp>(op)) {
1373 allocatedType = alloca.getElemType();
1375 }
1376
1378 });
1379
1380 reductionInfos.push_back(
1381 {moduleTranslation.convertType(reductionDecls[i].getType()), variable,
1382 privateReductionVariables[i],
1383 /*EvaluationKind=*/llvm::OpenMPIRBuilder::EvalKind::Scalar,
1385 /*ReductionGenClang=*/nullptr, atomicGen,
1387 allocatedType ? moduleTranslation.convertType(allocatedType) : nullptr,
1388 reductionDecls[i].getByrefElementType()
1389 ? moduleTranslation.convertType(
1390 *reductionDecls[i].getByrefElementType())
1391 : nullptr});
1392 }
1393}
1394
1395/// handling of DeclareReductionOp's cleanup region
1396static LogicalResult
1398 llvm::ArrayRef<llvm::Value *> privateVariables,
1399 LLVM::ModuleTranslation &moduleTranslation,
1400 llvm::IRBuilderBase &builder, StringRef regionName,
1401 bool shouldLoadCleanupRegionArg = true) {
1402 for (auto [i, cleanupRegion] : llvm::enumerate(cleanupRegions)) {
1403 if (cleanupRegion->empty())
1404 continue;
1405
1406 // map the argument to the cleanup region
1407 Block &entry = cleanupRegion->front();
1408
1409 llvm::Instruction *potentialTerminator =
1410 builder.GetInsertBlock()->empty() ? nullptr
1411 : &builder.GetInsertBlock()->back();
1412 if (potentialTerminator && potentialTerminator->isTerminator())
1413 builder.SetInsertPoint(potentialTerminator);
1414 llvm::Value *privateVarValue =
1415 shouldLoadCleanupRegionArg
1416 ? builder.CreateLoad(
1417 moduleTranslation.convertType(entry.getArgument(0).getType()),
1418 privateVariables[i])
1419 : privateVariables[i];
1420
1421 moduleTranslation.mapValue(entry.getArgument(0), privateVarValue);
1422
1423 if (failed(inlineConvertOmpRegions(*cleanupRegion, regionName, builder,
1424 moduleTranslation)))
1425 return failure();
1426
1427 // clear block argument mapping in case it needs to be re-created with a
1428 // different source for another use of the same reduction decl
1429 moduleTranslation.forgetMapping(*cleanupRegion);
1430 }
1431 return success();
1432}
1433
1434// TODO: not used by ParallelOp
1435template <class OP>
1436static LogicalResult createReductionsAndCleanup(
1437 OP op, llvm::IRBuilderBase &builder,
1438 LLVM::ModuleTranslation &moduleTranslation,
1439 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1441 ArrayRef<llvm::Value *> privateReductionVariables, ArrayRef<bool> isByRef,
1442 bool isNowait = false, bool isTeamsReduction = false) {
1443 // Process the reductions if required.
1444 if (op.getNumReductionVars() == 0)
1445 return success();
1446
1448 SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
1449 SmallVector<OwningDataPtrPtrReductionGen> owningReductionGenRefDataPtrGens;
1451
1452 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
1453
1454 // Create the reduction generators. We need to own them here because
1455 // ReductionInfo only accepts references to the generators.
1456 collectReductionInfo(op, builder, moduleTranslation, reductionDecls,
1457 owningReductionGens, owningAtomicReductionGens,
1458 owningReductionGenRefDataPtrGens,
1459 privateReductionVariables, reductionInfos, isByRef);
1460
1461 // The call to createReductions below expects the block to have a
1462 // terminator. Create an unreachable instruction to serve as terminator
1463 // and remove it later.
1464 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
1465 builder.SetInsertPoint(tempTerminator);
1466 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
1467 ompBuilder->createReductions(builder.saveIP(), allocaIP, reductionInfos,
1468 isByRef, isNowait, isTeamsReduction);
1469
1470 if (failed(handleError(contInsertPoint, *op)))
1471 return failure();
1472
1473 if (!contInsertPoint->getBlock())
1474 return op->emitOpError() << "failed to convert reductions";
1475
1476 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1477 ompBuilder->createBarrier(*contInsertPoint, llvm::omp::OMPD_for);
1478
1479 if (failed(handleError(afterIP, *op)))
1480 return failure();
1481
1482 tempTerminator->eraseFromParent();
1483 builder.restoreIP(*afterIP);
1484
1485 // after the construct, deallocate private reduction variables
1486 SmallVector<Region *> reductionRegions;
1487 llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
1488 [](omp::DeclareReductionOp reductionDecl) {
1489 return &reductionDecl.getCleanupRegion();
1490 });
1491 return inlineOmpRegionCleanup(reductionRegions, privateReductionVariables,
1492 moduleTranslation, builder,
1493 "omp.reduction.cleanup");
1494 return success();
1495}
1496
1497static ArrayRef<bool> getIsByRef(std::optional<ArrayRef<bool>> attr) {
1498 if (!attr)
1499 return {};
1500 return *attr;
1501}
1502
1503// TODO: not used by omp.parallel
1504template <typename OP>
1506 OP op, ArrayRef<BlockArgument> reductionArgs, llvm::IRBuilderBase &builder,
1507 LLVM::ModuleTranslation &moduleTranslation,
1508 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1510 SmallVectorImpl<llvm::Value *> &privateReductionVariables,
1511 DenseMap<Value, llvm::Value *> &reductionVariableMap,
1512 llvm::ArrayRef<bool> isByRef) {
1513 if (op.getNumReductionVars() == 0)
1514 return success();
1515
1516 SmallVector<DeferredStore> deferredStores;
1517
1518 if (failed(allocReductionVars(op, reductionArgs, builder, moduleTranslation,
1519 allocaIP, reductionDecls,
1520 privateReductionVariables, reductionVariableMap,
1521 deferredStores, isByRef)))
1522 return failure();
1523
1524 return initReductionVars(op, reductionArgs, builder, moduleTranslation,
1525 allocaIP.getBlock(), reductionDecls,
1526 privateReductionVariables, reductionVariableMap,
1527 isByRef, deferredStores);
1528}
1529
1530/// Return the llvm::Value * corresponding to the `privateVar` that
1531/// is being privatized. It isn't always as simple as looking up
1532/// moduleTranslation with privateVar. For instance, in case of
1533/// an allocatable, the descriptor for the allocatable is privatized.
1534/// This descriptor is mapped using an MapInfoOp. So, this function
1535/// will return a pointer to the llvm::Value corresponding to the
1536/// block argument for the mapped descriptor.
1537static llvm::Value *
1538findAssociatedValue(Value privateVar, llvm::IRBuilderBase &builder,
1539 LLVM::ModuleTranslation &moduleTranslation,
1540 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1541 if (mappedPrivateVars == nullptr || !mappedPrivateVars->contains(privateVar))
1542 return moduleTranslation.lookupValue(privateVar);
1543
1544 Value blockArg = (*mappedPrivateVars)[privateVar];
1545 Type privVarType = privateVar.getType();
1546 Type blockArgType = blockArg.getType();
1547 assert(isa<LLVM::LLVMPointerType>(blockArgType) &&
1548 "A block argument corresponding to a mapped var should have "
1549 "!llvm.ptr type");
1550
1551 if (privVarType == blockArgType)
1552 return moduleTranslation.lookupValue(blockArg);
1553
1554 // This typically happens when the privatized type is lowered from
1555 // boxchar<KIND> and gets lowered to !llvm.struct<(ptr, i64)>. That is the
1556 // struct/pair is passed by value. But, mapped values are passed only as
1557 // pointers, so before we privatize, we must load the pointer.
1558 if (!isa<LLVM::LLVMPointerType>(privVarType))
1559 return builder.CreateLoad(moduleTranslation.convertType(privVarType),
1560 moduleTranslation.lookupValue(blockArg));
1561
1562 return moduleTranslation.lookupValue(privateVar);
1563}
1564
1565/// Initialize a single (first)private variable. You probably want to use
1566/// allocateAndInitPrivateVars instead of this.
1567/// This returns the private variable which has been initialized. This
1568/// variable should be mapped before constructing the body of the Op.
1570initPrivateVar(llvm::IRBuilderBase &builder,
1571 LLVM::ModuleTranslation &moduleTranslation,
1572 omp::PrivateClauseOp &privDecl, llvm::Value *nonPrivateVar,
1573 BlockArgument &blockArg, llvm::Value *llvmPrivateVar,
1574 llvm::BasicBlock *privInitBlock,
1575 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1576 Region &initRegion = privDecl.getInitRegion();
1577 if (initRegion.empty())
1578 return llvmPrivateVar;
1579
1580 assert(nonPrivateVar);
1581 moduleTranslation.mapValue(privDecl.getInitMoldArg(), nonPrivateVar);
1582 moduleTranslation.mapValue(privDecl.getInitPrivateArg(), llvmPrivateVar);
1583
1584 // in-place convert the private initialization region
1586 if (failed(inlineConvertOmpRegions(initRegion, "omp.private.init", builder,
1587 moduleTranslation, &phis)))
1588 return llvm::createStringError(
1589 "failed to inline `init` region of `omp.private`");
1590
1591 assert(phis.size() == 1 && "expected one allocation to be yielded");
1592
1593 // clear init region block argument mapping in case it needs to be
1594 // re-created with a different source for another use of the same
1595 // reduction decl
1596 moduleTranslation.forgetMapping(initRegion);
1597
1598 // Prefer the value yielded from the init region to the allocated private
1599 // variable in case the region is operating on arguments by-value (e.g.
1600 // Fortran character boxes).
1601 return phis[0];
1602}
1603
1604/// Version of initPrivateVar which looks up the nonPrivateVar from mlirPrivVar.
1606 llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation,
1607 omp::PrivateClauseOp &privDecl, Value mlirPrivVar, BlockArgument &blockArg,
1608 llvm::Value *llvmPrivateVar, llvm::BasicBlock *privInitBlock,
1609 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1610 return initPrivateVar(
1611 builder, moduleTranslation, privDecl,
1612 findAssociatedValue(mlirPrivVar, builder, moduleTranslation,
1613 mappedPrivateVars),
1614 blockArg, llvmPrivateVar, privInitBlock, mappedPrivateVars);
1615}
1616
1617static llvm::Error
1618initPrivateVars(llvm::IRBuilderBase &builder,
1619 LLVM::ModuleTranslation &moduleTranslation,
1620 PrivateVarsInfo &privateVarsInfo,
1621 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1622 if (privateVarsInfo.blockArgs.empty())
1623 return llvm::Error::success();
1624
1625 llvm::BasicBlock *privInitBlock = splitBB(builder, true, "omp.private.init");
1626 setInsertPointForPossiblyEmptyBlock(builder, privInitBlock);
1627
1628 for (auto [idx, zip] : llvm::enumerate(llvm::zip_equal(
1629 privateVarsInfo.privatizers, privateVarsInfo.mlirVars,
1630 privateVarsInfo.blockArgs, privateVarsInfo.llvmVars))) {
1631 auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVar] = zip;
1633 builder, moduleTranslation, privDecl, mlirPrivVar, blockArg,
1634 llvmPrivateVar, privInitBlock, mappedPrivateVars);
1635
1636 if (!privVarOrErr)
1637 return privVarOrErr.takeError();
1638
1639 llvmPrivateVar = privVarOrErr.get();
1640 moduleTranslation.mapValue(blockArg, llvmPrivateVar);
1641
1643 }
1644
1645 return llvm::Error::success();
1646}
1647
1648/// Allocate and initialize delayed private variables. Returns the basic block
1649/// which comes after all of these allocations. llvm::Value * for each of these
1650/// private variables are populated in llvmPrivateVars.
1652allocatePrivateVars(llvm::IRBuilderBase &builder,
1653 LLVM::ModuleTranslation &moduleTranslation,
1654 PrivateVarsInfo &privateVarsInfo,
1655 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1656 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1657 // Allocate private vars
1658 llvm::Instruction *allocaTerminator = allocaIP.getBlock()->getTerminator();
1659 splitBB(llvm::OpenMPIRBuilder::InsertPointTy(allocaIP.getBlock(),
1660 allocaTerminator->getIterator()),
1661 true, allocaTerminator->getStableDebugLoc(),
1662 "omp.region.after_alloca");
1663
1664 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1665 // Update the allocaTerminator since the alloca block was split above.
1666 allocaTerminator = allocaIP.getBlock()->getTerminator();
1667 builder.SetInsertPoint(allocaTerminator);
1668 // The new terminator is an uncondition branch created by the splitBB above.
1669 assert(allocaTerminator->getNumSuccessors() == 1 &&
1670 "This is an unconditional branch created by splitBB");
1671
1672 llvm::DataLayout dataLayout = builder.GetInsertBlock()->getDataLayout();
1673 llvm::BasicBlock *afterAllocas = allocaTerminator->getSuccessor(0);
1674
1675 unsigned int allocaAS =
1676 moduleTranslation.getLLVMModule()->getDataLayout().getAllocaAddrSpace();
1677 unsigned int defaultAS = moduleTranslation.getLLVMModule()
1678 ->getDataLayout()
1679 .getProgramAddressSpace();
1680
1681 for (auto [privDecl, mlirPrivVar, blockArg] :
1682 llvm::zip_equal(privateVarsInfo.privatizers, privateVarsInfo.mlirVars,
1683 privateVarsInfo.blockArgs)) {
1684 llvm::Type *llvmAllocType =
1685 moduleTranslation.convertType(privDecl.getType());
1686 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1687 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
1688 llvmAllocType, /*ArraySize=*/nullptr, "omp.private.alloc");
1689 if (allocaAS != defaultAS)
1690 llvmPrivateVar = builder.CreateAddrSpaceCast(llvmPrivateVar,
1691 builder.getPtrTy(defaultAS));
1692
1693 privateVarsInfo.llvmVars.push_back(llvmPrivateVar);
1694 }
1695
1696 return afterAllocas;
1697}
1698
1699/// This can't always be determined statically, but when we can, it is good to
1700/// avoid generating compiler-added barriers which will deadlock the program.
1702 for (mlir::Operation *parent = op->getParentOp(); parent != nullptr;
1703 parent = parent->getParentOp()) {
1704 if (mlir::isa<omp::SingleOp, omp::CriticalOp>(parent))
1705 return true;
1706
1707 // e.g.
1708 // omp.single {
1709 // omp.parallel {
1710 // op
1711 // }
1712 // }
1713 if (mlir::isa<omp::ParallelOp>(parent))
1714 return false;
1715 }
1716 return false;
1717}
1718
1719static LogicalResult copyFirstPrivateVars(
1720 mlir::Operation *op, llvm::IRBuilderBase &builder,
1721 LLVM::ModuleTranslation &moduleTranslation,
1723 ArrayRef<llvm::Value *> llvmPrivateVars,
1724 SmallVectorImpl<omp::PrivateClauseOp> &privateDecls, bool insertBarrier,
1725 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1726 // Apply copy region for firstprivate.
1727 bool needsFirstprivate =
1728 llvm::any_of(privateDecls, [](omp::PrivateClauseOp &privOp) {
1729 return privOp.getDataSharingType() ==
1730 omp::DataSharingClauseType::FirstPrivate;
1731 });
1732
1733 if (!needsFirstprivate)
1734 return success();
1735
1736 llvm::BasicBlock *copyBlock =
1737 splitBB(builder, /*CreateBranch=*/true, "omp.private.copy");
1738 setInsertPointForPossiblyEmptyBlock(builder, copyBlock);
1739
1740 for (auto [decl, moldVar, llvmVar] :
1741 llvm::zip_equal(privateDecls, moldVars, llvmPrivateVars)) {
1742 if (decl.getDataSharingType() != omp::DataSharingClauseType::FirstPrivate)
1743 continue;
1744
1745 // copyRegion implements `lhs = rhs`
1746 Region &copyRegion = decl.getCopyRegion();
1747
1748 moduleTranslation.mapValue(decl.getCopyMoldArg(), moldVar);
1749
1750 // map copyRegion lhs arg
1751 moduleTranslation.mapValue(decl.getCopyPrivateArg(), llvmVar);
1752
1753 // in-place convert copy region
1754 if (failed(inlineConvertOmpRegions(copyRegion, "omp.private.copy", builder,
1755 moduleTranslation)))
1756 return decl.emitError("failed to inline `copy` region of `omp.private`");
1757
1759
1760 // ignore unused value yielded from copy region
1761
1762 // clear copy region block argument mapping in case it needs to be
1763 // re-created with different sources for reuse of the same reduction
1764 // decl
1765 moduleTranslation.forgetMapping(copyRegion);
1766 }
1767
1768 if (insertBarrier && !opIsInSingleThread(op)) {
1769 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
1770 llvm::OpenMPIRBuilder::InsertPointOrErrorTy res =
1771 ompBuilder->createBarrier(builder.saveIP(), llvm::omp::OMPD_barrier);
1772 if (failed(handleError(res, *op)))
1773 return failure();
1774 }
1775
1776 return success();
1777}
1778
1779static LogicalResult copyFirstPrivateVars(
1780 mlir::Operation *op, llvm::IRBuilderBase &builder,
1781 LLVM::ModuleTranslation &moduleTranslation,
1782 SmallVectorImpl<mlir::Value> &mlirPrivateVars,
1783 ArrayRef<llvm::Value *> llvmPrivateVars,
1784 SmallVectorImpl<omp::PrivateClauseOp> &privateDecls, bool insertBarrier,
1785 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1786 llvm::SmallVector<llvm::Value *> moldVars(mlirPrivateVars.size());
1787 llvm::transform(mlirPrivateVars, moldVars.begin(), [&](mlir::Value mlirVar) {
1788 // map copyRegion rhs arg
1789 llvm::Value *moldVar = findAssociatedValue(
1790 mlirVar, builder, moduleTranslation, mappedPrivateVars);
1791 assert(moldVar);
1792 return moldVar;
1793 });
1794 return copyFirstPrivateVars(op, builder, moduleTranslation, moldVars,
1795 llvmPrivateVars, privateDecls, insertBarrier,
1796 mappedPrivateVars);
1797}
1798
1799static LogicalResult
1800cleanupPrivateVars(llvm::IRBuilderBase &builder,
1801 LLVM::ModuleTranslation &moduleTranslation, Location loc,
1802 SmallVectorImpl<llvm::Value *> &llvmPrivateVars,
1804 // private variable deallocation
1805 SmallVector<Region *> privateCleanupRegions;
1806 llvm::transform(privateDecls, std::back_inserter(privateCleanupRegions),
1807 [](omp::PrivateClauseOp privatizer) {
1808 return &privatizer.getDeallocRegion();
1809 });
1810
1811 if (failed(inlineOmpRegionCleanup(
1812 privateCleanupRegions, llvmPrivateVars, moduleTranslation, builder,
1813 "omp.private.dealloc", /*shouldLoadCleanupRegionArg=*/false)))
1814 return mlir::emitError(loc, "failed to inline `dealloc` region of an "
1815 "`omp.private` op in");
1816
1817 return success();
1818}
1819
1820/// Returns true if the construct contains omp.cancel or omp.cancellation_point
1822 // omp.cancel and omp.cancellation_point must be "closely nested" so they will
1823 // be visible and not inside of function calls. This is enforced by the
1824 // verifier.
1825 return op
1826 ->walk([](Operation *child) {
1827 if (mlir::isa<omp::CancelOp, omp::CancellationPointOp>(child))
1828 return WalkResult::interrupt();
1829 return WalkResult::advance();
1830 })
1831 .wasInterrupted();
1832}
1833
1834static LogicalResult
1835convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder,
1836 LLVM::ModuleTranslation &moduleTranslation) {
1837 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1838 using StorableBodyGenCallbackTy =
1839 llvm::OpenMPIRBuilder::StorableBodyGenCallbackTy;
1840
1841 auto sectionsOp = cast<omp::SectionsOp>(opInst);
1842
1843 if (failed(checkImplementationStatus(opInst)))
1844 return failure();
1845
1846 llvm::ArrayRef<bool> isByRef = getIsByRef(sectionsOp.getReductionByref());
1847 assert(isByRef.size() == sectionsOp.getNumReductionVars());
1848
1850 collectReductionDecls(sectionsOp, reductionDecls);
1851 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1852 findAllocaInsertPoint(builder, moduleTranslation);
1853
1854 SmallVector<llvm::Value *> privateReductionVariables(
1855 sectionsOp.getNumReductionVars());
1856 DenseMap<Value, llvm::Value *> reductionVariableMap;
1857
1858 MutableArrayRef<BlockArgument> reductionArgs =
1859 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
1860
1862 sectionsOp, reductionArgs, builder, moduleTranslation, allocaIP,
1863 reductionDecls, privateReductionVariables, reductionVariableMap,
1864 isByRef)))
1865 return failure();
1866
1868
1869 for (Operation &op : *sectionsOp.getRegion().begin()) {
1870 auto sectionOp = dyn_cast<omp::SectionOp>(op);
1871 if (!sectionOp) // omp.terminator
1872 continue;
1873
1874 Region &region = sectionOp.getRegion();
1875 auto sectionCB = [&sectionsOp, &region, &builder, &moduleTranslation](
1876 InsertPointTy allocaIP, InsertPointTy codeGenIP) {
1877 builder.restoreIP(codeGenIP);
1878
1879 // map the omp.section reduction block argument to the omp.sections block
1880 // arguments
1881 // TODO: this assumes that the only block arguments are reduction
1882 // variables
1883 assert(region.getNumArguments() ==
1884 sectionsOp.getRegion().getNumArguments());
1885 for (auto [sectionsArg, sectionArg] : llvm::zip_equal(
1886 sectionsOp.getRegion().getArguments(), region.getArguments())) {
1887 llvm::Value *llvmVal = moduleTranslation.lookupValue(sectionsArg);
1888 assert(llvmVal);
1889 moduleTranslation.mapValue(sectionArg, llvmVal);
1890 }
1891
1892 return convertOmpOpRegions(region, "omp.section.region", builder,
1893 moduleTranslation)
1894 .takeError();
1895 };
1896 sectionCBs.push_back(sectionCB);
1897 }
1898
1899 // No sections within omp.sections operation - skip generation. This situation
1900 // is only possible if there is only a terminator operation inside the
1901 // sections operation
1902 if (sectionCBs.empty())
1903 return success();
1904
1905 assert(isa<omp::SectionOp>(*sectionsOp.getRegion().op_begin()));
1906
1907 // TODO: Perform appropriate actions according to the data-sharing
1908 // attribute (shared, private, firstprivate, ...) of variables.
1909 // Currently defaults to shared.
1910 auto privCB = [&](InsertPointTy, InsertPointTy codeGenIP, llvm::Value &,
1911 llvm::Value &vPtr, llvm::Value *&replacementValue)
1912 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1913 replacementValue = &vPtr;
1914 return codeGenIP;
1915 };
1916
1917 // TODO: Perform finalization actions for variables. This has to be
1918 // called for variables which have destructors/finalizers.
1919 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
1920
1921 allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
1922 bool isCancellable = constructIsCancellable(sectionsOp);
1923 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1924 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1925 moduleTranslation.getOpenMPBuilder()->createSections(
1926 ompLoc, allocaIP, sectionCBs, privCB, finiCB, isCancellable,
1927 sectionsOp.getNowait());
1928
1929 if (failed(handleError(afterIP, opInst)))
1930 return failure();
1931
1932 builder.restoreIP(*afterIP);
1933
1934 // Process the reductions if required.
1936 sectionsOp, builder, moduleTranslation, allocaIP, reductionDecls,
1937 privateReductionVariables, isByRef, sectionsOp.getNowait());
1938}
1939
1940/// Converts an OpenMP single construct into LLVM IR using OpenMPIRBuilder.
1941static LogicalResult
1942convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder,
1943 LLVM::ModuleTranslation &moduleTranslation) {
1944 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1945 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1946
1947 if (failed(checkImplementationStatus(*singleOp)))
1948 return failure();
1949
1950 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
1951 builder.restoreIP(codegenIP);
1952 return convertOmpOpRegions(singleOp.getRegion(), "omp.single.region",
1953 builder, moduleTranslation)
1954 .takeError();
1955 };
1956 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
1957
1958 // Handle copyprivate
1959 Operation::operand_range cpVars = singleOp.getCopyprivateVars();
1960 std::optional<ArrayAttr> cpFuncs = singleOp.getCopyprivateSyms();
1963 for (size_t i = 0, e = cpVars.size(); i < e; ++i) {
1964 llvmCPVars.push_back(moduleTranslation.lookupValue(cpVars[i]));
1966 singleOp, cast<SymbolRefAttr>((*cpFuncs)[i]));
1967 llvmCPFuncs.push_back(
1968 moduleTranslation.lookupFunction(llvmFuncOp.getName()));
1969 }
1970
1971 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1972 moduleTranslation.getOpenMPBuilder()->createSingle(
1973 ompLoc, bodyCB, finiCB, singleOp.getNowait(), llvmCPVars,
1974 llvmCPFuncs);
1975
1976 if (failed(handleError(afterIP, *singleOp)))
1977 return failure();
1978
1979 builder.restoreIP(*afterIP);
1980 return success();
1981}
1982
1983static bool teamsReductionContainedInDistribute(omp::TeamsOp teamsOp) {
1984 auto iface =
1985 llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(teamsOp.getOperation());
1986 // Check that all uses of the reduction block arg has the same distribute op
1987 // parent.
1989 Operation *distOp = nullptr;
1990 for (auto ra : iface.getReductionBlockArgs())
1991 for (auto &use : ra.getUses()) {
1992 auto *useOp = use.getOwner();
1993 // Ignore debug uses.
1994 if (mlir::isa<LLVM::DbgDeclareOp, LLVM::DbgValueOp>(useOp)) {
1995 debugUses.push_back(useOp);
1996 continue;
1997 }
1998
1999 auto currentDistOp = useOp->getParentOfType<omp::DistributeOp>();
2000 // Use is not inside a distribute op - return false
2001 if (!currentDistOp)
2002 return false;
2003 // Multiple distribute operations - return false
2004 Operation *currentOp = currentDistOp.getOperation();
2005 if (distOp && (distOp != currentOp))
2006 return false;
2007
2008 distOp = currentOp;
2009 }
2010
2011 // If we are going to use distribute reduction then remove any debug uses of
2012 // the reduction parameters in teamsOp. Otherwise they will be left without
2013 // any mapped value in moduleTranslation and will eventually error out.
2014 for (auto *use : debugUses)
2015 use->erase();
2016 return true;
2017}
2018
2019// Convert an OpenMP Teams construct to LLVM IR using OpenMPIRBuilder
2020static LogicalResult
2021convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder,
2022 LLVM::ModuleTranslation &moduleTranslation) {
2023 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2024 if (failed(checkImplementationStatus(*op)))
2025 return failure();
2026
2027 DenseMap<Value, llvm::Value *> reductionVariableMap;
2028 unsigned numReductionVars = op.getNumReductionVars();
2030 SmallVector<llvm::Value *> privateReductionVariables(numReductionVars);
2031 llvm::ArrayRef<bool> isByRef;
2032 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2033 findAllocaInsertPoint(builder, moduleTranslation);
2034
2035 // Only do teams reduction if there is no distribute op that captures the
2036 // reduction instead.
2037 bool doTeamsReduction = !teamsReductionContainedInDistribute(op);
2038 if (doTeamsReduction) {
2039 isByRef = getIsByRef(op.getReductionByref());
2040
2041 assert(isByRef.size() == op.getNumReductionVars());
2042
2043 MutableArrayRef<BlockArgument> reductionArgs =
2044 llvm::cast<omp::BlockArgOpenMPOpInterface>(*op).getReductionBlockArgs();
2045
2046 collectReductionDecls(op, reductionDecls);
2047
2049 op, reductionArgs, builder, moduleTranslation, allocaIP,
2050 reductionDecls, privateReductionVariables, reductionVariableMap,
2051 isByRef)))
2052 return failure();
2053 }
2054
2055 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
2057 moduleTranslation, allocaIP);
2058 builder.restoreIP(codegenIP);
2059 return convertOmpOpRegions(op.getRegion(), "omp.teams.region", builder,
2060 moduleTranslation)
2061 .takeError();
2062 };
2063
2064 llvm::Value *numTeamsLower = nullptr;
2065 if (Value numTeamsLowerVar = op.getNumTeamsLower())
2066 numTeamsLower = moduleTranslation.lookupValue(numTeamsLowerVar);
2067
2068 llvm::Value *numTeamsUpper = nullptr;
2069 if (Value numTeamsUpperVar = op.getNumTeamsUpper())
2070 numTeamsUpper = moduleTranslation.lookupValue(numTeamsUpperVar);
2071
2072 llvm::Value *threadLimit = nullptr;
2073 if (Value threadLimitVar = op.getThreadLimit())
2074 threadLimit = moduleTranslation.lookupValue(threadLimitVar);
2075
2076 llvm::Value *ifExpr = nullptr;
2077 if (Value ifVar = op.getIfExpr())
2078 ifExpr = moduleTranslation.lookupValue(ifVar);
2079
2080 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2081 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2082 moduleTranslation.getOpenMPBuilder()->createTeams(
2083 ompLoc, bodyCB, numTeamsLower, numTeamsUpper, threadLimit, ifExpr);
2084
2085 if (failed(handleError(afterIP, *op)))
2086 return failure();
2087
2088 builder.restoreIP(*afterIP);
2089 if (doTeamsReduction) {
2090 // Process the reductions if required.
2092 op, builder, moduleTranslation, allocaIP, reductionDecls,
2093 privateReductionVariables, isByRef,
2094 /*isNoWait*/ false, /*isTeamsReduction*/ true);
2095 }
2096 return success();
2097}
2098
2099static void
2100buildDependData(std::optional<ArrayAttr> dependKinds, OperandRange dependVars,
2101 LLVM::ModuleTranslation &moduleTranslation,
2103 if (dependVars.empty())
2104 return;
2105 for (auto dep : llvm::zip(dependVars, dependKinds->getValue())) {
2106 llvm::omp::RTLDependenceKindTy type;
2107 switch (
2108 cast<mlir::omp::ClauseTaskDependAttr>(std::get<1>(dep)).getValue()) {
2109 case mlir::omp::ClauseTaskDepend::taskdependin:
2110 type = llvm::omp::RTLDependenceKindTy::DepIn;
2111 break;
2112 // The OpenMP runtime requires that the codegen for 'depend' clause for
2113 // 'out' dependency kind must be the same as codegen for 'depend' clause
2114 // with 'inout' dependency.
2115 case mlir::omp::ClauseTaskDepend::taskdependout:
2116 case mlir::omp::ClauseTaskDepend::taskdependinout:
2117 type = llvm::omp::RTLDependenceKindTy::DepInOut;
2118 break;
2119 case mlir::omp::ClauseTaskDepend::taskdependmutexinoutset:
2120 type = llvm::omp::RTLDependenceKindTy::DepMutexInOutSet;
2121 break;
2122 case mlir::omp::ClauseTaskDepend::taskdependinoutset:
2123 type = llvm::omp::RTLDependenceKindTy::DepInOutSet;
2124 break;
2125 };
2126 llvm::Value *depVal = moduleTranslation.lookupValue(std::get<0>(dep));
2127 llvm::OpenMPIRBuilder::DependData dd(type, depVal->getType(), depVal);
2128 dds.emplace_back(dd);
2129 }
2130}
2131
2132/// Shared implementation of a callback which adds a termiator for the new block
2133/// created for the branch taken when an openmp construct is cancelled. The
2134/// terminator is saved in \p cancelTerminators. This callback is invoked only
2135/// if there is cancellation inside of the taskgroup body.
2136/// The terminator will need to be fixed to branch to the correct block to
2137/// cleanup the construct.
2138static void
2140 llvm::IRBuilderBase &llvmBuilder,
2141 llvm::OpenMPIRBuilder &ompBuilder, mlir::Operation *op,
2142 llvm::omp::Directive cancelDirective) {
2143 auto finiCB = [&](llvm::OpenMPIRBuilder::InsertPointTy ip) -> llvm::Error {
2144 llvm::IRBuilderBase::InsertPointGuard guard(llvmBuilder);
2145
2146 // ip is currently in the block branched to if cancellation occured.
2147 // We need to create a branch to terminate that block.
2148 llvmBuilder.restoreIP(ip);
2149
2150 // We must still clean up the construct after cancelling it, so we need to
2151 // branch to the block that finalizes the taskgroup.
2152 // That block has not been created yet so use this block as a dummy for now
2153 // and fix this after creating the operation.
2154 cancelTerminators.push_back(llvmBuilder.CreateBr(ip.getBlock()));
2155 return llvm::Error::success();
2156 };
2157 // We have to add the cleanup to the OpenMPIRBuilder before the body gets
2158 // created in case the body contains omp.cancel (which will then expect to be
2159 // able to find this cleanup callback).
2160 ompBuilder.pushFinalizationCB(
2161 {finiCB, cancelDirective, constructIsCancellable(op)});
2162}
2163
2164/// If we cancelled the construct, we should branch to the finalization block of
2165/// that construct. OMPIRBuilder structures the CFG such that the cleanup block
2166/// is immediately before the continuation block. Now this finalization has
2167/// been created we can fix the branch.
2168static void
2170 llvm::OpenMPIRBuilder &ompBuilder,
2171 const llvm::OpenMPIRBuilder::InsertPointTy &afterIP) {
2172 ompBuilder.popFinalizationCB();
2173 llvm::BasicBlock *constructFini = afterIP.getBlock()->getSinglePredecessor();
2174 for (llvm::BranchInst *cancelBranch : cancelTerminators) {
2175 assert(cancelBranch->getNumSuccessors() == 1 &&
2176 "cancel branch should have one target");
2177 cancelBranch->setSuccessor(0, constructFini);
2178 }
2179}
2180
2181namespace {
2182/// TaskContextStructManager takes care of creating and freeing a structure
2183/// containing information needed by the task body to execute.
2184class TaskContextStructManager {
2185public:
2186 TaskContextStructManager(llvm::IRBuilderBase &builder,
2187 LLVM::ModuleTranslation &moduleTranslation,
2188 MutableArrayRef<omp::PrivateClauseOp> privateDecls)
2189 : builder{builder}, moduleTranslation{moduleTranslation},
2190 privateDecls{privateDecls} {}
2191
2192 /// Creates a heap allocated struct containing space for each private
2193 /// variable. Invariant: privateVarTypes, privateDecls, and the elements of
2194 /// the structure should all have the same order (although privateDecls which
2195 /// do not read from the mold argument are skipped).
2196 void generateTaskContextStruct();
2197
2198 /// Create GEPs to access each member of the structure representing a private
2199 /// variable, adding them to llvmPrivateVars. Null values are added where
2200 /// private decls were skipped so that the ordering continues to match the
2201 /// private decls.
2202 void createGEPsToPrivateVars();
2203
2204 /// Given the address of the structure, return a GEP for each private variable
2205 /// in the structure. Null values are added where private decls were skipped
2206 /// so that the ordering continues to match the private decls.
2207 /// Must be called after generateTaskContextStruct().
2208 SmallVector<llvm::Value *>
2209 createGEPsToPrivateVars(llvm::Value *altStructPtr) const;
2210
2211 /// De-allocate the task context structure.
2212 void freeStructPtr();
2213
2214 MutableArrayRef<llvm::Value *> getLLVMPrivateVarGEPs() {
2215 return llvmPrivateVarGEPs;
2216 }
2217
2218 llvm::Value *getStructPtr() { return structPtr; }
2219
2220private:
2221 llvm::IRBuilderBase &builder;
2222 LLVM::ModuleTranslation &moduleTranslation;
2223 MutableArrayRef<omp::PrivateClauseOp> privateDecls;
2224
2225 /// The type of each member of the structure, in order.
2226 SmallVector<llvm::Type *> privateVarTypes;
2227
2228 /// LLVM values for each private variable, or null if that private variable is
2229 /// not included in the task context structure
2230 SmallVector<llvm::Value *> llvmPrivateVarGEPs;
2231
2232 /// A pointer to the structure containing context for this task.
2233 llvm::Value *structPtr = nullptr;
2234 /// The type of the structure
2235 llvm::Type *structTy = nullptr;
2236};
2237} // namespace
2238
2239void TaskContextStructManager::generateTaskContextStruct() {
2240 if (privateDecls.empty())
2241 return;
2242 privateVarTypes.reserve(privateDecls.size());
2243
2244 for (omp::PrivateClauseOp &privOp : privateDecls) {
2245 // Skip private variables which can safely be allocated and initialised
2246 // inside of the task
2247 if (!privOp.readsFromMold())
2248 continue;
2249 Type mlirType = privOp.getType();
2250 privateVarTypes.push_back(moduleTranslation.convertType(mlirType));
2251 }
2252
2253 if (privateVarTypes.empty())
2254 return;
2255
2256 structTy = llvm::StructType::get(moduleTranslation.getLLVMContext(),
2257 privateVarTypes);
2258
2259 llvm::DataLayout dataLayout =
2260 builder.GetInsertBlock()->getModule()->getDataLayout();
2261 llvm::Type *intPtrTy = builder.getIntPtrTy(dataLayout);
2262 llvm::Constant *allocSize = llvm::ConstantExpr::getSizeOf(structTy);
2263
2264 // Heap allocate the structure
2265 structPtr = builder.CreateMalloc(intPtrTy, structTy, allocSize,
2266 /*ArraySize=*/nullptr, /*MallocF=*/nullptr,
2267 "omp.task.context_ptr");
2268}
2269
2270SmallVector<llvm::Value *> TaskContextStructManager::createGEPsToPrivateVars(
2271 llvm::Value *altStructPtr) const {
2272 SmallVector<llvm::Value *> ret;
2273
2274 // Create GEPs for each struct member
2275 ret.reserve(privateDecls.size());
2276 llvm::Value *zero = builder.getInt32(0);
2277 unsigned i = 0;
2278 for (auto privDecl : privateDecls) {
2279 if (!privDecl.readsFromMold()) {
2280 // Handle this inside of the task so we don't pass unnessecary vars in
2281 ret.push_back(nullptr);
2282 continue;
2283 }
2284 llvm::Value *iVal = builder.getInt32(i);
2285 llvm::Value *gep = builder.CreateGEP(structTy, altStructPtr, {zero, iVal});
2286 ret.push_back(gep);
2287 i += 1;
2288 }
2289 return ret;
2290}
2291
2292void TaskContextStructManager::createGEPsToPrivateVars() {
2293 if (!structPtr)
2294 assert(privateVarTypes.empty());
2295 // Still need to run createGEPsToPrivateVars to populate llvmPrivateVarGEPs
2296 // with null values for skipped private decls
2297
2298 llvmPrivateVarGEPs = createGEPsToPrivateVars(structPtr);
2299}
2300
2301void TaskContextStructManager::freeStructPtr() {
2302 if (!structPtr)
2303 return;
2304
2305 llvm::IRBuilderBase::InsertPointGuard guard{builder};
2306 // Ensure we don't put the call to free() after the terminator
2307 builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
2308 builder.CreateFree(structPtr);
2309}
2310
2311/// Converts an OpenMP task construct into LLVM IR using OpenMPIRBuilder.
2312static LogicalResult
2313convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
2314 LLVM::ModuleTranslation &moduleTranslation) {
2315 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2316 if (failed(checkImplementationStatus(*taskOp)))
2317 return failure();
2318
2319 PrivateVarsInfo privateVarsInfo(taskOp);
2320 TaskContextStructManager taskStructMgr{builder, moduleTranslation,
2321 privateVarsInfo.privatizers};
2322
2323 // Allocate and copy private variables before creating the task. This avoids
2324 // accessing invalid memory if (after this scope ends) the private variables
2325 // are initialized from host variables or if the variables are copied into
2326 // from host variables (firstprivate). The insertion point is just before
2327 // where the code for creating and scheduling the task will go. That puts this
2328 // code outside of the outlined task region, which is what we want because
2329 // this way the initialization and copy regions are executed immediately while
2330 // the host variable data are still live.
2331
2332 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2333 findAllocaInsertPoint(builder, moduleTranslation);
2334
2335 // Not using splitBB() because that requires the current block to have a
2336 // terminator.
2337 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end());
2338 llvm::BasicBlock *taskStartBlock = llvm::BasicBlock::Create(
2339 builder.getContext(), "omp.task.start",
2340 /*Parent=*/builder.GetInsertBlock()->getParent());
2341 llvm::Instruction *branchToTaskStartBlock = builder.CreateBr(taskStartBlock);
2342 builder.SetInsertPoint(branchToTaskStartBlock);
2343
2344 // Now do this again to make the initialization and copy blocks
2345 llvm::BasicBlock *copyBlock =
2346 splitBB(builder, /*CreateBranch=*/true, "omp.private.copy");
2347 llvm::BasicBlock *initBlock =
2348 splitBB(builder, /*CreateBranch=*/true, "omp.private.init");
2349
2350 // Now the control flow graph should look like
2351 // starter_block:
2352 // <---- where we started when convertOmpTaskOp was called
2353 // br %omp.private.init
2354 // omp.private.init:
2355 // br %omp.private.copy
2356 // omp.private.copy:
2357 // br %omp.task.start
2358 // omp.task.start:
2359 // <---- where we want the insertion point to be when we call createTask()
2360
2361 // Save the alloca insertion point on ModuleTranslation stack for use in
2362 // nested regions.
2364 moduleTranslation, allocaIP);
2365
2366 // Allocate and initialize private variables
2367 builder.SetInsertPoint(initBlock->getTerminator());
2368
2369 // Create task variable structure
2370 taskStructMgr.generateTaskContextStruct();
2371 // GEPs so that we can initialize the variables. Don't use these GEPs inside
2372 // of the body otherwise it will be the GEP not the struct which is fowarded
2373 // to the outlined function. GEPs forwarded in this way are passed in a
2374 // stack-allocated (by OpenMPIRBuilder) structure which is not safe for tasks
2375 // which may not be executed until after the current stack frame goes out of
2376 // scope.
2377 taskStructMgr.createGEPsToPrivateVars();
2378
2379 for (auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVarAlloc] :
2380 llvm::zip_equal(privateVarsInfo.privatizers, privateVarsInfo.mlirVars,
2381 privateVarsInfo.blockArgs,
2382 taskStructMgr.getLLVMPrivateVarGEPs())) {
2383 // To be handled inside the task.
2384 if (!privDecl.readsFromMold())
2385 continue;
2386 assert(llvmPrivateVarAlloc &&
2387 "reads from mold so shouldn't have been skipped");
2388
2389 llvm::Expected<llvm::Value *> privateVarOrErr =
2390 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2391 blockArg, llvmPrivateVarAlloc, initBlock);
2392 if (!privateVarOrErr)
2393 return handleError(privateVarOrErr, *taskOp.getOperation());
2394
2396
2397 // TODO: this is a bit of a hack for Fortran character boxes.
2398 // Character boxes are passed by value into the init region and then the
2399 // initialized character box is yielded by value. Here we need to store the
2400 // yielded value into the private allocation, and load the private
2401 // allocation to match the type expected by region block arguments.
2402 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
2403 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2404 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
2405 // Load it so we have the value pointed to by the GEP
2406 llvmPrivateVarAlloc = builder.CreateLoad(privateVarOrErr.get()->getType(),
2407 llvmPrivateVarAlloc);
2408 }
2409 assert(llvmPrivateVarAlloc->getType() ==
2410 moduleTranslation.convertType(blockArg.getType()));
2411
2412 // Mapping blockArg -> llvmPrivateVarAlloc is done inside the body callback
2413 // so that OpenMPIRBuilder doesn't try to pass each GEP address through a
2414 // stack allocated structure.
2415 }
2416
2417 // firstprivate copy region
2418 setInsertPointForPossiblyEmptyBlock(builder, copyBlock);
2419 if (failed(copyFirstPrivateVars(
2420 taskOp, builder, moduleTranslation, privateVarsInfo.mlirVars,
2421 taskStructMgr.getLLVMPrivateVarGEPs(), privateVarsInfo.privatizers,
2422 taskOp.getPrivateNeedsBarrier())))
2423 return llvm::failure();
2424
2425 // Set up for call to createTask()
2426 builder.SetInsertPoint(taskStartBlock);
2427
2428 auto bodyCB = [&](InsertPointTy allocaIP,
2429 InsertPointTy codegenIP) -> llvm::Error {
2430 // Save the alloca insertion point on ModuleTranslation stack for use in
2431 // nested regions.
2433 moduleTranslation, allocaIP);
2434
2435 // translate the body of the task:
2436 builder.restoreIP(codegenIP);
2437
2438 llvm::BasicBlock *privInitBlock = nullptr;
2439 privateVarsInfo.llvmVars.resize(privateVarsInfo.blockArgs.size());
2440 for (auto [i, zip] : llvm::enumerate(llvm::zip_equal(
2441 privateVarsInfo.blockArgs, privateVarsInfo.privatizers,
2442 privateVarsInfo.mlirVars))) {
2443 auto [blockArg, privDecl, mlirPrivVar] = zip;
2444 // This is handled before the task executes
2445 if (privDecl.readsFromMold())
2446 continue;
2447
2448 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2449 llvm::Type *llvmAllocType =
2450 moduleTranslation.convertType(privDecl.getType());
2451 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
2452 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
2453 llvmAllocType, /*ArraySize=*/nullptr, "omp.private.alloc");
2454
2455 llvm::Expected<llvm::Value *> privateVarOrError =
2456 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2457 blockArg, llvmPrivateVar, privInitBlock);
2458 if (!privateVarOrError)
2459 return privateVarOrError.takeError();
2460 moduleTranslation.mapValue(blockArg, privateVarOrError.get());
2461 privateVarsInfo.llvmVars[i] = privateVarOrError.get();
2462 }
2463
2464 taskStructMgr.createGEPsToPrivateVars();
2465 for (auto [i, llvmPrivVar] :
2466 llvm::enumerate(taskStructMgr.getLLVMPrivateVarGEPs())) {
2467 if (!llvmPrivVar) {
2468 assert(privateVarsInfo.llvmVars[i] &&
2469 "This is added in the loop above");
2470 continue;
2471 }
2472 privateVarsInfo.llvmVars[i] = llvmPrivVar;
2473 }
2474
2475 // Find and map the addresses of each variable within the task context
2476 // structure
2477 for (auto [blockArg, llvmPrivateVar, privateDecl] :
2478 llvm::zip_equal(privateVarsInfo.blockArgs, privateVarsInfo.llvmVars,
2479 privateVarsInfo.privatizers)) {
2480 // This was handled above.
2481 if (!privateDecl.readsFromMold())
2482 continue;
2483 // Fix broken pass-by-value case for Fortran character boxes
2484 if (!mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2485 llvmPrivateVar = builder.CreateLoad(
2486 moduleTranslation.convertType(blockArg.getType()), llvmPrivateVar);
2487 }
2488 assert(llvmPrivateVar->getType() ==
2489 moduleTranslation.convertType(blockArg.getType()));
2490 moduleTranslation.mapValue(blockArg, llvmPrivateVar);
2491 }
2492
2493 auto continuationBlockOrError = convertOmpOpRegions(
2494 taskOp.getRegion(), "omp.task.region", builder, moduleTranslation);
2495 if (failed(handleError(continuationBlockOrError, *taskOp)))
2496 return llvm::make_error<PreviouslyReportedError>();
2497
2498 builder.SetInsertPoint(continuationBlockOrError.get()->getTerminator());
2499
2500 if (failed(cleanupPrivateVars(builder, moduleTranslation, taskOp.getLoc(),
2501 privateVarsInfo.llvmVars,
2502 privateVarsInfo.privatizers)))
2503 return llvm::make_error<PreviouslyReportedError>();
2504
2505 // Free heap allocated task context structure at the end of the task.
2506 taskStructMgr.freeStructPtr();
2507
2508 return llvm::Error::success();
2509 };
2510
2511 llvm::OpenMPIRBuilder &ompBuilder = *moduleTranslation.getOpenMPBuilder();
2512 SmallVector<llvm::BranchInst *> cancelTerminators;
2513 // The directive to match here is OMPD_taskgroup because it is the taskgroup
2514 // which is canceled. This is handled here because it is the task's cleanup
2515 // block which should be branched to.
2516 pushCancelFinalizationCB(cancelTerminators, builder, ompBuilder, taskOp,
2517 llvm::omp::Directive::OMPD_taskgroup);
2518
2520 buildDependData(taskOp.getDependKinds(), taskOp.getDependVars(),
2521 moduleTranslation, dds);
2522
2523 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2524 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2525 moduleTranslation.getOpenMPBuilder()->createTask(
2526 ompLoc, allocaIP, bodyCB, !taskOp.getUntied(),
2527 moduleTranslation.lookupValue(taskOp.getFinal()),
2528 moduleTranslation.lookupValue(taskOp.getIfExpr()), dds,
2529 taskOp.getMergeable(),
2530 moduleTranslation.lookupValue(taskOp.getEventHandle()),
2531 moduleTranslation.lookupValue(taskOp.getPriority()));
2532
2533 if (failed(handleError(afterIP, *taskOp)))
2534 return failure();
2535
2536 // Set the correct branch target for task cancellation
2537 popCancelFinalizationCB(cancelTerminators, ompBuilder, afterIP.get());
2538
2539 builder.restoreIP(*afterIP);
2540 return success();
2541}
2542
2543// Converts an OpenMP taskloop construct into LLVM IR using OpenMPIRBuilder.
2544static LogicalResult
2545convertOmpTaskloopOp(Operation &opInst, llvm::IRBuilderBase &builder,
2546 LLVM::ModuleTranslation &moduleTranslation) {
2547 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2548 auto taskloopOp = cast<omp::TaskloopOp>(opInst);
2549 if (failed(checkImplementationStatus(opInst)))
2550 return failure();
2551
2552 // It stores the pointer of allocated firstprivate copies,
2553 // which can be used later for freeing the allocated space.
2554 SmallVector<llvm::Value *> llvmFirstPrivateVars;
2555 PrivateVarsInfo privateVarsInfo(taskloopOp);
2556 TaskContextStructManager taskStructMgr{builder, moduleTranslation,
2557 privateVarsInfo.privatizers};
2558
2559 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2560 findAllocaInsertPoint(builder, moduleTranslation);
2561
2562 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end());
2563 llvm::BasicBlock *taskloopStartBlock = llvm::BasicBlock::Create(
2564 builder.getContext(), "omp.taskloop.start",
2565 /*Parent=*/builder.GetInsertBlock()->getParent());
2566 llvm::Instruction *branchToTaskloopStartBlock =
2567 builder.CreateBr(taskloopStartBlock);
2568 builder.SetInsertPoint(branchToTaskloopStartBlock);
2569
2570 llvm::BasicBlock *copyBlock =
2571 splitBB(builder, /*CreateBranch=*/true, "omp.private.copy");
2572 llvm::BasicBlock *initBlock =
2573 splitBB(builder, /*CreateBranch=*/true, "omp.private.init");
2574
2576 moduleTranslation, allocaIP);
2577
2578 // Allocate and initialize private variables
2579 builder.SetInsertPoint(initBlock->getTerminator());
2580
2581 // TODO: don't allocate if the loop has zero iterations.
2582 taskStructMgr.generateTaskContextStruct();
2583 taskStructMgr.createGEPsToPrivateVars();
2584
2585 llvmFirstPrivateVars.resize(privateVarsInfo.blockArgs.size());
2586
2587 for (auto [i, zip] : llvm::enumerate(llvm::zip_equal(
2588 privateVarsInfo.privatizers, privateVarsInfo.mlirVars,
2589 privateVarsInfo.blockArgs, taskStructMgr.getLLVMPrivateVarGEPs()))) {
2590 auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVarAlloc] = zip;
2591 // To be handled inside the taskloop.
2592 if (!privDecl.readsFromMold())
2593 continue;
2594 assert(llvmPrivateVarAlloc &&
2595 "reads from mold so shouldn't have been skipped");
2596
2597 llvm::Expected<llvm::Value *> privateVarOrErr =
2598 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2599 blockArg, llvmPrivateVarAlloc, initBlock);
2600 if (!privateVarOrErr)
2601 return handleError(privateVarOrErr, *taskloopOp.getOperation());
2602
2603 llvmFirstPrivateVars[i] = privateVarOrErr.get();
2604
2605 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2606 builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
2607
2608 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
2609 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2610 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
2611 // Load it so we have the value pointed to by the GEP
2612 llvmPrivateVarAlloc = builder.CreateLoad(privateVarOrErr.get()->getType(),
2613 llvmPrivateVarAlloc);
2614 }
2615 assert(llvmPrivateVarAlloc->getType() ==
2616 moduleTranslation.convertType(blockArg.getType()));
2617 }
2618
2619 // firstprivate copy region
2620 setInsertPointForPossiblyEmptyBlock(builder, copyBlock);
2621 if (failed(copyFirstPrivateVars(
2622 taskloopOp, builder, moduleTranslation, privateVarsInfo.mlirVars,
2623 taskStructMgr.getLLVMPrivateVarGEPs(), privateVarsInfo.privatizers,
2624 taskloopOp.getPrivateNeedsBarrier())))
2625 return llvm::failure();
2626
2627 // Set up inserttion point for call to createTaskloop()
2628 builder.SetInsertPoint(taskloopStartBlock);
2629
2630 auto bodyCB = [&](InsertPointTy allocaIP,
2631 InsertPointTy codegenIP) -> llvm::Error {
2632 // Save the alloca insertion point on ModuleTranslation stack for use in
2633 // nested regions.
2635 moduleTranslation, allocaIP);
2636
2637 // translate the body of the taskloop:
2638 builder.restoreIP(codegenIP);
2639
2640 llvm::BasicBlock *privInitBlock = nullptr;
2641 privateVarsInfo.llvmVars.resize(privateVarsInfo.blockArgs.size());
2642 for (auto [i, zip] : llvm::enumerate(llvm::zip_equal(
2643 privateVarsInfo.blockArgs, privateVarsInfo.privatizers,
2644 privateVarsInfo.mlirVars))) {
2645 auto [blockArg, privDecl, mlirPrivVar] = zip;
2646 // This is handled before the task executes
2647 if (privDecl.readsFromMold())
2648 continue;
2649
2650 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2651 llvm::Type *llvmAllocType =
2652 moduleTranslation.convertType(privDecl.getType());
2653 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
2654 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
2655 llvmAllocType, /*ArraySize=*/nullptr, "omp.private.alloc");
2656
2657 llvm::Expected<llvm::Value *> privateVarOrError =
2658 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2659 blockArg, llvmPrivateVar, privInitBlock);
2660 if (!privateVarOrError)
2661 return privateVarOrError.takeError();
2662 moduleTranslation.mapValue(blockArg, privateVarOrError.get());
2663 privateVarsInfo.llvmVars[i] = privateVarOrError.get();
2664 }
2665
2666 taskStructMgr.createGEPsToPrivateVars();
2667 for (auto [i, llvmPrivVar] :
2668 llvm::enumerate(taskStructMgr.getLLVMPrivateVarGEPs())) {
2669 if (!llvmPrivVar) {
2670 assert(privateVarsInfo.llvmVars[i] &&
2671 "This is added in the loop above");
2672 continue;
2673 }
2674 privateVarsInfo.llvmVars[i] = llvmPrivVar;
2675 }
2676
2677 // Find and map the addresses of each variable within the taskloop context
2678 // structure
2679 for (auto [blockArg, llvmPrivateVar, privateDecl] :
2680 llvm::zip_equal(privateVarsInfo.blockArgs, privateVarsInfo.llvmVars,
2681 privateVarsInfo.privatizers)) {
2682 // This was handled above.
2683 if (!privateDecl.readsFromMold())
2684 continue;
2685 // Fix broken pass-by-value case for Fortran character boxes
2686 if (!mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2687 llvmPrivateVar = builder.CreateLoad(
2688 moduleTranslation.convertType(blockArg.getType()), llvmPrivateVar);
2689 }
2690 assert(llvmPrivateVar->getType() ==
2691 moduleTranslation.convertType(blockArg.getType()));
2692 moduleTranslation.mapValue(blockArg, llvmPrivateVar);
2693 }
2694
2695 auto continuationBlockOrError =
2696 convertOmpOpRegions(taskloopOp.getRegion(), "omp.taskloop.region",
2697 builder, moduleTranslation);
2698
2699 if (failed(handleError(continuationBlockOrError, opInst)))
2700 return llvm::make_error<PreviouslyReportedError>();
2701
2702 builder.SetInsertPoint(continuationBlockOrError.get()->getTerminator());
2703
2704 // This is freeing the private variables as mapped inside of the task: these
2705 // will be per-task private copies possibly after task duplication. This is
2706 // handled transparently by how these are passed to the structure passed
2707 // into the outlined function. When the task is duplicated, that structure
2708 // is duplicated too.
2709 if (failed(cleanupPrivateVars(builder, moduleTranslation,
2710 taskloopOp.getLoc(), privateVarsInfo.llvmVars,
2711 privateVarsInfo.privatizers)))
2712 return llvm::make_error<PreviouslyReportedError>();
2713 // Similarly, the task context structure freed inside the task is the
2714 // per-task copy after task duplication.
2715 taskStructMgr.freeStructPtr();
2716
2717 return llvm::Error::success();
2718 };
2719
2720 // Taskloop divides into an appropriate number of tasks by repeatedly
2721 // duplicating the original task. Each time this is done, the task context
2722 // structure must be duplicated too.
2723 auto taskDupCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP,
2724 llvm::Value *destPtr, llvm::Value *srcPtr)
2726 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2727 builder.restoreIP(codegenIP);
2728
2729 llvm::Type *ptrTy =
2730 builder.getPtrTy(srcPtr->getType()->getPointerAddressSpace());
2731 llvm::Value *src =
2732 builder.CreateLoad(ptrTy, srcPtr, "omp.taskloop.context.src");
2733
2734 TaskContextStructManager &srcStructMgr = taskStructMgr;
2735 TaskContextStructManager destStructMgr(builder, moduleTranslation,
2736 privateVarsInfo.privatizers);
2737 destStructMgr.generateTaskContextStruct();
2738 llvm::Value *dest = destStructMgr.getStructPtr();
2739 dest->setName("omp.taskloop.context.dest");
2740 builder.CreateStore(dest, destPtr);
2741
2743 srcStructMgr.createGEPsToPrivateVars(src);
2745 destStructMgr.createGEPsToPrivateVars(dest);
2746
2747 // Inline init regions.
2748 for (auto [privDecl, mold, blockArg, llvmPrivateVarAlloc] :
2749 llvm::zip_equal(privateVarsInfo.privatizers, srcGEPs,
2750 privateVarsInfo.blockArgs, destGEPs)) {
2751 // To be handled inside task body.
2752 if (!privDecl.readsFromMold())
2753 continue;
2754 assert(llvmPrivateVarAlloc &&
2755 "reads from mold so shouldn't have been skipped");
2756
2757 llvm::Expected<llvm::Value *> privateVarOrErr =
2758 initPrivateVar(builder, moduleTranslation, privDecl, mold, blockArg,
2759 llvmPrivateVarAlloc, builder.GetInsertBlock());
2760 if (!privateVarOrErr)
2761 return privateVarOrErr.takeError();
2762
2764
2765 // TODO: this is a bit of a hack for Fortran character boxes.
2766 // Character boxes are passed by value into the init region and then the
2767 // initialized character box is yielded by value. Here we need to store
2768 // the yielded value into the private allocation, and load the private
2769 // allocation to match the type expected by region block arguments.
2770 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
2771 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2772 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
2773 // Load it so we have the value pointed to by the GEP
2774 llvmPrivateVarAlloc = builder.CreateLoad(
2775 privateVarOrErr.get()->getType(), llvmPrivateVarAlloc);
2776 }
2777 assert(llvmPrivateVarAlloc->getType() ==
2778 moduleTranslation.convertType(blockArg.getType()));
2779
2780 // Mapping blockArg -> llvmPrivateVarAlloc is done inside the body
2781 // callback so that OpenMPIRBuilder doesn't try to pass each GEP address
2782 // through a stack allocated structure.
2783 }
2784
2785 if (failed(copyFirstPrivateVars(
2786 &opInst, builder, moduleTranslation, srcGEPs, destGEPs,
2787 privateVarsInfo.privatizers, taskloopOp.getPrivateNeedsBarrier())))
2788 return llvm::make_error<PreviouslyReportedError>();
2789
2790 return builder.saveIP();
2791 };
2792
2793 auto loopOp = cast<omp::LoopNestOp>(taskloopOp.getWrappedLoop());
2794
2795 auto loopInfo = [&]() -> llvm::Expected<llvm::CanonicalLoopInfo *> {
2796 llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
2797 return loopInfo;
2798 };
2799
2800 llvm::Value *ifCond = nullptr;
2801 llvm::Value *grainsize = nullptr;
2802 int sched = 0; // default
2803 mlir::Value grainsizeVal = taskloopOp.getGrainsize();
2804 mlir::Value numTasksVal = taskloopOp.getNumTasks();
2805 if (Value ifVar = taskloopOp.getIfExpr())
2806 ifCond = moduleTranslation.lookupValue(ifVar);
2807 if (grainsizeVal) {
2808 grainsize = moduleTranslation.lookupValue(grainsizeVal);
2809 sched = 1; // grainsize
2810 } else if (numTasksVal) {
2811 grainsize = moduleTranslation.lookupValue(numTasksVal);
2812 sched = 2; // num_tasks
2813 }
2814
2815 llvm::OpenMPIRBuilder::TaskDupCallbackTy taskDupOrNull = nullptr;
2816 if (taskStructMgr.getStructPtr())
2817 taskDupOrNull = taskDupCB;
2818
2819 llvm::OpenMPIRBuilder &ompBuilder = *moduleTranslation.getOpenMPBuilder();
2820 SmallVector<llvm::BranchInst *> cancelTerminators;
2821 // The directive to match here is OMPD_taskgroup because it is the
2822 // taskgroup which is canceled. This is handled here because it is the
2823 // task's cleanup block which should be branched to. It doesn't depend upon
2824 // nogroup because even in that case the taskloop might still be inside an
2825 // explicit taskgroup.
2826 pushCancelFinalizationCB(cancelTerminators, builder, ompBuilder, taskloopOp,
2827 llvm::omp::Directive::OMPD_taskgroup);
2828
2829 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2830 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2831 moduleTranslation.getOpenMPBuilder()->createTaskloop(
2832 ompLoc, allocaIP, bodyCB, loopInfo,
2833 moduleTranslation.lookupValue(loopOp.getLoopLowerBounds()[0]),
2834 moduleTranslation.lookupValue(loopOp.getLoopUpperBounds()[0]),
2835 moduleTranslation.lookupValue(loopOp.getLoopSteps()[0]),
2836 taskloopOp.getUntied(), ifCond, grainsize, taskloopOp.getNogroup(),
2837 sched, moduleTranslation.lookupValue(taskloopOp.getFinal()),
2838 taskloopOp.getMergeable(),
2839 moduleTranslation.lookupValue(taskloopOp.getPriority()),
2840 taskDupOrNull, taskStructMgr.getStructPtr());
2841
2842 if (failed(handleError(afterIP, opInst)))
2843 return failure();
2844
2845 popCancelFinalizationCB(cancelTerminators, ompBuilder, afterIP.get());
2846
2847 builder.restoreIP(*afterIP);
2848 return success();
2849}
2850
2851/// Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder.
2852static LogicalResult
2853convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder,
2854 LLVM::ModuleTranslation &moduleTranslation) {
2855 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2856 if (failed(checkImplementationStatus(*tgOp)))
2857 return failure();
2858
2859 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
2860 builder.restoreIP(codegenIP);
2861 return convertOmpOpRegions(tgOp.getRegion(), "omp.taskgroup.region",
2862 builder, moduleTranslation)
2863 .takeError();
2864 };
2865
2866 InsertPointTy allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
2867 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2868 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2869 moduleTranslation.getOpenMPBuilder()->createTaskgroup(ompLoc, allocaIP,
2870 bodyCB);
2871
2872 if (failed(handleError(afterIP, *tgOp)))
2873 return failure();
2874
2875 builder.restoreIP(*afterIP);
2876 return success();
2877}
2878
2879static LogicalResult
2880convertOmpTaskwaitOp(omp::TaskwaitOp twOp, llvm::IRBuilderBase &builder,
2881 LLVM::ModuleTranslation &moduleTranslation) {
2882 if (failed(checkImplementationStatus(*twOp)))
2883 return failure();
2884
2885 moduleTranslation.getOpenMPBuilder()->createTaskwait(builder.saveIP());
2886 return success();
2887}
2888
2889/// Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
2890static LogicalResult
2891convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
2892 LLVM::ModuleTranslation &moduleTranslation) {
2893 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
2894 auto wsloopOp = cast<omp::WsloopOp>(opInst);
2895 if (failed(checkImplementationStatus(opInst)))
2896 return failure();
2897
2898 auto loopOp = cast<omp::LoopNestOp>(wsloopOp.getWrappedLoop());
2899 llvm::ArrayRef<bool> isByRef = getIsByRef(wsloopOp.getReductionByref());
2900 assert(isByRef.size() == wsloopOp.getNumReductionVars());
2901
2902 // Static is the default.
2903 auto schedule =
2904 wsloopOp.getScheduleKind().value_or(omp::ClauseScheduleKind::Static);
2905
2906 // Find the loop configuration.
2907 llvm::Value *step = moduleTranslation.lookupValue(loopOp.getLoopSteps()[0]);
2908 llvm::Type *ivType = step->getType();
2909 llvm::Value *chunk = nullptr;
2910 if (wsloopOp.getScheduleChunk()) {
2911 llvm::Value *chunkVar =
2912 moduleTranslation.lookupValue(wsloopOp.getScheduleChunk());
2913 chunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
2914 }
2915
2916 omp::DistributeOp distributeOp = nullptr;
2917 llvm::Value *distScheduleChunk = nullptr;
2918 bool hasDistSchedule = false;
2919 if (llvm::isa_and_present<omp::DistributeOp>(opInst.getParentOp())) {
2920 distributeOp = cast<omp::DistributeOp>(opInst.getParentOp());
2921 hasDistSchedule = distributeOp.getDistScheduleStatic();
2922 if (distributeOp.getDistScheduleChunkSize()) {
2923 llvm::Value *chunkVar = moduleTranslation.lookupValue(
2924 distributeOp.getDistScheduleChunkSize());
2925 distScheduleChunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
2926 }
2927 }
2928
2929 PrivateVarsInfo privateVarsInfo(wsloopOp);
2930
2932 collectReductionDecls(wsloopOp, reductionDecls);
2933 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2934 findAllocaInsertPoint(builder, moduleTranslation);
2935
2936 SmallVector<llvm::Value *> privateReductionVariables(
2937 wsloopOp.getNumReductionVars());
2938
2940 builder, moduleTranslation, privateVarsInfo, allocaIP);
2941 if (handleError(afterAllocas, opInst).failed())
2942 return failure();
2943
2944 DenseMap<Value, llvm::Value *> reductionVariableMap;
2945
2946 MutableArrayRef<BlockArgument> reductionArgs =
2947 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
2948
2949 SmallVector<DeferredStore> deferredStores;
2950
2951 if (failed(allocReductionVars(wsloopOp, reductionArgs, builder,
2952 moduleTranslation, allocaIP, reductionDecls,
2953 privateReductionVariables, reductionVariableMap,
2954 deferredStores, isByRef)))
2955 return failure();
2956
2957 if (handleError(initPrivateVars(builder, moduleTranslation, privateVarsInfo),
2958 opInst)
2959 .failed())
2960 return failure();
2961
2962 if (failed(copyFirstPrivateVars(
2963 wsloopOp, builder, moduleTranslation, privateVarsInfo.mlirVars,
2964 privateVarsInfo.llvmVars, privateVarsInfo.privatizers,
2965 wsloopOp.getPrivateNeedsBarrier())))
2966 return failure();
2967
2968 assert(afterAllocas.get()->getSinglePredecessor());
2969 if (failed(initReductionVars(wsloopOp, reductionArgs, builder,
2970 moduleTranslation,
2971 afterAllocas.get()->getSinglePredecessor(),
2972 reductionDecls, privateReductionVariables,
2973 reductionVariableMap, isByRef, deferredStores)))
2974 return failure();
2975
2976 // TODO: Handle doacross loops when the ordered clause has a parameter.
2977 bool isOrdered = wsloopOp.getOrdered().has_value();
2978 std::optional<omp::ScheduleModifier> scheduleMod = wsloopOp.getScheduleMod();
2979 bool isSimd = wsloopOp.getScheduleSimd();
2980 bool loopNeedsBarrier = !wsloopOp.getNowait();
2981
2982 // The only legal way for the direct parent to be omp.distribute is that this
2983 // represents 'distribute parallel do'. Otherwise, this is a regular
2984 // worksharing loop.
2985 llvm::omp::WorksharingLoopType workshareLoopType =
2986 llvm::isa_and_present<omp::DistributeOp>(opInst.getParentOp())
2987 ? llvm::omp::WorksharingLoopType::DistributeForStaticLoop
2988 : llvm::omp::WorksharingLoopType::ForStaticLoop;
2989
2990 SmallVector<llvm::BranchInst *> cancelTerminators;
2991 pushCancelFinalizationCB(cancelTerminators, builder, *ompBuilder, wsloopOp,
2992 llvm::omp::Directive::OMPD_for);
2993
2994 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2995
2996 // Initialize linear variables and linear step
2997 LinearClauseProcessor linearClauseProcessor;
2998
2999 if (!wsloopOp.getLinearVars().empty()) {
3000 auto linearVarTypes = wsloopOp.getLinearVarTypes().value();
3001 for (mlir::Attribute linearVarType : linearVarTypes)
3002 linearClauseProcessor.registerType(moduleTranslation, linearVarType);
3003
3004 for (auto [idx, linearVar] : llvm::enumerate(wsloopOp.getLinearVars()))
3005 linearClauseProcessor.createLinearVar(
3006 builder, moduleTranslation, moduleTranslation.lookupValue(linearVar),
3007 idx);
3008 for (mlir::Value linearStep : wsloopOp.getLinearStepVars())
3009 linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
3010 }
3011
3013 wsloopOp.getRegion(), "omp.wsloop.region", builder, moduleTranslation);
3014
3015 if (failed(handleError(regionBlock, opInst)))
3016 return failure();
3017
3018 llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
3019
3020 // Emit Initialization and Update IR for linear variables
3021 if (!wsloopOp.getLinearVars().empty()) {
3022 linearClauseProcessor.initLinearVar(builder, moduleTranslation,
3023 loopInfo->getPreheader());
3024 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
3025 moduleTranslation.getOpenMPBuilder()->createBarrier(
3026 builder.saveIP(), llvm::omp::OMPD_barrier);
3027 if (failed(handleError(afterBarrierIP, *loopOp)))
3028 return failure();
3029 builder.restoreIP(*afterBarrierIP);
3030 linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
3031 loopInfo->getIndVar());
3032 linearClauseProcessor.splitLinearFiniBB(builder, loopInfo->getExit());
3033 }
3034
3035 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
3036
3037 // Check if we can generate no-loop kernel
3038 bool noLoopMode = false;
3039 omp::TargetOp targetOp = wsloopOp->getParentOfType<mlir::omp::TargetOp>();
3040 if (targetOp) {
3041 Operation *targetCapturedOp = targetOp.getInnermostCapturedOmpOp();
3042 // We need this check because, without it, noLoopMode would be set to true
3043 // for every omp.wsloop nested inside a no-loop SPMD target region, even if
3044 // that loop is not the top-level SPMD one.
3045 if (loopOp == targetCapturedOp) {
3046 omp::TargetRegionFlags kernelFlags =
3047 targetOp.getKernelExecFlags(targetCapturedOp);
3048 if (omp::bitEnumContainsAll(kernelFlags,
3049 omp::TargetRegionFlags::spmd |
3050 omp::TargetRegionFlags::no_loop) &&
3051 !omp::bitEnumContainsAny(kernelFlags,
3052 omp::TargetRegionFlags::generic))
3053 noLoopMode = true;
3054 }
3055 }
3056
3057 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
3058 ompBuilder->applyWorkshareLoop(
3059 ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
3060 convertToScheduleKind(schedule), chunk, isSimd,
3061 scheduleMod == omp::ScheduleModifier::monotonic,
3062 scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
3063 workshareLoopType, noLoopMode, hasDistSchedule, distScheduleChunk);
3064
3065 if (failed(handleError(wsloopIP, opInst)))
3066 return failure();
3067
3068 // Emit finalization and in-place rewrites for linear vars.
3069 if (!wsloopOp.getLinearVars().empty()) {
3070 llvm::OpenMPIRBuilder::InsertPointTy oldIP = builder.saveIP();
3071 assert(loopInfo->getLastIter() &&
3072 "`lastiter` in CanonicalLoopInfo is nullptr");
3073 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
3074 linearClauseProcessor.finalizeLinearVar(builder, moduleTranslation,
3075 loopInfo->getLastIter());
3076 if (failed(handleError(afterBarrierIP, *loopOp)))
3077 return failure();
3078 for (size_t index = 0; index < wsloopOp.getLinearVars().size(); index++)
3079 linearClauseProcessor.rewriteInPlace(builder, "omp.loop_nest.region",
3080 index);
3081 builder.restoreIP(oldIP);
3082 }
3083
3084 // Set the correct branch target for task cancellation
3085 popCancelFinalizationCB(cancelTerminators, *ompBuilder, wsloopIP.get());
3086
3087 // Process the reductions if required.
3088 if (failed(createReductionsAndCleanup(
3089 wsloopOp, builder, moduleTranslation, allocaIP, reductionDecls,
3090 privateReductionVariables, isByRef, wsloopOp.getNowait(),
3091 /*isTeamsReduction=*/false)))
3092 return failure();
3093
3094 return cleanupPrivateVars(builder, moduleTranslation, wsloopOp.getLoc(),
3095 privateVarsInfo.llvmVars,
3096 privateVarsInfo.privatizers);
3097}
3098
3099/// Converts the OpenMP parallel operation to LLVM IR.
3100static LogicalResult
3101convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
3102 LLVM::ModuleTranslation &moduleTranslation) {
3103 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
3104 ArrayRef<bool> isByRef = getIsByRef(opInst.getReductionByref());
3105 assert(isByRef.size() == opInst.getNumReductionVars());
3106 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3107 bool isCancellable = constructIsCancellable(opInst);
3108
3109 if (failed(checkImplementationStatus(*opInst)))
3110 return failure();
3111
3112 PrivateVarsInfo privateVarsInfo(opInst);
3113
3114 // Collect reduction declarations
3116 collectReductionDecls(opInst, reductionDecls);
3117 SmallVector<llvm::Value *> privateReductionVariables(
3118 opInst.getNumReductionVars());
3119 SmallVector<DeferredStore> deferredStores;
3120
3121 auto bodyGenCB = [&](InsertPointTy allocaIP,
3122 InsertPointTy codeGenIP) -> llvm::Error {
3124 builder, moduleTranslation, privateVarsInfo, allocaIP);
3125 if (handleError(afterAllocas, *opInst).failed())
3126 return llvm::make_error<PreviouslyReportedError>();
3127
3128 // Allocate reduction vars
3129 DenseMap<Value, llvm::Value *> reductionVariableMap;
3130
3131 MutableArrayRef<BlockArgument> reductionArgs =
3132 cast<omp::BlockArgOpenMPOpInterface>(*opInst).getReductionBlockArgs();
3133
3134 allocaIP =
3135 InsertPointTy(allocaIP.getBlock(),
3136 allocaIP.getBlock()->getTerminator()->getIterator());
3137
3138 if (failed(allocReductionVars(
3139 opInst, reductionArgs, builder, moduleTranslation, allocaIP,
3140 reductionDecls, privateReductionVariables, reductionVariableMap,
3141 deferredStores, isByRef)))
3142 return llvm::make_error<PreviouslyReportedError>();
3143
3144 assert(afterAllocas.get()->getSinglePredecessor());
3145 builder.restoreIP(codeGenIP);
3146
3147 if (handleError(
3148 initPrivateVars(builder, moduleTranslation, privateVarsInfo),
3149 *opInst)
3150 .failed())
3151 return llvm::make_error<PreviouslyReportedError>();
3152
3153 if (failed(copyFirstPrivateVars(
3154 opInst, builder, moduleTranslation, privateVarsInfo.mlirVars,
3155 privateVarsInfo.llvmVars, privateVarsInfo.privatizers,
3156 opInst.getPrivateNeedsBarrier())))
3157 return llvm::make_error<PreviouslyReportedError>();
3158
3159 if (failed(
3160 initReductionVars(opInst, reductionArgs, builder, moduleTranslation,
3161 afterAllocas.get()->getSinglePredecessor(),
3162 reductionDecls, privateReductionVariables,
3163 reductionVariableMap, isByRef, deferredStores)))
3164 return llvm::make_error<PreviouslyReportedError>();
3165
3166 // Save the alloca insertion point on ModuleTranslation stack for use in
3167 // nested regions.
3169 moduleTranslation, allocaIP);
3170
3171 // ParallelOp has only one region associated with it.
3173 opInst.getRegion(), "omp.par.region", builder, moduleTranslation);
3174 if (!regionBlock)
3175 return regionBlock.takeError();
3176
3177 // Process the reductions if required.
3178 if (opInst.getNumReductionVars() > 0) {
3179 // Collect reduction info
3181 SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
3183 owningReductionGenRefDataPtrGens;
3185 collectReductionInfo(opInst, builder, moduleTranslation, reductionDecls,
3186 owningReductionGens, owningAtomicReductionGens,
3187 owningReductionGenRefDataPtrGens,
3188 privateReductionVariables, reductionInfos, isByRef);
3189
3190 // Move to region cont block
3191 builder.SetInsertPoint((*regionBlock)->getTerminator());
3192
3193 // Generate reductions from info
3194 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
3195 builder.SetInsertPoint(tempTerminator);
3196
3197 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
3198 ompBuilder->createReductions(
3199 builder.saveIP(), allocaIP, reductionInfos, isByRef,
3200 /*IsNoWait=*/false, /*IsTeamsReduction=*/false);
3201 if (!contInsertPoint)
3202 return contInsertPoint.takeError();
3203
3204 if (!contInsertPoint->getBlock())
3205 return llvm::make_error<PreviouslyReportedError>();
3206
3207 tempTerminator->eraseFromParent();
3208 builder.restoreIP(*contInsertPoint);
3209 }
3210
3211 return llvm::Error::success();
3212 };
3213
3214 auto privCB = [](InsertPointTy allocaIP, InsertPointTy codeGenIP,
3215 llvm::Value &, llvm::Value &val, llvm::Value *&replVal) {
3216 // tell OpenMPIRBuilder not to do anything. We handled Privatisation in
3217 // bodyGenCB.
3218 replVal = &val;
3219 return codeGenIP;
3220 };
3221
3222 // TODO: Perform finalization actions for variables. This has to be
3223 // called for variables which have destructors/finalizers.
3224 auto finiCB = [&](InsertPointTy codeGenIP) -> llvm::Error {
3225 InsertPointTy oldIP = builder.saveIP();
3226 builder.restoreIP(codeGenIP);
3227
3228 // if the reduction has a cleanup region, inline it here to finalize the
3229 // reduction variables
3230 SmallVector<Region *> reductionCleanupRegions;
3231 llvm::transform(reductionDecls, std::back_inserter(reductionCleanupRegions),
3232 [](omp::DeclareReductionOp reductionDecl) {
3233 return &reductionDecl.getCleanupRegion();
3234 });
3235 if (failed(inlineOmpRegionCleanup(
3236 reductionCleanupRegions, privateReductionVariables,
3237 moduleTranslation, builder, "omp.reduction.cleanup")))
3238 return llvm::createStringError(
3239 "failed to inline `cleanup` region of `omp.declare_reduction`");
3240
3241 if (failed(cleanupPrivateVars(builder, moduleTranslation, opInst.getLoc(),
3242 privateVarsInfo.llvmVars,
3243 privateVarsInfo.privatizers)))
3244 return llvm::make_error<PreviouslyReportedError>();
3245
3246 // If we could be performing cancellation, add the cancellation barrier on
3247 // the way out of the outlined region.
3248 if (isCancellable) {
3249 auto IPOrErr = ompBuilder->createBarrier(
3250 llvm::OpenMPIRBuilder::LocationDescription(builder),
3251 llvm::omp::Directive::OMPD_unknown,
3252 /* ForceSimpleCall */ false,
3253 /* CheckCancelFlag */ false);
3254 if (!IPOrErr)
3255 return IPOrErr.takeError();
3256 }
3257
3258 builder.restoreIP(oldIP);
3259 return llvm::Error::success();
3260 };
3261
3262 llvm::Value *ifCond = nullptr;
3263 if (auto ifVar = opInst.getIfExpr())
3264 ifCond = moduleTranslation.lookupValue(ifVar);
3265 llvm::Value *numThreads = nullptr;
3266 if (auto numThreadsVar = opInst.getNumThreads())
3267 numThreads = moduleTranslation.lookupValue(numThreadsVar);
3268 auto pbKind = llvm::omp::OMP_PROC_BIND_default;
3269 if (auto bind = opInst.getProcBindKind())
3270 pbKind = getProcBindKind(*bind);
3271
3272 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3273 findAllocaInsertPoint(builder, moduleTranslation);
3274 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3275
3276 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3277 ompBuilder->createParallel(ompLoc, allocaIP, bodyGenCB, privCB, finiCB,
3278 ifCond, numThreads, pbKind, isCancellable);
3279
3280 if (failed(handleError(afterIP, *opInst)))
3281 return failure();
3282
3283 builder.restoreIP(*afterIP);
3284 return success();
3285}
3286
3287/// Convert Order attribute to llvm::omp::OrderKind.
3288static llvm::omp::OrderKind
3289convertOrderKind(std::optional<omp::ClauseOrderKind> o) {
3290 if (!o)
3291 return llvm::omp::OrderKind::OMP_ORDER_unknown;
3292 switch (*o) {
3293 case omp::ClauseOrderKind::Concurrent:
3294 return llvm::omp::OrderKind::OMP_ORDER_concurrent;
3295 }
3296 llvm_unreachable("Unknown ClauseOrderKind kind");
3297}
3298
3299/// Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder.
3300static LogicalResult
3301convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
3302 LLVM::ModuleTranslation &moduleTranslation) {
3303 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3304 auto simdOp = cast<omp::SimdOp>(opInst);
3305
3306 if (failed(checkImplementationStatus(opInst)))
3307 return failure();
3308
3309 PrivateVarsInfo privateVarsInfo(simdOp);
3310
3311 MutableArrayRef<BlockArgument> reductionArgs =
3312 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
3313 DenseMap<Value, llvm::Value *> reductionVariableMap;
3314 SmallVector<llvm::Value *> privateReductionVariables(
3315 simdOp.getNumReductionVars());
3316 SmallVector<DeferredStore> deferredStores;
3318 collectReductionDecls(simdOp, reductionDecls);
3319 llvm::ArrayRef<bool> isByRef = getIsByRef(simdOp.getReductionByref());
3320 assert(isByRef.size() == simdOp.getNumReductionVars());
3321
3322 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3323 findAllocaInsertPoint(builder, moduleTranslation);
3324
3326 builder, moduleTranslation, privateVarsInfo, allocaIP);
3327 if (handleError(afterAllocas, opInst).failed())
3328 return failure();
3329
3330 // Initialize linear variables and linear step
3331 LinearClauseProcessor linearClauseProcessor;
3332
3333 if (!simdOp.getLinearVars().empty()) {
3334 auto linearVarTypes = simdOp.getLinearVarTypes().value();
3335 for (mlir::Attribute linearVarType : linearVarTypes)
3336 linearClauseProcessor.registerType(moduleTranslation, linearVarType);
3337 for (auto [idx, linearVar] : llvm::enumerate(simdOp.getLinearVars())) {
3338 bool isImplicit = false;
3339 for (auto [mlirPrivVar, llvmPrivateVar] : llvm::zip_equal(
3340 privateVarsInfo.mlirVars, privateVarsInfo.llvmVars)) {
3341 // If the linear variable is implicit, reuse the already
3342 // existing llvm::Value
3343 if (linearVar == mlirPrivVar) {
3344 isImplicit = true;
3345 linearClauseProcessor.createLinearVar(builder, moduleTranslation,
3346 llvmPrivateVar, idx);
3347 break;
3348 }
3349 }
3350
3351 if (!isImplicit)
3352 linearClauseProcessor.createLinearVar(
3353 builder, moduleTranslation,
3354 moduleTranslation.lookupValue(linearVar), idx);
3355 }
3356 for (mlir::Value linearStep : simdOp.getLinearStepVars())
3357 linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
3358 }
3359
3360 if (failed(allocReductionVars(simdOp, reductionArgs, builder,
3361 moduleTranslation, allocaIP, reductionDecls,
3362 privateReductionVariables, reductionVariableMap,
3363 deferredStores, isByRef)))
3364 return failure();
3365
3366 if (handleError(initPrivateVars(builder, moduleTranslation, privateVarsInfo),
3367 opInst)
3368 .failed())
3369 return failure();
3370
3371 // No call to copyFirstPrivateVars because FIRSTPRIVATE is not allowed for
3372 // SIMD.
3373
3374 assert(afterAllocas.get()->getSinglePredecessor());
3375 if (failed(initReductionVars(simdOp, reductionArgs, builder,
3376 moduleTranslation,
3377 afterAllocas.get()->getSinglePredecessor(),
3378 reductionDecls, privateReductionVariables,
3379 reductionVariableMap, isByRef, deferredStores)))
3380 return failure();
3381
3382 llvm::ConstantInt *simdlen = nullptr;
3383 if (std::optional<uint64_t> simdlenVar = simdOp.getSimdlen())
3384 simdlen = builder.getInt64(simdlenVar.value());
3385
3386 llvm::ConstantInt *safelen = nullptr;
3387 if (std::optional<uint64_t> safelenVar = simdOp.getSafelen())
3388 safelen = builder.getInt64(safelenVar.value());
3389
3390 llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars;
3391 llvm::omp::OrderKind order = convertOrderKind(simdOp.getOrder());
3392
3393 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
3394 std::optional<ArrayAttr> alignmentValues = simdOp.getAlignments();
3395 mlir::OperandRange operands = simdOp.getAlignedVars();
3396 for (size_t i = 0; i < operands.size(); ++i) {
3397 llvm::Value *alignment = nullptr;
3398 llvm::Value *llvmVal = moduleTranslation.lookupValue(operands[i]);
3399 llvm::Type *ty = llvmVal->getType();
3400
3401 auto intAttr = cast<IntegerAttr>((*alignmentValues)[i]);
3402 alignment = builder.getInt64(intAttr.getInt());
3403 assert(ty->isPointerTy() && "Invalid type for aligned variable");
3404 assert(alignment && "Invalid alignment value");
3405
3406 // Check if the alignment value is not a power of 2. If so, skip emitting
3407 // alignment.
3408 if (!intAttr.getValue().isPowerOf2())
3409 continue;
3410
3411 auto curInsert = builder.saveIP();
3412 builder.SetInsertPoint(sourceBlock);
3413 llvmVal = builder.CreateLoad(ty, llvmVal);
3414 builder.restoreIP(curInsert);
3415 alignedVars[llvmVal] = alignment;
3416 }
3417
3419 simdOp.getRegion(), "omp.simd.region", builder, moduleTranslation);
3420
3421 if (failed(handleError(regionBlock, opInst)))
3422 return failure();
3423
3424 llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
3425 // Emit Initialization for linear variables
3426 if (simdOp.getLinearVars().size()) {
3427 linearClauseProcessor.initLinearVar(builder, moduleTranslation,
3428 loopInfo->getPreheader());
3429
3430 linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
3431 loopInfo->getIndVar());
3432 }
3433 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
3434
3435 ompBuilder->applySimd(loopInfo, alignedVars,
3436 simdOp.getIfExpr()
3437 ? moduleTranslation.lookupValue(simdOp.getIfExpr())
3438 : nullptr,
3439 order, simdlen, safelen);
3440
3441 linearClauseProcessor.emitStoresForLinearVar(builder);
3442 for (size_t index = 0; index < simdOp.getLinearVars().size(); index++)
3443 linearClauseProcessor.rewriteInPlace(builder, "omp.loop_nest.region",
3444 index);
3445
3446 // We now need to reduce the per-simd-lane reduction variable into the
3447 // original variable. This works a bit differently to other reductions (e.g.
3448 // wsloop) because we don't need to call into the OpenMP runtime to handle
3449 // threads: everything happened in this one thread.
3450 for (auto [i, tuple] : llvm::enumerate(
3451 llvm::zip(reductionDecls, isByRef, simdOp.getReductionVars(),
3452 privateReductionVariables))) {
3453 auto [decl, byRef, reductionVar, privateReductionVar] = tuple;
3454
3455 OwningReductionGen gen = makeReductionGen(decl, builder, moduleTranslation);
3456 llvm::Value *originalVariable = moduleTranslation.lookupValue(reductionVar);
3457 llvm::Type *reductionType = moduleTranslation.convertType(decl.getType());
3458
3459 // We have one less load for by-ref case because that load is now inside of
3460 // the reduction region.
3461 llvm::Value *redValue = originalVariable;
3462 if (!byRef)
3463 redValue =
3464 builder.CreateLoad(reductionType, redValue, "red.value." + Twine(i));
3465 llvm::Value *privateRedValue = builder.CreateLoad(
3466 reductionType, privateReductionVar, "red.private.value." + Twine(i));
3467 llvm::Value *reduced;
3468
3469 auto res = gen(builder.saveIP(), redValue, privateRedValue, reduced);
3470 if (failed(handleError(res, opInst)))
3471 return failure();
3472 builder.restoreIP(res.get());
3473
3474 // For by-ref case, the store is inside of the reduction region.
3475 if (!byRef)
3476 builder.CreateStore(reduced, originalVariable);
3477 }
3478
3479 // After the construct, deallocate private reduction variables.
3480 SmallVector<Region *> reductionRegions;
3481 llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
3482 [](omp::DeclareReductionOp reductionDecl) {
3483 return &reductionDecl.getCleanupRegion();
3484 });
3485 if (failed(inlineOmpRegionCleanup(reductionRegions, privateReductionVariables,
3486 moduleTranslation, builder,
3487 "omp.reduction.cleanup")))
3488 return failure();
3489
3490 return cleanupPrivateVars(builder, moduleTranslation, simdOp.getLoc(),
3491 privateVarsInfo.llvmVars,
3492 privateVarsInfo.privatizers);
3493}
3494
3495/// Converts an OpenMP loop nest into LLVM IR using OpenMPIRBuilder.
3496static LogicalResult
3497convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder,
3498 LLVM::ModuleTranslation &moduleTranslation) {
3499 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3500 auto loopOp = cast<omp::LoopNestOp>(opInst);
3501
3502 if (failed(checkImplementationStatus(opInst)))
3503 return failure();
3504
3505 // Set up the source location value for OpenMP runtime.
3506 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3507
3508 // Generator of the canonical loop body.
3511 auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip,
3512 llvm::Value *iv) -> llvm::Error {
3513 // Make sure further conversions know about the induction variable.
3514 moduleTranslation.mapValue(
3515 loopOp.getRegion().front().getArgument(loopInfos.size()), iv);
3516
3517 // Capture the body insertion point for use in nested loops. BodyIP of the
3518 // CanonicalLoopInfo always points to the beginning of the entry block of
3519 // the body.
3520 bodyInsertPoints.push_back(ip);
3521
3522 if (loopInfos.size() != loopOp.getNumLoops() - 1)
3523 return llvm::Error::success();
3524
3525 // Convert the body of the loop.
3526 builder.restoreIP(ip);
3528 loopOp.getRegion(), "omp.loop_nest.region", builder, moduleTranslation);
3529 if (!regionBlock)
3530 return regionBlock.takeError();
3531
3532 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
3533 return llvm::Error::success();
3534 };
3535
3536 // Delegate actual loop construction to the OpenMP IRBuilder.
3537 // TODO: this currently assumes omp.loop_nest is semantically similar to SCF
3538 // loop, i.e. it has a positive step, uses signed integer semantics.
3539 // Reconsider this code when the nested loop operation clearly supports more
3540 // cases.
3541 for (unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) {
3542 llvm::Value *lowerBound =
3543 moduleTranslation.lookupValue(loopOp.getLoopLowerBounds()[i]);
3544 llvm::Value *upperBound =
3545 moduleTranslation.lookupValue(loopOp.getLoopUpperBounds()[i]);
3546 llvm::Value *step = moduleTranslation.lookupValue(loopOp.getLoopSteps()[i]);
3547
3548 // Make sure loop trip count are emitted in the preheader of the outermost
3549 // loop at the latest so that they are all available for the new collapsed
3550 // loop will be created below.
3551 llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc;
3552 llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP;
3553 if (i != 0) {
3554 loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back(),
3555 ompLoc.DL);
3556 computeIP = loopInfos.front()->getPreheaderIP();
3557 }
3558
3560 ompBuilder->createCanonicalLoop(
3561 loc, bodyGen, lowerBound, upperBound, step,
3562 /*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP);
3563
3564 if (failed(handleError(loopResult, *loopOp)))
3565 return failure();
3566
3567 loopInfos.push_back(*loopResult);
3568 }
3569
3570 llvm::OpenMPIRBuilder::InsertPointTy afterIP =
3571 loopInfos.front()->getAfterIP();
3572
3573 // Do tiling.
3574 if (const auto &tiles = loopOp.getTileSizes()) {
3575 llvm::Type *ivType = loopInfos.front()->getIndVarType();
3577
3578 for (auto tile : tiles.value()) {
3579 llvm::Value *tileVal = llvm::ConstantInt::get(ivType, tile);
3580 tileSizes.push_back(tileVal);
3581 }
3582
3583 std::vector<llvm::CanonicalLoopInfo *> newLoops =
3584 ompBuilder->tileLoops(ompLoc.DL, loopInfos, tileSizes);
3585
3586 // Update afterIP to get the correct insertion point after
3587 // tiling.
3588 llvm::BasicBlock *afterBB = newLoops.front()->getAfter();
3589 llvm::BasicBlock *afterAfterBB = afterBB->getSingleSuccessor();
3590 afterIP = {afterAfterBB, afterAfterBB->begin()};
3591
3592 // Update the loop infos.
3593 loopInfos.clear();
3594 for (const auto &newLoop : newLoops)
3595 loopInfos.push_back(newLoop);
3596 } // Tiling done.
3597
3598 // Do collapse.
3599 const auto &numCollapse = loopOp.getCollapseNumLoops();
3601 loopInfos.begin(), loopInfos.begin() + (numCollapse));
3602
3603 auto newTopLoopInfo =
3604 ompBuilder->collapseLoops(ompLoc.DL, collapseLoopInfos, {});
3605
3606 assert(newTopLoopInfo && "New top loop information is missing");
3607 moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
3608 [&](OpenMPLoopInfoStackFrame &frame) {
3609 frame.loopInfo = newTopLoopInfo;
3610 return WalkResult::interrupt();
3611 });
3612
3613 // Continue building IR after the loop. Note that the LoopInfo returned by
3614 // `collapseLoops` points inside the outermost loop and is intended for
3615 // potential further loop transformations. Use the insertion point stored
3616 // before collapsing loops instead.
3617 builder.restoreIP(afterIP);
3618 return success();
3619}
3620
3621/// Convert an omp.canonical_loop to LLVM-IR
3622static LogicalResult
3623convertOmpCanonicalLoopOp(omp::CanonicalLoopOp op, llvm::IRBuilderBase &builder,
3624 LLVM::ModuleTranslation &moduleTranslation) {
3625 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3626
3627 llvm::OpenMPIRBuilder::LocationDescription loopLoc(builder);
3628 Value loopIV = op.getInductionVar();
3629 Value loopTC = op.getTripCount();
3630
3631 llvm::Value *llvmTC = moduleTranslation.lookupValue(loopTC);
3632
3634 ompBuilder->createCanonicalLoop(
3635 loopLoc,
3636 [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *llvmIV) {
3637 // Register the mapping of MLIR induction variable to LLVM-IR
3638 // induction variable
3639 moduleTranslation.mapValue(loopIV, llvmIV);
3640
3641 builder.restoreIP(ip);
3643 convertOmpOpRegions(op.getRegion(), "omp.loop.region", builder,
3644 moduleTranslation);
3645
3646 return bodyGenStatus.takeError();
3647 },
3648 llvmTC, "omp.loop");
3649 if (!llvmOrError)
3650 return op.emitError(llvm::toString(llvmOrError.takeError()));
3651
3652 llvm::CanonicalLoopInfo *llvmCLI = *llvmOrError;
3653 llvm::IRBuilderBase::InsertPoint afterIP = llvmCLI->getAfterIP();
3654 builder.restoreIP(afterIP);
3655
3656 // Register the mapping of MLIR loop to LLVM-IR OpenMPIRBuilder loop
3657 if (Value cli = op.getCli())
3658 moduleTranslation.mapOmpLoop(cli, llvmCLI);
3659
3660 return success();
3661}
3662
3663/// Apply a `#pragma omp unroll` / "!$omp unroll" transformation using the
3664/// OpenMPIRBuilder.
3665static LogicalResult
3666applyUnrollHeuristic(omp::UnrollHeuristicOp op, llvm::IRBuilderBase &builder,
3667 LLVM::ModuleTranslation &moduleTranslation) {
3668 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3669
3670 Value applyee = op.getApplyee();
3671 assert(applyee && "Loop to apply unrolling on required");
3672
3673 llvm::CanonicalLoopInfo *consBuilderCLI =
3674 moduleTranslation.lookupOMPLoop(applyee);
3675 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
3676 ompBuilder->unrollLoopHeuristic(loc.DL, consBuilderCLI);
3677
3678 moduleTranslation.invalidateOmpLoop(applyee);
3679 return success();
3680}
3681
3682/// Apply a `#pragma omp tile` / `!$omp tile` transformation using the
3683/// OpenMPIRBuilder.
3684static LogicalResult applyTile(omp::TileOp op, llvm::IRBuilderBase &builder,
3685 LLVM::ModuleTranslation &moduleTranslation) {
3686 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3687 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
3688
3690 SmallVector<llvm::Value *> translatedSizes;
3691
3692 for (Value size : op.getSizes()) {
3693 llvm::Value *translatedSize = moduleTranslation.lookupValue(size);
3694 assert(translatedSize &&
3695 "sizes clause arguments must already be translated");
3696 translatedSizes.push_back(translatedSize);
3697 }
3698
3699 for (Value applyee : op.getApplyees()) {
3700 llvm::CanonicalLoopInfo *consBuilderCLI =
3701 moduleTranslation.lookupOMPLoop(applyee);
3702 assert(applyee && "Canonical loop must already been translated");
3703 translatedLoops.push_back(consBuilderCLI);
3704 }
3705
3706 auto generatedLoops =
3707 ompBuilder->tileLoops(loc.DL, translatedLoops, translatedSizes);
3708 if (!op.getGeneratees().empty()) {
3709 for (auto [mlirLoop, genLoop] :
3710 zip_equal(op.getGeneratees(), generatedLoops))
3711 moduleTranslation.mapOmpLoop(mlirLoop, genLoop);
3712 }
3713
3714 // CLIs can only be consumed once
3715 for (Value applyee : op.getApplyees())
3716 moduleTranslation.invalidateOmpLoop(applyee);
3717
3718 return success();
3719}
3720
3721/// Convert an Atomic Ordering attribute to llvm::AtomicOrdering.
3722static llvm::AtomicOrdering
3723convertAtomicOrdering(std::optional<omp::ClauseMemoryOrderKind> ao) {
3724 if (!ao)
3725 return llvm::AtomicOrdering::Monotonic; // Default Memory Ordering
3726
3727 switch (*ao) {
3728 case omp::ClauseMemoryOrderKind::Seq_cst:
3729 return llvm::AtomicOrdering::SequentiallyConsistent;
3730 case omp::ClauseMemoryOrderKind::Acq_rel:
3731 return llvm::AtomicOrdering::AcquireRelease;
3732 case omp::ClauseMemoryOrderKind::Acquire:
3733 return llvm::AtomicOrdering::Acquire;
3734 case omp::ClauseMemoryOrderKind::Release:
3735 return llvm::AtomicOrdering::Release;
3736 case omp::ClauseMemoryOrderKind::Relaxed:
3737 return llvm::AtomicOrdering::Monotonic;
3738 }
3739 llvm_unreachable("Unknown ClauseMemoryOrderKind kind");
3740}
3741
3742/// Convert omp.atomic.read operation to LLVM IR.
3743static LogicalResult
3744convertOmpAtomicRead(Operation &opInst, llvm::IRBuilderBase &builder,
3745 LLVM::ModuleTranslation &moduleTranslation) {
3746 auto readOp = cast<omp::AtomicReadOp>(opInst);
3747 if (failed(checkImplementationStatus(opInst)))
3748 return failure();
3749
3750 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3751 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3752 findAllocaInsertPoint(builder, moduleTranslation);
3753
3754 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3755
3756 llvm::AtomicOrdering AO = convertAtomicOrdering(readOp.getMemoryOrder());
3757 llvm::Value *x = moduleTranslation.lookupValue(readOp.getX());
3758 llvm::Value *v = moduleTranslation.lookupValue(readOp.getV());
3759
3760 llvm::Type *elementType =
3761 moduleTranslation.convertType(readOp.getElementType());
3762
3763 llvm::OpenMPIRBuilder::AtomicOpValue V = {v, elementType, false, false};
3764 llvm::OpenMPIRBuilder::AtomicOpValue X = {x, elementType, false, false};
3765 builder.restoreIP(ompBuilder->createAtomicRead(ompLoc, X, V, AO, allocaIP));
3766 return success();
3767}
3768
3769/// Converts an omp.atomic.write operation to LLVM IR.
3770static LogicalResult
3771convertOmpAtomicWrite(Operation &opInst, llvm::IRBuilderBase &builder,
3772 LLVM::ModuleTranslation &moduleTranslation) {
3773 auto writeOp = cast<omp::AtomicWriteOp>(opInst);
3774 if (failed(checkImplementationStatus(opInst)))
3775 return failure();
3776
3777 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3778 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3779 findAllocaInsertPoint(builder, moduleTranslation);
3780
3781 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3782 llvm::AtomicOrdering ao = convertAtomicOrdering(writeOp.getMemoryOrder());
3783 llvm::Value *expr = moduleTranslation.lookupValue(writeOp.getExpr());
3784 llvm::Value *dest = moduleTranslation.lookupValue(writeOp.getX());
3785 llvm::Type *ty = moduleTranslation.convertType(writeOp.getExpr().getType());
3786 llvm::OpenMPIRBuilder::AtomicOpValue x = {dest, ty, /*isSigned=*/false,
3787 /*isVolatile=*/false};
3788 builder.restoreIP(
3789 ompBuilder->createAtomicWrite(ompLoc, x, expr, ao, allocaIP));
3790 return success();
3791}
3792
3793/// Converts an LLVM dialect binary operation to the corresponding enum value
3794/// for `atomicrmw` supported binary operation.
3795static llvm::AtomicRMWInst::BinOp convertBinOpToAtomic(Operation &op) {
3797 .Case([&](LLVM::AddOp) { return llvm::AtomicRMWInst::BinOp::Add; })
3798 .Case([&](LLVM::SubOp) { return llvm::AtomicRMWInst::BinOp::Sub; })
3799 .Case([&](LLVM::AndOp) { return llvm::AtomicRMWInst::BinOp::And; })
3800 .Case([&](LLVM::OrOp) { return llvm::AtomicRMWInst::BinOp::Or; })
3801 .Case([&](LLVM::XOrOp) { return llvm::AtomicRMWInst::BinOp::Xor; })
3802 .Case([&](LLVM::UMaxOp) { return llvm::AtomicRMWInst::BinOp::UMax; })
3803 .Case([&](LLVM::UMinOp) { return llvm::AtomicRMWInst::BinOp::UMin; })
3804 .Case([&](LLVM::FAddOp) { return llvm::AtomicRMWInst::BinOp::FAdd; })
3805 .Case([&](LLVM::FSubOp) { return llvm::AtomicRMWInst::BinOp::FSub; })
3806 .Default(llvm::AtomicRMWInst::BinOp::BAD_BINOP);
3807}
3808
3809static void extractAtomicControlFlags(omp::AtomicUpdateOp atomicUpdateOp,
3810 bool &isIgnoreDenormalMode,
3811 bool &isFineGrainedMemory,
3812 bool &isRemoteMemory) {
3813 isIgnoreDenormalMode = false;
3814 isFineGrainedMemory = false;
3815 isRemoteMemory = false;
3816 if (atomicUpdateOp &&
3817 atomicUpdateOp->hasAttr(atomicUpdateOp.getAtomicControlAttrName())) {
3818 mlir::omp::AtomicControlAttr atomicControlAttr =
3819 atomicUpdateOp.getAtomicControlAttr();
3820 isIgnoreDenormalMode = atomicControlAttr.getIgnoreDenormalMode();
3821 isFineGrainedMemory = atomicControlAttr.getFineGrainedMemory();
3822 isRemoteMemory = atomicControlAttr.getRemoteMemory();
3823 }
3824}
3825
3826/// Converts an OpenMP atomic update operation using OpenMPIRBuilder.
3827static LogicalResult
3828convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst,
3829 llvm::IRBuilderBase &builder,
3830 LLVM::ModuleTranslation &moduleTranslation) {
3831 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3832 if (failed(checkImplementationStatus(*opInst)))
3833 return failure();
3834
3835 // Convert values and types.
3836 auto &innerOpList = opInst.getRegion().front().getOperations();
3837 bool isXBinopExpr{false};
3838 llvm::AtomicRMWInst::BinOp binop;
3839 mlir::Value mlirExpr;
3840 llvm::Value *llvmExpr = nullptr;
3841 llvm::Value *llvmX = nullptr;
3842 llvm::Type *llvmXElementType = nullptr;
3843 if (innerOpList.size() == 2) {
3844 // The two operations here are the update and the terminator.
3845 // Since we can identify the update operation, there is a possibility
3846 // that we can generate the atomicrmw instruction.
3847 mlir::Operation &innerOp = *opInst.getRegion().front().begin();
3848 if (!llvm::is_contained(innerOp.getOperands(),
3849 opInst.getRegion().getArgument(0))) {
3850 return opInst.emitError("no atomic update operation with region argument"
3851 " as operand found inside atomic.update region");
3852 }
3853 binop = convertBinOpToAtomic(innerOp);
3854 isXBinopExpr = innerOp.getOperand(0) == opInst.getRegion().getArgument(0);
3855 mlirExpr = (isXBinopExpr ? innerOp.getOperand(1) : innerOp.getOperand(0));
3856 llvmExpr = moduleTranslation.lookupValue(mlirExpr);
3857 } else {
3858 // Since the update region includes more than one operation
3859 // we will resort to generating a cmpxchg loop.
3860 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
3861 }
3862 llvmX = moduleTranslation.lookupValue(opInst.getX());
3863 llvmXElementType = moduleTranslation.convertType(
3864 opInst.getRegion().getArgument(0).getType());
3865 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
3866 /*isSigned=*/false,
3867 /*isVolatile=*/false};
3868
3869 llvm::AtomicOrdering atomicOrdering =
3870 convertAtomicOrdering(opInst.getMemoryOrder());
3871
3872 // Generate update code.
3873 auto updateFn =
3874 [&opInst, &moduleTranslation](
3875 llvm::Value *atomicx,
3876 llvm::IRBuilder<> &builder) -> llvm::Expected<llvm::Value *> {
3877 Block &bb = *opInst.getRegion().begin();
3878 moduleTranslation.mapValue(*opInst.getRegion().args_begin(), atomicx);
3879 moduleTranslation.mapBlock(&bb, builder.GetInsertBlock());
3880 if (failed(moduleTranslation.convertBlock(bb, true, builder)))
3881 return llvm::make_error<PreviouslyReportedError>();
3882
3883 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.getTerminator());
3884 assert(yieldop && yieldop.getResults().size() == 1 &&
3885 "terminator must be omp.yield op and it must have exactly one "
3886 "argument");
3887 return moduleTranslation.lookupValue(yieldop.getResults()[0]);
3888 };
3889
3890 bool isIgnoreDenormalMode;
3891 bool isFineGrainedMemory;
3892 bool isRemoteMemory;
3893 extractAtomicControlFlags(opInst, isIgnoreDenormalMode, isFineGrainedMemory,
3894 isRemoteMemory);
3895 // Handle ambiguous alloca, if any.
3896 auto allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
3897 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3898 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3899 ompBuilder->createAtomicUpdate(ompLoc, allocaIP, llvmAtomicX, llvmExpr,
3900 atomicOrdering, binop, updateFn,
3901 isXBinopExpr, isIgnoreDenormalMode,
3902 isFineGrainedMemory, isRemoteMemory);
3903
3904 if (failed(handleError(afterIP, *opInst)))
3905 return failure();
3906
3907 builder.restoreIP(*afterIP);
3908 return success();
3909}
3910
3911static LogicalResult
3912convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp,
3913 llvm::IRBuilderBase &builder,
3914 LLVM::ModuleTranslation &moduleTranslation) {
3915 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3916 if (failed(checkImplementationStatus(*atomicCaptureOp)))
3917 return failure();
3918
3919 mlir::Value mlirExpr;
3920 bool isXBinopExpr = false, isPostfixUpdate = false;
3921 llvm::AtomicRMWInst::BinOp binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
3922
3923 omp::AtomicUpdateOp atomicUpdateOp = atomicCaptureOp.getAtomicUpdateOp();
3924 omp::AtomicWriteOp atomicWriteOp = atomicCaptureOp.getAtomicWriteOp();
3925
3926 assert((atomicUpdateOp || atomicWriteOp) &&
3927 "internal op must be an atomic.update or atomic.write op");
3928
3929 if (atomicWriteOp) {
3930 isPostfixUpdate = true;
3931 mlirExpr = atomicWriteOp.getExpr();
3932 } else {
3933 isPostfixUpdate = atomicCaptureOp.getSecondOp() ==
3934 atomicCaptureOp.getAtomicUpdateOp().getOperation();
3935 auto &innerOpList = atomicUpdateOp.getRegion().front().getOperations();
3936 // Find the binary update operation that uses the region argument
3937 // and get the expression to update
3938 if (innerOpList.size() == 2) {
3939 mlir::Operation &innerOp = *atomicUpdateOp.getRegion().front().begin();
3940 if (!llvm::is_contained(innerOp.getOperands(),
3941 atomicUpdateOp.getRegion().getArgument(0))) {
3942 return atomicUpdateOp.emitError(
3943 "no atomic update operation with region argument"
3944 " as operand found inside atomic.update region");
3945 }
3946 binop = convertBinOpToAtomic(innerOp);
3947 isXBinopExpr =
3948 innerOp.getOperand(0) == atomicUpdateOp.getRegion().getArgument(0);
3949 mlirExpr = (isXBinopExpr ? innerOp.getOperand(1) : innerOp.getOperand(0));
3950 } else {
3951 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
3952 }
3953 }
3954
3955 llvm::Value *llvmExpr = moduleTranslation.lookupValue(mlirExpr);
3956 llvm::Value *llvmX =
3957 moduleTranslation.lookupValue(atomicCaptureOp.getAtomicReadOp().getX());
3958 llvm::Value *llvmV =
3959 moduleTranslation.lookupValue(atomicCaptureOp.getAtomicReadOp().getV());
3960 llvm::Type *llvmXElementType = moduleTranslation.convertType(
3961 atomicCaptureOp.getAtomicReadOp().getElementType());
3962 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
3963 /*isSigned=*/false,
3964 /*isVolatile=*/false};
3965 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicV = {llvmV, llvmXElementType,
3966 /*isSigned=*/false,
3967 /*isVolatile=*/false};
3968
3969 llvm::AtomicOrdering atomicOrdering =
3970 convertAtomicOrdering(atomicCaptureOp.getMemoryOrder());
3971
3972 auto updateFn =
3973 [&](llvm::Value *atomicx,
3974 llvm::IRBuilder<> &builder) -> llvm::Expected<llvm::Value *> {
3975 if (atomicWriteOp)
3976 return moduleTranslation.lookupValue(atomicWriteOp.getExpr());
3977 Block &bb = *atomicUpdateOp.getRegion().begin();
3978 moduleTranslation.mapValue(*atomicUpdateOp.getRegion().args_begin(),
3979 atomicx);
3980 moduleTranslation.mapBlock(&bb, builder.GetInsertBlock());
3981 if (failed(moduleTranslation.convertBlock(bb, true, builder)))
3982 return llvm::make_error<PreviouslyReportedError>();
3983
3984 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.getTerminator());
3985 assert(yieldop && yieldop.getResults().size() == 1 &&
3986 "terminator must be omp.yield op and it must have exactly one "
3987 "argument");
3988 return moduleTranslation.lookupValue(yieldop.getResults()[0]);
3989 };
3990
3991 bool isIgnoreDenormalMode;
3992 bool isFineGrainedMemory;
3993 bool isRemoteMemory;
3994 extractAtomicControlFlags(atomicUpdateOp, isIgnoreDenormalMode,
3995 isFineGrainedMemory, isRemoteMemory);
3996 // Handle ambiguous alloca, if any.
3997 auto allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
3998 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3999 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4000 ompBuilder->createAtomicCapture(
4001 ompLoc, allocaIP, llvmAtomicX, llvmAtomicV, llvmExpr, atomicOrdering,
4002 binop, updateFn, atomicUpdateOp, isPostfixUpdate, isXBinopExpr,
4003 isIgnoreDenormalMode, isFineGrainedMemory, isRemoteMemory);
4004
4005 if (failed(handleError(afterIP, *atomicCaptureOp)))
4006 return failure();
4007
4008 builder.restoreIP(*afterIP);
4009 return success();
4010}
4011
4012static llvm::omp::Directive convertCancellationConstructType(
4013 omp::ClauseCancellationConstructType directive) {
4014 switch (directive) {
4015 case omp::ClauseCancellationConstructType::Loop:
4016 return llvm::omp::Directive::OMPD_for;
4017 case omp::ClauseCancellationConstructType::Parallel:
4018 return llvm::omp::Directive::OMPD_parallel;
4019 case omp::ClauseCancellationConstructType::Sections:
4020 return llvm::omp::Directive::OMPD_sections;
4021 case omp::ClauseCancellationConstructType::Taskgroup:
4022 return llvm::omp::Directive::OMPD_taskgroup;
4023 }
4024 llvm_unreachable("Unhandled cancellation construct type");
4025}
4026
4027static LogicalResult
4028convertOmpCancel(omp::CancelOp op, llvm::IRBuilderBase &builder,
4029 LLVM::ModuleTranslation &moduleTranslation) {
4030 if (failed(checkImplementationStatus(*op.getOperation())))
4031 return failure();
4032
4033 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4034 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4035
4036 llvm::Value *ifCond = nullptr;
4037 if (Value ifVar = op.getIfExpr())
4038 ifCond = moduleTranslation.lookupValue(ifVar);
4039
4040 llvm::omp::Directive cancelledDirective =
4041 convertCancellationConstructType(op.getCancelDirective());
4042
4043 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4044 ompBuilder->createCancel(ompLoc, ifCond, cancelledDirective);
4045
4046 if (failed(handleError(afterIP, *op.getOperation())))
4047 return failure();
4048
4049 builder.restoreIP(afterIP.get());
4050
4051 return success();
4052}
4053
4054static LogicalResult
4055convertOmpCancellationPoint(omp::CancellationPointOp op,
4056 llvm::IRBuilderBase &builder,
4057 LLVM::ModuleTranslation &moduleTranslation) {
4058 if (failed(checkImplementationStatus(*op.getOperation())))
4059 return failure();
4060
4061 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4062 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4063
4064 llvm::omp::Directive cancelledDirective =
4065 convertCancellationConstructType(op.getCancelDirective());
4066
4067 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4068 ompBuilder->createCancellationPoint(ompLoc, cancelledDirective);
4069
4070 if (failed(handleError(afterIP, *op.getOperation())))
4071 return failure();
4072
4073 builder.restoreIP(afterIP.get());
4074
4075 return success();
4076}
4077
4078/// Converts an OpenMP Threadprivate operation into LLVM IR using
4079/// OpenMPIRBuilder.
4080static LogicalResult
4081convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder,
4082 LLVM::ModuleTranslation &moduleTranslation) {
4083 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4084 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4085 auto threadprivateOp = cast<omp::ThreadprivateOp>(opInst);
4086
4087 if (failed(checkImplementationStatus(opInst)))
4088 return failure();
4089
4090 Value symAddr = threadprivateOp.getSymAddr();
4091 auto *symOp = symAddr.getDefiningOp();
4092
4093 if (auto asCast = dyn_cast<LLVM::AddrSpaceCastOp>(symOp))
4094 symOp = asCast.getOperand().getDefiningOp();
4095
4096 if (!isa<LLVM::AddressOfOp>(symOp))
4097 return opInst.emitError("Addressing symbol not found");
4098 LLVM::AddressOfOp addressOfOp = dyn_cast<LLVM::AddressOfOp>(symOp);
4099
4100 LLVM::GlobalOp global =
4101 addressOfOp.getGlobal(moduleTranslation.symbolTable());
4102 llvm::GlobalValue *globalValue = moduleTranslation.lookupGlobal(global);
4103
4104 if (!ompBuilder->Config.isTargetDevice()) {
4105 llvm::Type *type = globalValue->getValueType();
4106 llvm::TypeSize typeSize =
4107 builder.GetInsertBlock()->getModule()->getDataLayout().getTypeStoreSize(
4108 type);
4109 llvm::ConstantInt *size = builder.getInt64(typeSize.getFixedValue());
4110 llvm::Value *callInst = ompBuilder->createCachedThreadPrivate(
4111 ompLoc, globalValue, size, global.getSymName() + ".cache");
4112 moduleTranslation.mapValue(opInst.getResult(0), callInst);
4113 } else {
4114 moduleTranslation.mapValue(opInst.getResult(0), globalValue);
4115 }
4116
4117 return success();
4118}
4119
4120static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind
4121convertToDeviceClauseKind(mlir::omp::DeclareTargetDeviceType deviceClause) {
4122 switch (deviceClause) {
4123 case mlir::omp::DeclareTargetDeviceType::host:
4124 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseHost;
4125 break;
4126 case mlir::omp::DeclareTargetDeviceType::nohost:
4127 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseNoHost;
4128 break;
4129 case mlir::omp::DeclareTargetDeviceType::any:
4130 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseAny;
4131 break;
4132 }
4133 llvm_unreachable("unhandled device clause");
4134}
4135
4136static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind
4138 mlir::omp::DeclareTargetCaptureClause captureClause) {
4139 switch (captureClause) {
4140 case mlir::omp::DeclareTargetCaptureClause::to:
4141 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo;
4142 case mlir::omp::DeclareTargetCaptureClause::link:
4143 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink;
4144 case mlir::omp::DeclareTargetCaptureClause::enter:
4145 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter;
4146 case mlir::omp::DeclareTargetCaptureClause::none:
4147 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryNone;
4148 }
4149 llvm_unreachable("unhandled capture clause");
4150}
4151
4153 Operation *op = value.getDefiningOp();
4154 if (auto addrCast = dyn_cast_if_present<LLVM::AddrSpaceCastOp>(op))
4155 op = addrCast->getOperand(0).getDefiningOp();
4156 if (auto addressOfOp = dyn_cast_if_present<LLVM::AddressOfOp>(op)) {
4157 auto modOp = addressOfOp->getParentOfType<mlir::ModuleOp>();
4158 return modOp.lookupSymbol(addressOfOp.getGlobalName());
4159 }
4160 return nullptr;
4161}
4162
4163static llvm::SmallString<64>
4164getDeclareTargetRefPtrSuffix(LLVM::GlobalOp globalOp,
4165 llvm::OpenMPIRBuilder &ompBuilder) {
4166 llvm::SmallString<64> suffix;
4167 llvm::raw_svector_ostream os(suffix);
4168 if (globalOp.getVisibility() == mlir::SymbolTable::Visibility::Private) {
4169 auto loc = globalOp->getLoc()->findInstanceOf<FileLineColLoc>();
4170 auto fileInfoCallBack = [&loc]() {
4171 return std::pair<std::string, uint64_t>(
4172 llvm::StringRef(loc.getFilename()), loc.getLine());
4173 };
4174
4175 auto vfs = llvm::vfs::getRealFileSystem();
4176 os << llvm::format(
4177 "_%x",
4178 ompBuilder.getTargetEntryUniqueInfo(fileInfoCallBack, *vfs).FileID);
4179 }
4180 os << "_decl_tgt_ref_ptr";
4181
4182 return suffix;
4183}
4184
4185static bool isDeclareTargetLink(Value value) {
4186 if (auto declareTargetGlobal =
4187 dyn_cast_if_present<omp::DeclareTargetInterface>(
4188 getGlobalOpFromValue(value)))
4189 if (declareTargetGlobal.getDeclareTargetCaptureClause() ==
4190 omp::DeclareTargetCaptureClause::link)
4191 return true;
4192 return false;
4193}
4194
4195static bool isDeclareTargetTo(Value value) {
4196 if (auto declareTargetGlobal =
4197 dyn_cast_if_present<omp::DeclareTargetInterface>(
4198 getGlobalOpFromValue(value)))
4199 if (declareTargetGlobal.getDeclareTargetCaptureClause() ==
4200 omp::DeclareTargetCaptureClause::to ||
4201 declareTargetGlobal.getDeclareTargetCaptureClause() ==
4202 omp::DeclareTargetCaptureClause::enter)
4203 return true;
4204 return false;
4205}
4206
4207// Returns the reference pointer generated by the lowering of the declare
4208// target operation in cases where the link clause is used or the to clause is
4209// used in USM mode.
4210static llvm::Value *
4212 LLVM::ModuleTranslation &moduleTranslation) {
4213 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4214 if (auto gOp =
4215 dyn_cast_or_null<LLVM::GlobalOp>(getGlobalOpFromValue(value))) {
4216 if (auto declareTargetGlobal =
4217 dyn_cast<omp::DeclareTargetInterface>(gOp.getOperation())) {
4218 // In this case, we must utilise the reference pointer generated by
4219 // the declare target operation, similar to Clang
4220 if ((declareTargetGlobal.getDeclareTargetCaptureClause() ==
4221 omp::DeclareTargetCaptureClause::link) ||
4222 (declareTargetGlobal.getDeclareTargetCaptureClause() ==
4223 omp::DeclareTargetCaptureClause::to &&
4224 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
4225 llvm::SmallString<64> suffix =
4226 getDeclareTargetRefPtrSuffix(gOp, *ompBuilder);
4227
4228 if (gOp.getSymName().contains(suffix))
4229 return moduleTranslation.getLLVMModule()->getNamedValue(
4230 gOp.getSymName());
4231
4232 return moduleTranslation.getLLVMModule()->getNamedValue(
4233 (gOp.getSymName().str() + suffix.str()).str());
4234 }
4235 }
4236 }
4237 return nullptr;
4238}
4239
4240namespace {
4241// Append customMappers information to existing MapInfosTy
4242struct MapInfosTy : llvm::OpenMPIRBuilder::MapInfosTy {
4243 SmallVector<Operation *, 4> Mappers;
4244
4245 /// Append arrays in \a CurInfo.
4246 void append(MapInfosTy &curInfo) {
4247 Mappers.append(curInfo.Mappers.begin(), curInfo.Mappers.end());
4248 llvm::OpenMPIRBuilder::MapInfosTy::append(curInfo);
4249 }
4250};
4251// A small helper structure to contain data gathered
4252// for map lowering and coalese it into one area and
4253// avoiding extra computations such as searches in the
4254// llvm module for lowered mapped variables or checking
4255// if something is declare target (and retrieving the
4256// value) more than neccessary.
4257struct MapInfoData : MapInfosTy {
4258 llvm::SmallVector<bool, 4> IsDeclareTarget;
4259 llvm::SmallVector<bool, 4> IsAMember;
4260 // Identify if mapping was added by mapClause or use_device clauses.
4261 llvm::SmallVector<bool, 4> IsAMapping;
4262 llvm::SmallVector<mlir::Operation *, 4> MapClause;
4263 llvm::SmallVector<llvm::Value *, 4> OriginalValue;
4264 // Stripped off array/pointer to get the underlying
4265 // element type
4266 llvm::SmallVector<llvm::Type *, 4> BaseType;
4267
4268 /// Append arrays in \a CurInfo.
4269 void append(MapInfoData &CurInfo) {
4270 IsDeclareTarget.append(CurInfo.IsDeclareTarget.begin(),
4271 CurInfo.IsDeclareTarget.end());
4272 MapClause.append(CurInfo.MapClause.begin(), CurInfo.MapClause.end());
4273 OriginalValue.append(CurInfo.OriginalValue.begin(),
4274 CurInfo.OriginalValue.end());
4275 BaseType.append(CurInfo.BaseType.begin(), CurInfo.BaseType.end());
4276 MapInfosTy::append(CurInfo);
4277 }
4278};
4279
4280enum class TargetDirectiveEnumTy : uint32_t {
4281 None = 0,
4282 Target = 1,
4283 TargetData = 2,
4284 TargetEnterData = 3,
4285 TargetExitData = 4,
4286 TargetUpdate = 5
4287};
4288
4289static TargetDirectiveEnumTy getTargetDirectiveEnumTyFromOp(Operation *op) {
4290 return llvm::TypeSwitch<Operation *, TargetDirectiveEnumTy>(op)
4291 .Case([](omp::TargetDataOp) { return TargetDirectiveEnumTy::TargetData; })
4292 .Case([](omp::TargetEnterDataOp) {
4293 return TargetDirectiveEnumTy::TargetEnterData;
4294 })
4295 .Case([&](omp::TargetExitDataOp) {
4296 return TargetDirectiveEnumTy::TargetExitData;
4297 })
4298 .Case([&](omp::TargetUpdateOp) {
4299 return TargetDirectiveEnumTy::TargetUpdate;
4300 })
4301 .Case([&](omp::TargetOp) { return TargetDirectiveEnumTy::Target; })
4302 .Default([&](Operation *op) { return TargetDirectiveEnumTy::None; });
4303}
4304
4305} // namespace
4306
4307static uint64_t getArrayElementSizeInBits(LLVM::LLVMArrayType arrTy,
4308 DataLayout &dl) {
4309 if (auto nestedArrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(
4310 arrTy.getElementType()))
4311 return getArrayElementSizeInBits(nestedArrTy, dl);
4312 return dl.getTypeSizeInBits(arrTy.getElementType());
4313}
4314
4315// This function calculates the size to be offloaded for a specified type, given
4316// its associated map clause (which can contain bounds information which affects
4317// the total size), this size is calculated based on the underlying element type
4318// e.g. given a 1-D array of ints, we will calculate the size from the integer
4319// type * number of elements in the array. This size can be used in other
4320// calculations but is ultimately used as an argument to the OpenMP runtimes
4321// kernel argument structure which is generated through the combinedInfo data
4322// structures.
4323// This function is somewhat equivalent to Clang's getExprTypeSize inside of
4324// CGOpenMPRuntime.cpp.
4325static llvm::Value *getSizeInBytes(DataLayout &dl, const mlir::Type &type,
4326 Operation *clauseOp,
4327 llvm::Value *basePointer,
4328 llvm::Type *baseType,
4329 llvm::IRBuilderBase &builder,
4330 LLVM::ModuleTranslation &moduleTranslation) {
4331 if (auto memberClause =
4332 mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(clauseOp)) {
4333 // This calculates the size to transfer based on bounds and the underlying
4334 // element type, provided bounds have been specified (Fortran
4335 // pointers/allocatables/target and arrays that have sections specified fall
4336 // into this as well)
4337 if (!memberClause.getBounds().empty()) {
4338 llvm::Value *elementCount = builder.getInt64(1);
4339 for (auto bounds : memberClause.getBounds()) {
4340 if (auto boundOp = mlir::dyn_cast_if_present<mlir::omp::MapBoundsOp>(
4341 bounds.getDefiningOp())) {
4342 // The below calculation for the size to be mapped calculated from the
4343 // map.info's bounds is: (elemCount * [UB - LB] + 1), later we
4344 // multiply by the underlying element types byte size to get the full
4345 // size to be offloaded based on the bounds
4346 elementCount = builder.CreateMul(
4347 elementCount,
4348 builder.CreateAdd(
4349 builder.CreateSub(
4350 moduleTranslation.lookupValue(boundOp.getUpperBound()),
4351 moduleTranslation.lookupValue(boundOp.getLowerBound())),
4352 builder.getInt64(1)));
4353 }
4354 }
4355
4356 // utilising getTypeSizeInBits instead of getTypeSize as getTypeSize gives
4357 // the size in inconsistent byte or bit format.
4358 uint64_t underlyingTypeSzInBits = dl.getTypeSizeInBits(type);
4359 if (auto arrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(type))
4360 underlyingTypeSzInBits = getArrayElementSizeInBits(arrTy, dl);
4361
4362 // The size in bytes x number of elements, the sizeInBytes stored is
4363 // the underyling types size, e.g. if ptr<i32>, it'll be the i32's
4364 // size, so we do some on the fly runtime math to get the size in
4365 // bytes from the extent (ub - lb) * sizeInBytes. NOTE: This may need
4366 // some adjustment for members with more complex types.
4367 return builder.CreateMul(elementCount,
4368 builder.getInt64(underlyingTypeSzInBits / 8));
4369 }
4370 }
4371
4372 return builder.getInt64(dl.getTypeSizeInBits(type) / 8);
4373}
4374
4375// Convert the MLIR map flag set to the runtime map flag set for embedding
4376// in LLVM-IR. This is important as the two bit-flag lists do not correspond
4377// 1-to-1 as there's flags the runtime doesn't care about and vice versa.
4378// Certain flags are discarded here such as RefPtee and co.
4379static llvm::omp::OpenMPOffloadMappingFlags
4380convertClauseMapFlags(omp::ClauseMapFlags mlirFlags) {
4381 auto mapTypeToBool = [&mlirFlags](omp::ClauseMapFlags flag) {
4382 return (mlirFlags & flag) == flag;
4383 };
4384 const bool hasExplicitMap =
4385 (mlirFlags & ~omp::ClauseMapFlags::is_device_ptr) !=
4386 omp::ClauseMapFlags::none;
4387
4388 llvm::omp::OpenMPOffloadMappingFlags mapType =
4389 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
4390
4391 if (mapTypeToBool(omp::ClauseMapFlags::to))
4392 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
4393
4394 if (mapTypeToBool(omp::ClauseMapFlags::from))
4395 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
4396
4397 if (mapTypeToBool(omp::ClauseMapFlags::always))
4398 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
4399
4400 if (mapTypeToBool(omp::ClauseMapFlags::del))
4401 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
4402
4403 if (mapTypeToBool(omp::ClauseMapFlags::return_param))
4404 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
4405
4406 if (mapTypeToBool(omp::ClauseMapFlags::priv))
4407 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRIVATE;
4408
4409 if (mapTypeToBool(omp::ClauseMapFlags::literal))
4410 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
4411
4412 if (mapTypeToBool(omp::ClauseMapFlags::implicit))
4413 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
4414
4415 if (mapTypeToBool(omp::ClauseMapFlags::close))
4416 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
4417
4418 if (mapTypeToBool(omp::ClauseMapFlags::present))
4419 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT;
4420
4421 if (mapTypeToBool(omp::ClauseMapFlags::ompx_hold))
4422 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD;
4423
4424 if (mapTypeToBool(omp::ClauseMapFlags::attach))
4425 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH;
4426
4427 if (mapTypeToBool(omp::ClauseMapFlags::is_device_ptr)) {
4428 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
4429 if (!hasExplicitMap)
4430 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
4431 }
4432
4433 return mapType;
4434}
4435
4437 MapInfoData &mapData, SmallVectorImpl<Value> &mapVars,
4438 LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl,
4439 llvm::IRBuilderBase &builder, ArrayRef<Value> useDevPtrOperands = {},
4440 ArrayRef<Value> useDevAddrOperands = {},
4441 ArrayRef<Value> hasDevAddrOperands = {}) {
4442 auto checkIsAMember = [](const auto &mapVars, auto mapOp) {
4443 // Check if this is a member mapping and correctly assign that it is, if
4444 // it is a member of a larger object.
4445 // TODO: Need better handling of members, and distinguishing of members
4446 // that are implicitly allocated on device vs explicitly passed in as
4447 // arguments.
4448 // TODO: May require some further additions to support nested record
4449 // types, i.e. member maps that can have member maps.
4450 for (Value mapValue : mapVars) {
4451 auto map = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
4452 for (auto member : map.getMembers())
4453 if (member == mapOp)
4454 return true;
4455 }
4456 return false;
4457 };
4458
4459 // Process MapOperands
4460 for (Value mapValue : mapVars) {
4461 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
4462 Value offloadPtr =
4463 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
4464 mapData.OriginalValue.push_back(moduleTranslation.lookupValue(offloadPtr));
4465 mapData.Pointers.push_back(mapData.OriginalValue.back());
4466
4467 if (llvm::Value *refPtr =
4468 getRefPtrIfDeclareTarget(offloadPtr, moduleTranslation)) {
4469 mapData.IsDeclareTarget.push_back(true);
4470 mapData.BasePointers.push_back(refPtr);
4471 } else if (isDeclareTargetTo(offloadPtr)) {
4472 mapData.IsDeclareTarget.push_back(true);
4473 mapData.BasePointers.push_back(mapData.OriginalValue.back());
4474 } else { // regular mapped variable
4475 mapData.IsDeclareTarget.push_back(false);
4476 mapData.BasePointers.push_back(mapData.OriginalValue.back());
4477 }
4478
4479 mapData.BaseType.push_back(
4480 moduleTranslation.convertType(mapOp.getVarType()));
4481 mapData.Sizes.push_back(
4482 getSizeInBytes(dl, mapOp.getVarType(), mapOp, mapData.Pointers.back(),
4483 mapData.BaseType.back(), builder, moduleTranslation));
4484 mapData.MapClause.push_back(mapOp.getOperation());
4485 mapData.Types.push_back(convertClauseMapFlags(mapOp.getMapType()));
4486 mapData.Names.push_back(LLVM::createMappingInformation(
4487 mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
4488 mapData.DevicePointers.push_back(llvm::OpenMPIRBuilder::DeviceInfoTy::None);
4489 if (mapOp.getMapperId())
4490 mapData.Mappers.push_back(
4492 mapOp, mapOp.getMapperIdAttr()));
4493 else
4494 mapData.Mappers.push_back(nullptr);
4495 mapData.IsAMapping.push_back(true);
4496 mapData.IsAMember.push_back(checkIsAMember(mapVars, mapOp));
4497 }
4498
4499 auto findMapInfo = [&mapData](llvm::Value *val,
4500 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
4501 unsigned index = 0;
4502 bool found = false;
4503 for (llvm::Value *basePtr : mapData.OriginalValue) {
4504 if (basePtr == val && mapData.IsAMapping[index]) {
4505 found = true;
4506 mapData.Types[index] |=
4507 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
4508 mapData.DevicePointers[index] = devInfoTy;
4509 }
4510 index++;
4511 }
4512 return found;
4513 };
4514
4515 // Process useDevPtr(Addr)Operands
4516 auto addDevInfos = [&](const llvm::ArrayRef<Value> &useDevOperands,
4517 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
4518 for (Value mapValue : useDevOperands) {
4519 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
4520 Value offloadPtr =
4521 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
4522 llvm::Value *origValue = moduleTranslation.lookupValue(offloadPtr);
4523
4524 // Check if map info is already present for this entry.
4525 if (!findMapInfo(origValue, devInfoTy)) {
4526 mapData.OriginalValue.push_back(origValue);
4527 mapData.Pointers.push_back(mapData.OriginalValue.back());
4528 mapData.IsDeclareTarget.push_back(false);
4529 mapData.BasePointers.push_back(mapData.OriginalValue.back());
4530 mapData.BaseType.push_back(
4531 moduleTranslation.convertType(mapOp.getVarType()));
4532 mapData.Sizes.push_back(builder.getInt64(0));
4533 mapData.MapClause.push_back(mapOp.getOperation());
4534 mapData.Types.push_back(
4535 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM);
4536 mapData.Names.push_back(LLVM::createMappingInformation(
4537 mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
4538 mapData.DevicePointers.push_back(devInfoTy);
4539 mapData.Mappers.push_back(nullptr);
4540 mapData.IsAMapping.push_back(false);
4541 mapData.IsAMember.push_back(checkIsAMember(useDevOperands, mapOp));
4542 }
4543 }
4544 };
4545
4546 addDevInfos(useDevAddrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
4547 addDevInfos(useDevPtrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer);
4548
4549 for (Value mapValue : hasDevAddrOperands) {
4550 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
4551 Value offloadPtr =
4552 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
4553 llvm::Value *origValue = moduleTranslation.lookupValue(offloadPtr);
4554 auto mapType = convertClauseMapFlags(mapOp.getMapType());
4555 auto mapTypeAlways = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
4556 bool isDevicePtr =
4557 (mapOp.getMapType() & omp::ClauseMapFlags::is_device_ptr) !=
4558 omp::ClauseMapFlags::none;
4559
4560 mapData.OriginalValue.push_back(origValue);
4561 mapData.BasePointers.push_back(origValue);
4562 mapData.Pointers.push_back(origValue);
4563 mapData.IsDeclareTarget.push_back(false);
4564 mapData.BaseType.push_back(
4565 moduleTranslation.convertType(mapOp.getVarType()));
4566 mapData.Sizes.push_back(
4567 builder.getInt64(dl.getTypeSize(mapOp.getVarType())));
4568 mapData.MapClause.push_back(mapOp.getOperation());
4569 if (llvm::to_underlying(mapType & mapTypeAlways)) {
4570 // Descriptors are mapped with the ALWAYS flag, since they can get
4571 // rematerialized, so the address of the decriptor for a given object
4572 // may change from one place to another.
4573 mapData.Types.push_back(mapType);
4574 // Technically it's possible for a non-descriptor mapping to have
4575 // both has-device-addr and ALWAYS, so lookup the mapper in case it
4576 // exists.
4577 if (mapOp.getMapperId()) {
4578 mapData.Mappers.push_back(
4580 mapOp, mapOp.getMapperIdAttr()));
4581 } else {
4582 mapData.Mappers.push_back(nullptr);
4583 }
4584 } else {
4585 // For is_device_ptr we need the map type to propagate so the runtime
4586 // can materialize the device-side copy of the pointer container.
4587 mapData.Types.push_back(
4588 isDevicePtr ? mapType
4589 : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL);
4590 mapData.Mappers.push_back(nullptr);
4591 }
4592 mapData.Names.push_back(LLVM::createMappingInformation(
4593 mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
4594 mapData.DevicePointers.push_back(
4595 isDevicePtr ? llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer
4596 : llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
4597 mapData.IsAMapping.push_back(false);
4598 mapData.IsAMember.push_back(checkIsAMember(hasDevAddrOperands, mapOp));
4599 }
4600}
4601
4602static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp) {
4603 auto *res = llvm::find(mapData.MapClause, memberOp);
4604 assert(res != mapData.MapClause.end() &&
4605 "MapInfoOp for member not found in MapData, cannot return index");
4606 return std::distance(mapData.MapClause.begin(), res);
4607}
4608
4610 omp::MapInfoOp mapInfo) {
4611 ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
4612 llvm::SmallVector<size_t> occludedChildren;
4613 llvm::sort(
4614 indices.begin(), indices.end(), [&](const size_t a, const size_t b) {
4615 // Bail early if we are asked to look at the same index. If we do not
4616 // bail early, we can end up mistakenly adding indices to
4617 // occludedChildren. This can occur with some types of libc++ hardening.
4618 if (a == b)
4619 return false;
4620
4621 auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]);
4622 auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]);
4623
4624 for (auto it : llvm::zip(memberIndicesA, memberIndicesB)) {
4625 int64_t aIndex = mlir::cast<IntegerAttr>(std::get<0>(it)).getInt();
4626 int64_t bIndex = mlir::cast<IntegerAttr>(std::get<1>(it)).getInt();
4627
4628 if (aIndex == bIndex)
4629 continue;
4630
4631 if (aIndex < bIndex)
4632 return true;
4633
4634 if (aIndex > bIndex)
4635 return false;
4636 }
4637
4638 // Iterated up until the end of the smallest member and
4639 // they were found to be equal up to that point, so select
4640 // the member with the lowest index count, so the "parent"
4641 bool memberAParent = memberIndicesA.size() < memberIndicesB.size();
4642 if (memberAParent)
4643 occludedChildren.push_back(b);
4644 else
4645 occludedChildren.push_back(a);
4646 return memberAParent;
4647 });
4648
4649 // We remove children from the index list that are overshadowed by
4650 // a parent, this prevents us retrieving these as the first or last
4651 // element when the parent is the correct element in these cases.
4652 for (auto v : occludedChildren)
4653 indices.erase(std::remove(indices.begin(), indices.end(), v),
4654 indices.end());
4655}
4656
4657static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo,
4658 bool first) {
4659 ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
4660 // Only 1 member has been mapped, we can return it.
4661 if (indexAttr.size() == 1)
4662 return cast<omp::MapInfoOp>(mapInfo.getMembers()[0].getDefiningOp());
4663 llvm::SmallVector<size_t> indices(indexAttr.size());
4664 std::iota(indices.begin(), indices.end(), 0);
4665 sortMapIndices(indices, mapInfo);
4666 return llvm::cast<omp::MapInfoOp>(
4667 mapInfo.getMembers()[first ? indices.front() : indices.back()]
4668 .getDefiningOp());
4669}
4670
4671/// This function calculates the array/pointer offset for map data provided
4672/// with bounds operations, e.g. when provided something like the following:
4673///
4674/// Fortran
4675/// map(tofrom: array(2:5, 3:2))
4676///
4677/// We must calculate the initial pointer offset to pass across, this function
4678/// performs this using bounds.
4679///
4680/// TODO/WARNING: This only supports Fortran's column major indexing currently
4681/// as is noted in the note below and comments in the function, we must extend
4682/// this function when we add a C++ frontend.
4683/// NOTE: which while specified in row-major order it currently needs to be
4684/// flipped for Fortran's column order array allocation and access (as
4685/// opposed to C++'s row-major, hence the backwards processing where order is
4686/// important). This is likely important to keep in mind for the future when
4687/// we incorporate a C++ frontend, both frontends will need to agree on the
4688/// ordering of generated bounds operations (one may have to flip them) to
4689/// make the below lowering frontend agnostic. The offload size
4690/// calcualtion may also have to be adjusted for C++.
4691static std::vector<llvm::Value *>
4693 llvm::IRBuilderBase &builder, bool isArrayTy,
4694 OperandRange bounds) {
4695 std::vector<llvm::Value *> idx;
4696 // There's no bounds to calculate an offset from, we can safely
4697 // ignore and return no indices.
4698 if (bounds.empty())
4699 return idx;
4700
4701 // If we have an array type, then we have its type so can treat it as a
4702 // normal GEP instruction where the bounds operations are simply indexes
4703 // into the array. We currently do reverse order of the bounds, which
4704 // I believe leans more towards Fortran's column-major in memory.
4705 if (isArrayTy) {
4706 idx.push_back(builder.getInt64(0));
4707 for (int i = bounds.size() - 1; i >= 0; --i) {
4708 if (auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
4709 bounds[i].getDefiningOp())) {
4710 idx.push_back(moduleTranslation.lookupValue(boundOp.getLowerBound()));
4711 }
4712 }
4713 } else {
4714 // If we do not have an array type, but we have bounds, then we're dealing
4715 // with a pointer that's being treated like an array and we have the
4716 // underlying type e.g. an i32, or f64 etc, e.g. a fortran descriptor base
4717 // address (pointer pointing to the actual data) so we must caclulate the
4718 // offset using a single index which the following loop attempts to
4719 // compute using the standard column-major algorithm e.g for a 3D array:
4720 //
4721 // ((((c_idx * b_len) + b_idx) * a_len) + a_idx)
4722 //
4723 // It is of note that it's doing column-major rather than row-major at the
4724 // moment, but having a way for the frontend to indicate which major format
4725 // to use or standardizing/canonicalizing the order of the bounds to compute
4726 // the offset may be useful in the future when there's other frontends with
4727 // different formats.
4728 std::vector<llvm::Value *> dimensionIndexSizeOffset;
4729 for (int i = bounds.size() - 1; i >= 0; --i) {
4730 if (auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
4731 bounds[i].getDefiningOp())) {
4732 if (i == ((int)bounds.size() - 1))
4733 idx.emplace_back(
4734 moduleTranslation.lookupValue(boundOp.getLowerBound()));
4735 else
4736 idx.back() = builder.CreateAdd(
4737 builder.CreateMul(idx.back(), moduleTranslation.lookupValue(
4738 boundOp.getExtent())),
4739 moduleTranslation.lookupValue(boundOp.getLowerBound()));
4740 }
4741 }
4742 }
4743
4744 return idx;
4745}
4746
4748 llvm::transform(values, std::back_inserter(ints), [](Attribute value) {
4749 return cast<IntegerAttr>(value).getInt();
4750 });
4751}
4752
4753// Gathers members that are overlapping in the parent, excluding members that
4754// themselves overlap, keeping the top-most (closest to parents level) map.
4755static void
4757 omp::MapInfoOp parentOp) {
4758 // No members mapped, no overlaps.
4759 if (parentOp.getMembers().empty())
4760 return;
4761
4762 // Single member, we can insert and return early.
4763 if (parentOp.getMembers().size() == 1) {
4764 overlapMapDataIdxs.push_back(0);
4765 return;
4766 }
4767
4768 // 1) collect list of top-level overlapping members from MemberOp
4770 ArrayAttr indexAttr = parentOp.getMembersIndexAttr();
4771 for (auto [memIndex, indicesAttr] : llvm::enumerate(indexAttr))
4772 memberByIndex.push_back(
4773 std::make_pair(memIndex, cast<ArrayAttr>(indicesAttr)));
4774
4775 // Sort the smallest first (higher up the parent -> member chain), so that
4776 // when we remove members, we remove as much as we can in the initial
4777 // iterations, shortening the number of passes required.
4778 llvm::sort(memberByIndex.begin(), memberByIndex.end(),
4779 [&](auto a, auto b) { return a.second.size() < b.second.size(); });
4780
4781 // Remove elements from the vector if there is a parent element that
4782 // supersedes it. i.e. if member [0] is mapped, we can remove members [0,1],
4783 // [0,2].. etc.
4785 for (auto v : memberByIndex) {
4786 llvm::SmallVector<int64_t> vArr(v.second.size());
4787 getAsIntegers(v.second, vArr);
4788 skipList.push_back(
4789 *std::find_if(memberByIndex.begin(), memberByIndex.end(), [&](auto x) {
4790 if (v == x)
4791 return false;
4792 llvm::SmallVector<int64_t> xArr(x.second.size());
4793 getAsIntegers(x.second, xArr);
4794 return std::equal(vArr.begin(), vArr.end(), xArr.begin()) &&
4795 xArr.size() >= vArr.size();
4796 }));
4797 }
4798
4799 // Collect the indices, as we need the base pointer etc. from the MapData
4800 // structure which is primarily accessible via index at the moment.
4801 for (auto v : memberByIndex)
4802 if (find(skipList.begin(), skipList.end(), v) == skipList.end())
4803 overlapMapDataIdxs.push_back(v.first);
4804}
4805
4806// The intent is to verify if the mapped data being passed is a
4807// pointer -> pointee that requires special handling in certain cases,
4808// e.g. applying the OMP_MAP_PTR_AND_OBJ map type.
4809//
4810// There may be a better way to verify this, but unfortunately with
4811// opaque pointers we lose the ability to easily check if something is
4812// a pointer whilst maintaining access to the underlying type.
4813static bool checkIfPointerMap(omp::MapInfoOp mapOp) {
4814 // If we have a varPtrPtr field assigned then the underlying type is a pointer
4815 if (mapOp.getVarPtrPtr())
4816 return true;
4817
4818 // If the map data is declare target with a link clause, then it's represented
4819 // as a pointer when we lower it to LLVM-IR even if at the MLIR level it has
4820 // no relation to pointers.
4821 if (isDeclareTargetLink(mapOp.getVarPtr()))
4822 return true;
4823
4824 return false;
4825}
4826
4827// This creates two insertions into the MapInfosTy data structure for the
4828// "parent" of a set of members, (usually a container e.g.
4829// class/structure/derived type) when subsequent members have also been
4830// explicitly mapped on the same map clause. Certain types, such as Fortran
4831// descriptors are mapped like this as well, however, the members are
4832// implicit as far as a user is concerned, but we must explicitly map them
4833// internally.
4834//
4835// This function also returns the memberOfFlag for this particular parent,
4836// which is utilised in subsequent member mappings (by modifying there map type
4837// with it) to indicate that a member is part of this parent and should be
4838// treated by the runtime as such. Important to achieve the correct mapping.
4839//
4840// This function borrows a lot from Clang's emitCombinedEntry function
4841// inside of CGOpenMPRuntime.cpp
4842static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
4843 LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder,
4844 llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo,
4845 MapInfoData &mapData, uint64_t mapDataIndex,
4846 TargetDirectiveEnumTy targetDirective) {
4847 assert(!ompBuilder.Config.isTargetDevice() &&
4848 "function only supported for host device codegen");
4849
4850 // Map the first segment of the parent. If a user-defined mapper is attached,
4851 // include the parent's to/from-style bits (and common modifiers) in this
4852 // base entry so the mapper receives correct copy semantics via its 'type'
4853 // parameter. Also keep TARGET_PARAM when required for kernel arguments.
4854 llvm::omp::OpenMPOffloadMappingFlags baseFlag =
4855 (targetDirective == TargetDirectiveEnumTy::Target &&
4856 !mapData.IsDeclareTarget[mapDataIndex])
4857 ? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM
4858 : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
4859
4860 // Detect if this mapping uses a user-defined mapper.
4861 bool hasUserMapper = mapData.Mappers[mapDataIndex] != nullptr;
4862 if (hasUserMapper) {
4863 using mapFlags = llvm::omp::OpenMPOffloadMappingFlags;
4864 // Preserve relevant map-type bits from the parent clause. These include
4865 // the copy direction (TO/FROM), as well as commonly used modifiers that
4866 // should be visible to the mapper for correct behaviour.
4867 mapFlags parentFlags = mapData.Types[mapDataIndex];
4868 mapFlags preserve = mapFlags::OMP_MAP_TO | mapFlags::OMP_MAP_FROM |
4869 mapFlags::OMP_MAP_ALWAYS | mapFlags::OMP_MAP_CLOSE |
4870 mapFlags::OMP_MAP_PRESENT | mapFlags::OMP_MAP_OMPX_HOLD;
4871 baseFlag |= (parentFlags & preserve);
4872 }
4873
4874 combinedInfo.Types.emplace_back(baseFlag);
4875 combinedInfo.DevicePointers.emplace_back(
4876 mapData.DevicePointers[mapDataIndex]);
4877 combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIndex]);
4878 combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
4879 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
4880 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
4881
4882 // Calculate size of the parent object being mapped based on the
4883 // addresses at runtime, highAddr - lowAddr = size. This of course
4884 // doesn't factor in allocated data like pointers, hence the further
4885 // processing of members specified by users, or in the case of
4886 // Fortran pointers and allocatables, the mapping of the pointed to
4887 // data by the descriptor (which itself, is a structure containing
4888 // runtime information on the dynamically allocated data).
4889 auto parentClause =
4890 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
4891 llvm::Value *lowAddr, *highAddr;
4892 if (!parentClause.getPartialMap()) {
4893 lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
4894 builder.getPtrTy());
4895 highAddr = builder.CreatePointerCast(
4896 builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
4897 mapData.Pointers[mapDataIndex], 1),
4898 builder.getPtrTy());
4899 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
4900 } else {
4901 auto mapOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
4902 int firstMemberIdx = getMapDataMemberIdx(
4903 mapData, getFirstOrLastMappedMemberPtr(mapOp, true));
4904 lowAddr = builder.CreatePointerCast(mapData.Pointers[firstMemberIdx],
4905 builder.getPtrTy());
4906 int lastMemberIdx = getMapDataMemberIdx(
4907 mapData, getFirstOrLastMappedMemberPtr(mapOp, false));
4908 highAddr = builder.CreatePointerCast(
4909 builder.CreateGEP(mapData.BaseType[lastMemberIdx],
4910 mapData.Pointers[lastMemberIdx], builder.getInt64(1)),
4911 builder.getPtrTy());
4912 combinedInfo.Pointers.emplace_back(mapData.Pointers[firstMemberIdx]);
4913 }
4914
4915 llvm::Value *size = builder.CreateIntCast(
4916 builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
4917 builder.getInt64Ty(),
4918 /*isSigned=*/false);
4919 combinedInfo.Sizes.push_back(size);
4920
4921 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag =
4922 ompBuilder.getMemberOfFlag(combinedInfo.BasePointers.size() - 1);
4923
4924 // This creates the initial MEMBER_OF mapping that consists of
4925 // the parent/top level container (same as above effectively, except
4926 // with a fixed initial compile time size and separate maptype which
4927 // indicates the true mape type (tofrom etc.). This parent mapping is
4928 // only relevant if the structure in its totality is being mapped,
4929 // otherwise the above suffices.
4930 if (!parentClause.getPartialMap()) {
4931 // TODO: This will need to be expanded to include the whole host of logic
4932 // for the map flags that Clang currently supports (e.g. it should do some
4933 // further case specific flag modifications). For the moment, it handles
4934 // what we support as expected.
4935 llvm::omp::OpenMPOffloadMappingFlags mapFlag = mapData.Types[mapDataIndex];
4936 bool hasMapClose = (llvm::omp::OpenMPOffloadMappingFlags(mapFlag) &
4937 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE) ==
4938 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
4939 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
4940
4941 if (targetDirective == TargetDirectiveEnumTy::TargetUpdate || hasMapClose) {
4942 combinedInfo.Types.emplace_back(mapFlag);
4943 combinedInfo.DevicePointers.emplace_back(
4944 mapData.DevicePointers[mapDataIndex]);
4945 combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
4946 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
4947 combinedInfo.BasePointers.emplace_back(
4948 mapData.BasePointers[mapDataIndex]);
4949 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
4950 combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIndex]);
4951 combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIndex]);
4952 } else {
4953 llvm::SmallVector<size_t> overlapIdxs;
4954 // Find all of the members that "overlap", i.e. occlude other members that
4955 // were mapped alongside the parent, e.g. member [0], occludes [0,1] and
4956 // [0,2], but not [1,0].
4957 getOverlappedMembers(overlapIdxs, parentClause);
4958 // We need to make sure the overlapped members are sorted in order of
4959 // lowest address to highest address.
4960 sortMapIndices(overlapIdxs, parentClause);
4961
4962 lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
4963 builder.getPtrTy());
4964 highAddr = builder.CreatePointerCast(
4965 builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
4966 mapData.Pointers[mapDataIndex], 1),
4967 builder.getPtrTy());
4968
4969 // TODO: We may want to skip arrays/array sections in this as Clang does.
4970 // It appears to be an optimisation rather than a necessity though,
4971 // but this requires further investigation. However, we would have to make
4972 // sure to not exclude maps with bounds that ARE pointers, as these are
4973 // processed as separate components, i.e. pointer + data.
4974 for (auto v : overlapIdxs) {
4975 auto mapDataOverlapIdx = getMapDataMemberIdx(
4976 mapData,
4977 cast<omp::MapInfoOp>(parentClause.getMembers()[v].getDefiningOp()));
4978 combinedInfo.Types.emplace_back(mapFlag);
4979 combinedInfo.DevicePointers.emplace_back(
4980 mapData.DevicePointers[mapDataOverlapIdx]);
4981 combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
4982 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
4983 combinedInfo.BasePointers.emplace_back(
4984 mapData.BasePointers[mapDataIndex]);
4985 combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIndex]);
4986 combinedInfo.Pointers.emplace_back(lowAddr);
4987 combinedInfo.Sizes.emplace_back(builder.CreateIntCast(
4988 builder.CreatePtrDiff(builder.getInt8Ty(),
4989 mapData.OriginalValue[mapDataOverlapIdx],
4990 lowAddr),
4991 builder.getInt64Ty(), /*isSigned=*/true));
4992 lowAddr = builder.CreateConstGEP1_32(
4993 checkIfPointerMap(llvm::cast<omp::MapInfoOp>(
4994 mapData.MapClause[mapDataOverlapIdx]))
4995 ? builder.getPtrTy()
4996 : mapData.BaseType[mapDataOverlapIdx],
4997 mapData.BasePointers[mapDataOverlapIdx], 1);
4998 }
4999
5000 combinedInfo.Types.emplace_back(mapFlag);
5001 combinedInfo.DevicePointers.emplace_back(
5002 mapData.DevicePointers[mapDataIndex]);
5003 combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
5004 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
5005 combinedInfo.BasePointers.emplace_back(
5006 mapData.BasePointers[mapDataIndex]);
5007 combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIndex]);
5008 combinedInfo.Pointers.emplace_back(lowAddr);
5009 combinedInfo.Sizes.emplace_back(builder.CreateIntCast(
5010 builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
5011 builder.getInt64Ty(), true));
5012 }
5013 }
5014 return memberOfFlag;
5015}
5016
5017// This function is intended to add explicit mappings of members
5019 LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder,
5020 llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo,
5021 MapInfoData &mapData, uint64_t mapDataIndex,
5022 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag,
5023 TargetDirectiveEnumTy targetDirective) {
5024 assert(!ompBuilder.Config.isTargetDevice() &&
5025 "function only supported for host device codegen");
5026
5027 auto parentClause =
5028 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
5029
5030 for (auto mappedMembers : parentClause.getMembers()) {
5031 auto memberClause =
5032 llvm::cast<omp::MapInfoOp>(mappedMembers.getDefiningOp());
5033 int memberDataIdx = getMapDataMemberIdx(mapData, memberClause);
5034
5035 assert(memberDataIdx >= 0 && "could not find mapped member of structure");
5036
5037 // If we're currently mapping a pointer to a block of data, we must
5038 // initially map the pointer, and then attatch/bind the data with a
5039 // subsequent map to the pointer. This segment of code generates the
5040 // pointer mapping, which can in certain cases be optimised out as Clang
5041 // currently does in its lowering. However, for the moment we do not do so,
5042 // in part as we currently have substantially less information on the data
5043 // being mapped at this stage.
5044 if (checkIfPointerMap(memberClause)) {
5045 auto mapFlag = convertClauseMapFlags(memberClause.getMapType());
5046 mapFlag &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
5047 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
5048 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
5049 combinedInfo.Types.emplace_back(mapFlag);
5050 combinedInfo.DevicePointers.emplace_back(
5051 llvm::OpenMPIRBuilder::DeviceInfoTy::None);
5052 combinedInfo.Mappers.emplace_back(nullptr);
5053 combinedInfo.Names.emplace_back(
5054 LLVM::createMappingInformation(memberClause.getLoc(), ompBuilder));
5055 combinedInfo.BasePointers.emplace_back(
5056 mapData.BasePointers[mapDataIndex]);
5057 combinedInfo.Pointers.emplace_back(mapData.BasePointers[memberDataIdx]);
5058 combinedInfo.Sizes.emplace_back(builder.getInt64(
5059 moduleTranslation.getLLVMModule()->getDataLayout().getPointerSize()));
5060 }
5061
5062 // Same MemberOfFlag to indicate its link with parent and other members
5063 // of.
5064 auto mapFlag = convertClauseMapFlags(memberClause.getMapType());
5065 mapFlag &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
5066 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
5067 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
5068 bool isDeclTargetTo = isDeclareTargetTo(parentClause.getVarPtr()
5069 ? parentClause.getVarPtr()
5070 : parentClause.getVarPtrPtr());
5071 if (checkIfPointerMap(memberClause) &&
5072 (!isDeclTargetTo ||
5073 (targetDirective != TargetDirectiveEnumTy::TargetUpdate &&
5074 targetDirective != TargetDirectiveEnumTy::TargetData))) {
5075 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
5076 }
5077
5078 combinedInfo.Types.emplace_back(mapFlag);
5079 combinedInfo.DevicePointers.emplace_back(
5080 mapData.DevicePointers[memberDataIdx]);
5081 combinedInfo.Mappers.emplace_back(mapData.Mappers[memberDataIdx]);
5082 combinedInfo.Names.emplace_back(
5083 LLVM::createMappingInformation(memberClause.getLoc(), ompBuilder));
5084 uint64_t basePointerIndex =
5085 checkIfPointerMap(memberClause) ? memberDataIdx : mapDataIndex;
5086 combinedInfo.BasePointers.emplace_back(
5087 mapData.BasePointers[basePointerIndex]);
5088 combinedInfo.Pointers.emplace_back(mapData.Pointers[memberDataIdx]);
5089
5090 llvm::Value *size = mapData.Sizes[memberDataIdx];
5091 if (checkIfPointerMap(memberClause)) {
5092 size = builder.CreateSelect(
5093 builder.CreateIsNull(mapData.Pointers[memberDataIdx]),
5094 builder.getInt64(0), size);
5095 }
5096
5097 combinedInfo.Sizes.emplace_back(size);
5098 }
5099}
5100
5101static void processIndividualMap(MapInfoData &mapData, size_t mapDataIdx,
5102 MapInfosTy &combinedInfo,
5103 TargetDirectiveEnumTy targetDirective,
5104 int mapDataParentIdx = -1) {
5105 // Declare Target Mappings are excluded from being marked as
5106 // OMP_MAP_TARGET_PARAM as they are not passed as parameters, they're
5107 // marked with OMP_MAP_PTR_AND_OBJ instead.
5108 auto mapFlag = mapData.Types[mapDataIdx];
5109 auto mapInfoOp = llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIdx]);
5110
5111 bool isPtrTy = checkIfPointerMap(mapInfoOp);
5112 if (isPtrTy)
5113 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
5114
5115 if (targetDirective == TargetDirectiveEnumTy::Target &&
5116 !mapData.IsDeclareTarget[mapDataIdx])
5117 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
5118
5119 if (mapInfoOp.getMapCaptureType() == omp::VariableCaptureKind::ByCopy &&
5120 !isPtrTy)
5121 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
5122
5123 // if we're provided a mapDataParentIdx, then the data being mapped is
5124 // part of a larger object (in a parent <-> member mapping) and in this
5125 // case our BasePointer should be the parent.
5126 if (mapDataParentIdx >= 0)
5127 combinedInfo.BasePointers.emplace_back(
5128 mapData.BasePointers[mapDataParentIdx]);
5129 else
5130 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIdx]);
5131
5132 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIdx]);
5133 combinedInfo.DevicePointers.emplace_back(mapData.DevicePointers[mapDataIdx]);
5134 combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIdx]);
5135 combinedInfo.Names.emplace_back(mapData.Names[mapDataIdx]);
5136 combinedInfo.Types.emplace_back(mapFlag);
5137 combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIdx]);
5138}
5139
5141 llvm::IRBuilderBase &builder,
5142 llvm::OpenMPIRBuilder &ompBuilder,
5143 DataLayout &dl, MapInfosTy &combinedInfo,
5144 MapInfoData &mapData, uint64_t mapDataIndex,
5145 TargetDirectiveEnumTy targetDirective) {
5146 assert(!ompBuilder.Config.isTargetDevice() &&
5147 "function only supported for host device codegen");
5148
5149 auto parentClause =
5150 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
5151
5152 // If we have a partial map (no parent referenced in the map clauses of the
5153 // directive, only members) and only a single member, we do not need to bind
5154 // the map of the member to the parent, we can pass the member separately.
5155 if (parentClause.getMembers().size() == 1 && parentClause.getPartialMap()) {
5156 auto memberClause = llvm::cast<omp::MapInfoOp>(
5157 parentClause.getMembers()[0].getDefiningOp());
5158 int memberDataIdx = getMapDataMemberIdx(mapData, memberClause);
5159 // Note: Clang treats arrays with explicit bounds that fall into this
5160 // category as a parent with map case, however, it seems this isn't a
5161 // requirement, and processing them as an individual map is fine. So,
5162 // we will handle them as individual maps for the moment, as it's
5163 // difficult for us to check this as we always require bounds to be
5164 // specified currently and it's also marginally more optimal (single
5165 // map rather than two). The difference may come from the fact that
5166 // Clang maps array without bounds as pointers (which we do not
5167 // currently do), whereas we treat them as arrays in all cases
5168 // currently.
5169 processIndividualMap(mapData, memberDataIdx, combinedInfo, targetDirective,
5170 mapDataIndex);
5171 return;
5172 }
5173
5174 llvm::omp::OpenMPOffloadMappingFlags memberOfParentFlag =
5175 mapParentWithMembers(moduleTranslation, builder, ompBuilder, dl,
5176 combinedInfo, mapData, mapDataIndex,
5177 targetDirective);
5178 processMapMembersWithParent(moduleTranslation, builder, ompBuilder, dl,
5179 combinedInfo, mapData, mapDataIndex,
5180 memberOfParentFlag, targetDirective);
5181}
5182
5183// This is a variation on Clang's GenerateOpenMPCapturedVars, which
5184// generates different operation (e.g. load/store) combinations for
5185// arguments to the kernel, based on map capture kinds which are then
5186// utilised in the combinedInfo in place of the original Map value.
5187static void
5188createAlteredByCaptureMap(MapInfoData &mapData,
5189 LLVM::ModuleTranslation &moduleTranslation,
5190 llvm::IRBuilderBase &builder) {
5191 assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
5192 "function only supported for host device codegen");
5193 for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
5194 // if it's declare target, skip it, it's handled separately.
5195 if (!mapData.IsDeclareTarget[i]) {
5196 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
5197 omp::VariableCaptureKind captureKind = mapOp.getMapCaptureType();
5198 bool isPtrTy = checkIfPointerMap(mapOp);
5199
5200 // Currently handles array sectioning lowerbound case, but more
5201 // logic may be required in the future. Clang invokes EmitLValue,
5202 // which has specialised logic for special Clang types such as user
5203 // defines, so it is possible we will have to extend this for
5204 // structures or other complex types. As the general idea is that this
5205 // function mimics some of the logic from Clang that we require for
5206 // kernel argument passing from host -> device.
5207 switch (captureKind) {
5208 case omp::VariableCaptureKind::ByRef: {
5209 llvm::Value *newV = mapData.Pointers[i];
5210 std::vector<llvm::Value *> offsetIdx = calculateBoundsOffset(
5211 moduleTranslation, builder, mapData.BaseType[i]->isArrayTy(),
5212 mapOp.getBounds());
5213 if (isPtrTy)
5214 newV = builder.CreateLoad(builder.getPtrTy(), newV);
5215
5216 if (!offsetIdx.empty())
5217 newV = builder.CreateInBoundsGEP(mapData.BaseType[i], newV, offsetIdx,
5218 "array_offset");
5219 mapData.Pointers[i] = newV;
5220 } break;
5221 case omp::VariableCaptureKind::ByCopy: {
5222 llvm::Type *type = mapData.BaseType[i];
5223 llvm::Value *newV;
5224 if (mapData.Pointers[i]->getType()->isPointerTy())
5225 newV = builder.CreateLoad(type, mapData.Pointers[i]);
5226 else
5227 newV = mapData.Pointers[i];
5228
5229 if (!isPtrTy) {
5230 auto curInsert = builder.saveIP();
5231 llvm::DebugLoc DbgLoc = builder.getCurrentDebugLocation();
5232 builder.restoreIP(findAllocaInsertPoint(builder, moduleTranslation));
5233 auto *memTempAlloc =
5234 builder.CreateAlloca(builder.getPtrTy(), nullptr, ".casted");
5235 builder.SetCurrentDebugLocation(DbgLoc);
5236 builder.restoreIP(curInsert);
5237
5238 builder.CreateStore(newV, memTempAlloc);
5239 newV = builder.CreateLoad(builder.getPtrTy(), memTempAlloc);
5240 }
5241
5242 mapData.Pointers[i] = newV;
5243 mapData.BasePointers[i] = newV;
5244 } break;
5245 case omp::VariableCaptureKind::This:
5246 case omp::VariableCaptureKind::VLAType:
5247 mapData.MapClause[i]->emitOpError("Unhandled capture kind");
5248 break;
5249 }
5250 }
5251 }
5252}
5253
5254// Generate all map related information and fill the combinedInfo.
5255static void genMapInfos(llvm::IRBuilderBase &builder,
5256 LLVM::ModuleTranslation &moduleTranslation,
5257 DataLayout &dl, MapInfosTy &combinedInfo,
5258 MapInfoData &mapData,
5259 TargetDirectiveEnumTy targetDirective) {
5260 assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
5261 "function only supported for host device codegen");
5262 // We wish to modify some of the methods in which arguments are
5263 // passed based on their capture type by the target region, this can
5264 // involve generating new loads and stores, which changes the
5265 // MLIR value to LLVM value mapping, however, we only wish to do this
5266 // locally for the current function/target and also avoid altering
5267 // ModuleTranslation, so we remap the base pointer or pointer stored
5268 // in the map infos corresponding MapInfoData, which is later accessed
5269 // by genMapInfos and createTarget to help generate the kernel and
5270 // kernel arg structure. It primarily becomes relevant in cases like
5271 // bycopy, or byref range'd arrays. In the default case, we simply
5272 // pass thee pointer byref as both basePointer and pointer.
5273 createAlteredByCaptureMap(mapData, moduleTranslation, builder);
5274
5275 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
5276
5277 // We operate under the assumption that all vectors that are
5278 // required in MapInfoData are of equal lengths (either filled with
5279 // default constructed data or appropiate information) so we can
5280 // utilise the size from any component of MapInfoData, if we can't
5281 // something is missing from the initial MapInfoData construction.
5282 for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
5283 // NOTE/TODO: We currently do not support arbitrary depth record
5284 // type mapping.
5285 if (mapData.IsAMember[i])
5286 continue;
5287
5288 auto mapInfoOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[i]);
5289 if (!mapInfoOp.getMembers().empty()) {
5290 processMapWithMembersOf(moduleTranslation, builder, *ompBuilder, dl,
5291 combinedInfo, mapData, i, targetDirective);
5292 continue;
5293 }
5294
5295 processIndividualMap(mapData, i, combinedInfo, targetDirective);
5296 }
5297}
5298
5299static llvm::Expected<llvm::Function *>
5300emitUserDefinedMapper(Operation *declMapperOp, llvm::IRBuilderBase &builder,
5301 LLVM::ModuleTranslation &moduleTranslation,
5302 llvm::StringRef mapperFuncName,
5303 TargetDirectiveEnumTy targetDirective);
5304
5305static llvm::Expected<llvm::Function *>
5306getOrCreateUserDefinedMapperFunc(Operation *op, llvm::IRBuilderBase &builder,
5307 LLVM::ModuleTranslation &moduleTranslation,
5308 TargetDirectiveEnumTy targetDirective) {
5309 assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
5310 "function only supported for host device codegen");
5311 auto declMapperOp = cast<omp::DeclareMapperOp>(op);
5312 std::string mapperFuncName =
5313 moduleTranslation.getOpenMPBuilder()->createPlatformSpecificName(
5314 {"omp_mapper", declMapperOp.getSymName()});
5315
5316 if (auto *lookupFunc = moduleTranslation.lookupFunction(mapperFuncName))
5317 return lookupFunc;
5318
5319 return emitUserDefinedMapper(declMapperOp, builder, moduleTranslation,
5320 mapperFuncName, targetDirective);
5321}
5322
5323static llvm::Expected<llvm::Function *>
5324emitUserDefinedMapper(Operation *op, llvm::IRBuilderBase &builder,
5325 LLVM::ModuleTranslation &moduleTranslation,
5326 llvm::StringRef mapperFuncName,
5327 TargetDirectiveEnumTy targetDirective) {
5328 assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
5329 "function only supported for host device codegen");
5330 auto declMapperOp = cast<omp::DeclareMapperOp>(op);
5331 auto declMapperInfoOp = declMapperOp.getDeclareMapperInfo();
5332 DataLayout dl = DataLayout(declMapperOp->getParentOfType<ModuleOp>());
5333 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
5334 llvm::Type *varType = moduleTranslation.convertType(declMapperOp.getType());
5335 SmallVector<Value> mapVars = declMapperInfoOp.getMapVars();
5336
5337 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
5338
5339 // Fill up the arrays with all the mapped variables.
5340 MapInfosTy combinedInfo;
5341 auto genMapInfoCB =
5342 [&](InsertPointTy codeGenIP, llvm::Value *ptrPHI,
5343 llvm::Value *unused2) -> llvm::OpenMPIRBuilder::MapInfosOrErrorTy {
5344 builder.restoreIP(codeGenIP);
5345 moduleTranslation.mapValue(declMapperOp.getSymVal(), ptrPHI);
5346 moduleTranslation.mapBlock(&declMapperOp.getRegion().front(),
5347 builder.GetInsertBlock());
5348 if (failed(moduleTranslation.convertBlock(declMapperOp.getRegion().front(),
5349 /*ignoreArguments=*/true,
5350 builder)))
5351 return llvm::make_error<PreviouslyReportedError>();
5352 MapInfoData mapData;
5353 collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, dl,
5354 builder);
5355 genMapInfos(builder, moduleTranslation, dl, combinedInfo, mapData,
5356 targetDirective);
5357
5358 // Drop the mapping that is no longer necessary so that the same region
5359 // can be processed multiple times.
5360 moduleTranslation.forgetMapping(declMapperOp.getRegion());
5361 return combinedInfo;
5362 };
5363
5364 auto customMapperCB = [&](unsigned i) -> llvm::Expected<llvm::Function *> {
5365 if (!combinedInfo.Mappers[i])
5366 return nullptr;
5367 return getOrCreateUserDefinedMapperFunc(combinedInfo.Mappers[i], builder,
5368 moduleTranslation, targetDirective);
5369 };
5370
5371 llvm::Expected<llvm::Function *> newFn = ompBuilder->emitUserDefinedMapper(
5372 genMapInfoCB, varType, mapperFuncName, customMapperCB);
5373 if (!newFn)
5374 return newFn.takeError();
5375 moduleTranslation.mapFunction(mapperFuncName, *newFn);
5376 return *newFn;
5377}
5378
5379static LogicalResult
5380convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
5381 LLVM::ModuleTranslation &moduleTranslation) {
5382 llvm::Value *ifCond = nullptr;
5383 llvm::Value *deviceID = builder.getInt64(llvm::omp::OMP_DEVICEID_UNDEF);
5384 SmallVector<Value> mapVars;
5385 SmallVector<Value> useDevicePtrVars;
5386 SmallVector<Value> useDeviceAddrVars;
5387 llvm::omp::RuntimeFunction RTLFn;
5388 DataLayout DL = DataLayout(op->getParentOfType<ModuleOp>());
5389 TargetDirectiveEnumTy targetDirective = getTargetDirectiveEnumTyFromOp(op);
5390
5391 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
5392 llvm::OpenMPIRBuilder::TargetDataInfo info(
5393 /*RequiresDevicePointerInfo=*/true,
5394 /*SeparateBeginEndCalls=*/true);
5395 bool isTargetDevice = ompBuilder->Config.isTargetDevice();
5396 bool isOffloadEntry =
5397 isTargetDevice || !ompBuilder->Config.TargetTriples.empty();
5398
5399 auto getDeviceID = [&](mlir::Value dev) -> llvm::Value * {
5400 llvm::Value *v = moduleTranslation.lookupValue(dev);
5401 return builder.CreateIntCast(v, builder.getInt64Ty(), /*isSigned=*/true);
5402 };
5403
5404 LogicalResult result =
5406 .Case([&](omp::TargetDataOp dataOp) {
5407 if (failed(checkImplementationStatus(*dataOp)))
5408 return failure();
5409
5410 if (auto ifVar = dataOp.getIfExpr())
5411 ifCond = moduleTranslation.lookupValue(ifVar);
5412
5413 if (mlir::Value devId = dataOp.getDevice())
5414 deviceID = getDeviceID(devId);
5415
5416 mapVars = dataOp.getMapVars();
5417 useDevicePtrVars = dataOp.getUseDevicePtrVars();
5418 useDeviceAddrVars = dataOp.getUseDeviceAddrVars();
5419 return success();
5420 })
5421 .Case([&](omp::TargetEnterDataOp enterDataOp) -> LogicalResult {
5422 if (failed(checkImplementationStatus(*enterDataOp)))
5423 return failure();
5424
5425 if (auto ifVar = enterDataOp.getIfExpr())
5426 ifCond = moduleTranslation.lookupValue(ifVar);
5427
5428 if (mlir::Value devId = enterDataOp.getDevice())
5429 deviceID = getDeviceID(devId);
5430
5431 RTLFn =
5432 enterDataOp.getNowait()
5433 ? llvm::omp::OMPRTL___tgt_target_data_begin_nowait_mapper
5434 : llvm::omp::OMPRTL___tgt_target_data_begin_mapper;
5435 mapVars = enterDataOp.getMapVars();
5436 info.HasNoWait = enterDataOp.getNowait();
5437 return success();
5438 })
5439 .Case([&](omp::TargetExitDataOp exitDataOp) -> LogicalResult {
5440 if (failed(checkImplementationStatus(*exitDataOp)))
5441 return failure();
5442
5443 if (auto ifVar = exitDataOp.getIfExpr())
5444 ifCond = moduleTranslation.lookupValue(ifVar);
5445
5446 if (mlir::Value devId = exitDataOp.getDevice())
5447 deviceID = getDeviceID(devId);
5448
5449 RTLFn = exitDataOp.getNowait()
5450 ? llvm::omp::OMPRTL___tgt_target_data_end_nowait_mapper
5451 : llvm::omp::OMPRTL___tgt_target_data_end_mapper;
5452 mapVars = exitDataOp.getMapVars();
5453 info.HasNoWait = exitDataOp.getNowait();
5454 return success();
5455 })
5456 .Case([&](omp::TargetUpdateOp updateDataOp) -> LogicalResult {
5457 if (failed(checkImplementationStatus(*updateDataOp)))
5458 return failure();
5459
5460 if (auto ifVar = updateDataOp.getIfExpr())
5461 ifCond = moduleTranslation.lookupValue(ifVar);
5462
5463 if (mlir::Value devId = updateDataOp.getDevice())
5464 deviceID = getDeviceID(devId);
5465
5466 RTLFn =
5467 updateDataOp.getNowait()
5468 ? llvm::omp::OMPRTL___tgt_target_data_update_nowait_mapper
5469 : llvm::omp::OMPRTL___tgt_target_data_update_mapper;
5470 mapVars = updateDataOp.getMapVars();
5471 info.HasNoWait = updateDataOp.getNowait();
5472 return success();
5473 })
5474 .DefaultUnreachable("unexpected operation");
5475
5476 if (failed(result))
5477 return failure();
5478 // Pretend we have IF(false) if we're not doing offload.
5479 if (!isOffloadEntry)
5480 ifCond = builder.getFalse();
5481
5482 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
5483 MapInfoData mapData;
5484 collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, DL,
5485 builder, useDevicePtrVars, useDeviceAddrVars);
5486
5487 // Fill up the arrays with all the mapped variables.
5488 MapInfosTy combinedInfo;
5489 auto genMapInfoCB = [&](InsertPointTy codeGenIP) -> MapInfosTy & {
5490 builder.restoreIP(codeGenIP);
5491 genMapInfos(builder, moduleTranslation, DL, combinedInfo, mapData,
5492 targetDirective);
5493 return combinedInfo;
5494 };
5495
5496 // Define a lambda to apply mappings between use_device_addr and
5497 // use_device_ptr base pointers, and their associated block arguments.
5498 auto mapUseDevice =
5499 [&moduleTranslation](
5500 llvm::OpenMPIRBuilder::DeviceInfoTy type,
5502 llvm::SmallVectorImpl<Value> &useDeviceVars, MapInfoData &mapInfoData,
5503 llvm::function_ref<llvm::Value *(llvm::Value *)> mapper = nullptr) {
5504 for (auto [arg, useDevVar] :
5505 llvm::zip_equal(blockArgs, useDeviceVars)) {
5506
5507 auto getMapBasePtr = [](omp::MapInfoOp mapInfoOp) {
5508 return mapInfoOp.getVarPtrPtr() ? mapInfoOp.getVarPtrPtr()
5509 : mapInfoOp.getVarPtr();
5510 };
5511
5512 auto useDevMap = cast<omp::MapInfoOp>(useDevVar.getDefiningOp());
5513 for (auto [mapClause, devicePointer, basePointer] : llvm::zip_equal(
5514 mapInfoData.MapClause, mapInfoData.DevicePointers,
5515 mapInfoData.BasePointers)) {
5516 auto mapOp = cast<omp::MapInfoOp>(mapClause);
5517 if (getMapBasePtr(mapOp) != getMapBasePtr(useDevMap) ||
5518 devicePointer != type)
5519 continue;
5520
5521 if (llvm::Value *devPtrInfoMap =
5522 mapper ? mapper(basePointer) : basePointer) {
5523 moduleTranslation.mapValue(arg, devPtrInfoMap);
5524 break;
5525 }
5526 }
5527 }
5528 };
5529
5530 using BodyGenTy = llvm::OpenMPIRBuilder::BodyGenTy;
5531 auto bodyGenCB = [&](InsertPointTy codeGenIP, BodyGenTy bodyGenType)
5532 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
5533 // We must always restoreIP regardless of doing anything the caller
5534 // does not restore it, leading to incorrect (no) branch generation.
5535 builder.restoreIP(codeGenIP);
5536 assert(isa<omp::TargetDataOp>(op) &&
5537 "BodyGen requested for non TargetDataOp");
5538 auto blockArgIface = cast<omp::BlockArgOpenMPOpInterface>(op);
5539 Region &region = cast<omp::TargetDataOp>(op).getRegion();
5540 switch (bodyGenType) {
5541 case BodyGenTy::Priv:
5542 // Check if any device ptr/addr info is available
5543 if (!info.DevicePtrInfoMap.empty()) {
5544 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
5545 blockArgIface.getUseDeviceAddrBlockArgs(),
5546 useDeviceAddrVars, mapData,
5547 [&](llvm::Value *basePointer) -> llvm::Value * {
5548 if (!info.DevicePtrInfoMap[basePointer].second)
5549 return nullptr;
5550 return builder.CreateLoad(
5551 builder.getPtrTy(),
5552 info.DevicePtrInfoMap[basePointer].second);
5553 });
5554 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
5555 blockArgIface.getUseDevicePtrBlockArgs(), useDevicePtrVars,
5556 mapData, [&](llvm::Value *basePointer) {
5557 return info.DevicePtrInfoMap[basePointer].second;
5558 });
5559
5560 if (failed(inlineConvertOmpRegions(region, "omp.data.region", builder,
5561 moduleTranslation)))
5562 return llvm::make_error<PreviouslyReportedError>();
5563 }
5564 break;
5565 case BodyGenTy::DupNoPriv:
5566 if (info.DevicePtrInfoMap.empty()) {
5567 // For host device we still need to do the mapping for codegen,
5568 // otherwise it may try to lookup a missing value.
5569 if (!ompBuilder->Config.IsTargetDevice.value_or(false)) {
5570 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
5571 blockArgIface.getUseDeviceAddrBlockArgs(),
5572 useDeviceAddrVars, mapData);
5573 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
5574 blockArgIface.getUseDevicePtrBlockArgs(),
5575 useDevicePtrVars, mapData);
5576 }
5577 }
5578 break;
5579 case BodyGenTy::NoPriv:
5580 // If device info is available then region has already been generated
5581 if (info.DevicePtrInfoMap.empty()) {
5582 // For device pass, if use_device_ptr(addr) mappings were present,
5583 // we need to link them here before codegen.
5584 if (ompBuilder->Config.IsTargetDevice.value_or(false)) {
5585 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
5586 blockArgIface.getUseDeviceAddrBlockArgs(),
5587 useDeviceAddrVars, mapData);
5588 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
5589 blockArgIface.getUseDevicePtrBlockArgs(),
5590 useDevicePtrVars, mapData);
5591 }
5592
5593 if (failed(inlineConvertOmpRegions(region, "omp.data.region", builder,
5594 moduleTranslation)))
5595 return llvm::make_error<PreviouslyReportedError>();
5596 }
5597 break;
5598 }
5599 return builder.saveIP();
5600 };
5601
5602 auto customMapperCB =
5603 [&](unsigned int i) -> llvm::Expected<llvm::Function *> {
5604 if (!combinedInfo.Mappers[i])
5605 return nullptr;
5606 info.HasMapper = true;
5607 return getOrCreateUserDefinedMapperFunc(combinedInfo.Mappers[i], builder,
5608 moduleTranslation, targetDirective);
5609 };
5610
5611 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
5612 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
5613 findAllocaInsertPoint(builder, moduleTranslation);
5614 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = [&]() {
5615 if (isa<omp::TargetDataOp>(op))
5616 return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(),
5617 deviceID, ifCond, info, genMapInfoCB,
5618 customMapperCB,
5619 /*MapperFunc=*/nullptr, bodyGenCB,
5620 /*DeviceAddrCB=*/nullptr);
5621 return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(),
5622 deviceID, ifCond, info, genMapInfoCB,
5623 customMapperCB, &RTLFn);
5624 }();
5625
5626 if (failed(handleError(afterIP, *op)))
5627 return failure();
5628
5629 builder.restoreIP(*afterIP);
5630 return success();
5631}
5632
5633static LogicalResult
5634convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder,
5635 LLVM::ModuleTranslation &moduleTranslation) {
5636 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
5637 auto distributeOp = cast<omp::DistributeOp>(opInst);
5638 if (failed(checkImplementationStatus(opInst)))
5639 return failure();
5640
5641 /// Process teams op reduction in distribute if the reduction is contained in
5642 /// the distribute op.
5643 omp::TeamsOp teamsOp = opInst.getParentOfType<omp::TeamsOp>();
5644 bool doDistributeReduction =
5645 teamsOp ? teamsReductionContainedInDistribute(teamsOp) : false;
5646
5647 DenseMap<Value, llvm::Value *> reductionVariableMap;
5648 unsigned numReductionVars = teamsOp ? teamsOp.getNumReductionVars() : 0;
5650 SmallVector<llvm::Value *> privateReductionVariables(numReductionVars);
5651 llvm::ArrayRef<bool> isByRef;
5652
5653 if (doDistributeReduction) {
5654 isByRef = getIsByRef(teamsOp.getReductionByref());
5655 assert(isByRef.size() == teamsOp.getNumReductionVars());
5656
5657 collectReductionDecls(teamsOp, reductionDecls);
5658 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
5659 findAllocaInsertPoint(builder, moduleTranslation);
5660
5661 MutableArrayRef<BlockArgument> reductionArgs =
5662 llvm::cast<omp::BlockArgOpenMPOpInterface>(*teamsOp)
5663 .getReductionBlockArgs();
5664
5666 teamsOp, reductionArgs, builder, moduleTranslation, allocaIP,
5667 reductionDecls, privateReductionVariables, reductionVariableMap,
5668 isByRef)))
5669 return failure();
5670 }
5671
5672 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
5673 auto bodyGenCB = [&](InsertPointTy allocaIP,
5674 InsertPointTy codeGenIP) -> llvm::Error {
5675 // Save the alloca insertion point on ModuleTranslation stack for use in
5676 // nested regions.
5678 moduleTranslation, allocaIP);
5679
5680 // DistributeOp has only one region associated with it.
5681 builder.restoreIP(codeGenIP);
5682 PrivateVarsInfo privVarsInfo(distributeOp);
5683
5685 allocatePrivateVars(builder, moduleTranslation, privVarsInfo, allocaIP);
5686 if (handleError(afterAllocas, opInst).failed())
5687 return llvm::make_error<PreviouslyReportedError>();
5688
5689 if (handleError(initPrivateVars(builder, moduleTranslation, privVarsInfo),
5690 opInst)
5691 .failed())
5692 return llvm::make_error<PreviouslyReportedError>();
5693
5694 if (failed(copyFirstPrivateVars(
5695 distributeOp, builder, moduleTranslation, privVarsInfo.mlirVars,
5696 privVarsInfo.llvmVars, privVarsInfo.privatizers,
5697 distributeOp.getPrivateNeedsBarrier())))
5698 return llvm::make_error<PreviouslyReportedError>();
5699
5700 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
5701 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
5703 convertOmpOpRegions(distributeOp.getRegion(), "omp.distribute.region",
5704 builder, moduleTranslation);
5705 if (!regionBlock)
5706 return regionBlock.takeError();
5707 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
5708
5709 // Skip applying a workshare loop below when translating 'distribute
5710 // parallel do' (it's been already handled by this point while translating
5711 // the nested omp.wsloop).
5712 if (!isa_and_present<omp::WsloopOp>(distributeOp.getNestedWrapper())) {
5713 // TODO: Add support for clauses which are valid for DISTRIBUTE
5714 // constructs. Static schedule is the default.
5715 bool hasDistSchedule = distributeOp.getDistScheduleStatic();
5716 auto schedule = hasDistSchedule ? omp::ClauseScheduleKind::Distribute
5717 : omp::ClauseScheduleKind::Static;
5718 // dist_schedule clauses are ordered - otherise this should be false
5719 bool isOrdered = hasDistSchedule;
5720 std::optional<omp::ScheduleModifier> scheduleMod;
5721 bool isSimd = false;
5722 llvm::omp::WorksharingLoopType workshareLoopType =
5723 llvm::omp::WorksharingLoopType::DistributeStaticLoop;
5724 bool loopNeedsBarrier = false;
5725 llvm::Value *chunk = moduleTranslation.lookupValue(
5726 distributeOp.getDistScheduleChunkSize());
5727 llvm::CanonicalLoopInfo *loopInfo =
5728 findCurrentLoopInfo(moduleTranslation);
5729 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
5730 ompBuilder->applyWorkshareLoop(
5731 ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
5732 convertToScheduleKind(schedule), chunk, isSimd,
5733 scheduleMod == omp::ScheduleModifier::monotonic,
5734 scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
5735 workshareLoopType, false, hasDistSchedule, chunk);
5736
5737 if (!wsloopIP)
5738 return wsloopIP.takeError();
5739 }
5740 if (failed(cleanupPrivateVars(builder, moduleTranslation,
5741 distributeOp.getLoc(), privVarsInfo.llvmVars,
5742 privVarsInfo.privatizers)))
5743 return llvm::make_error<PreviouslyReportedError>();
5744
5745 return llvm::Error::success();
5746 };
5747
5748 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
5749 findAllocaInsertPoint(builder, moduleTranslation);
5750 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
5751 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
5752 ompBuilder->createDistribute(ompLoc, allocaIP, bodyGenCB);
5753
5754 if (failed(handleError(afterIP, opInst)))
5755 return failure();
5756
5757 builder.restoreIP(*afterIP);
5758
5759 if (doDistributeReduction) {
5760 // Process the reductions if required.
5762 teamsOp, builder, moduleTranslation, allocaIP, reductionDecls,
5763 privateReductionVariables, isByRef,
5764 /*isNoWait*/ false, /*isTeamsReduction*/ true);
5765 }
5766 return success();
5767}
5768
5769/// Lowers the FlagsAttr which is applied to the module on the device
5770/// pass when offloading, this attribute contains OpenMP RTL globals that can
5771/// be passed as flags to the frontend, otherwise they are set to default
5772static LogicalResult
5773convertFlagsAttr(Operation *op, mlir::omp::FlagsAttr attribute,
5774 LLVM::ModuleTranslation &moduleTranslation) {
5775 if (!cast<mlir::ModuleOp>(op))
5776 return failure();
5777
5778 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
5779
5780 ompBuilder->M.addModuleFlag(llvm::Module::Max, "openmp-device",
5781 attribute.getOpenmpDeviceVersion());
5782
5783 if (attribute.getNoGpuLib())
5784 return success();
5785
5786 ompBuilder->createGlobalFlag(
5787 attribute.getDebugKind() /*LangOpts().OpenMPTargetDebug*/,
5788 "__omp_rtl_debug_kind");
5789 ompBuilder->createGlobalFlag(
5790 attribute
5791 .getAssumeTeamsOversubscription() /*LangOpts().OpenMPTeamSubscription*/
5792 ,
5793 "__omp_rtl_assume_teams_oversubscription");
5794 ompBuilder->createGlobalFlag(
5795 attribute
5796 .getAssumeThreadsOversubscription() /*LangOpts().OpenMPThreadSubscription*/
5797 ,
5798 "__omp_rtl_assume_threads_oversubscription");
5799 ompBuilder->createGlobalFlag(
5800 attribute.getAssumeNoThreadState() /*LangOpts().OpenMPNoThreadState*/,
5801 "__omp_rtl_assume_no_thread_state");
5802 ompBuilder->createGlobalFlag(
5803 attribute
5804 .getAssumeNoNestedParallelism() /*LangOpts().OpenMPNoNestedParallelism*/
5805 ,
5806 "__omp_rtl_assume_no_nested_parallelism");
5807 return success();
5808}
5809
5810static void getTargetEntryUniqueInfo(llvm::TargetRegionEntryInfo &targetInfo,
5811 omp::TargetOp targetOp,
5812 llvm::StringRef parentName = "") {
5813 auto fileLoc = targetOp.getLoc()->findInstanceOf<FileLineColLoc>();
5814
5815 assert(fileLoc && "No file found from location");
5816 StringRef fileName = fileLoc.getFilename().getValue();
5817
5818 llvm::sys::fs::UniqueID id;
5819 uint64_t line = fileLoc.getLine();
5820 if (auto ec = llvm::sys::fs::getUniqueID(fileName, id)) {
5821 size_t fileHash = llvm::hash_value(fileName.str());
5822 size_t deviceId = 0xdeadf17e;
5823 targetInfo =
5824 llvm::TargetRegionEntryInfo(parentName, deviceId, fileHash, line);
5825 } else {
5826 targetInfo = llvm::TargetRegionEntryInfo(parentName, id.getDevice(),
5827 id.getFile(), line);
5828 }
5829}
5830
5831static void
5832handleDeclareTargetMapVar(MapInfoData &mapData,
5833 LLVM::ModuleTranslation &moduleTranslation,
5834 llvm::IRBuilderBase &builder, llvm::Function *func) {
5835 assert(moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
5836 "function only supported for target device codegen");
5837 llvm::IRBuilderBase::InsertPointGuard guard(builder);
5838 for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
5839 // In the case of declare target mapped variables, the basePointer is
5840 // the reference pointer generated by the convertDeclareTargetAttr
5841 // method. Whereas the kernelValue is the original variable, so for
5842 // the device we must replace all uses of this original global variable
5843 // (stored in kernelValue) with the reference pointer (stored in
5844 // basePointer for declare target mapped variables), as for device the
5845 // data is mapped into this reference pointer and should be loaded
5846 // from it, the original variable is discarded. On host both exist and
5847 // metadata is generated (elsewhere in the convertDeclareTargetAttr)
5848 // function to link the two variables in the runtime and then both the
5849 // reference pointer and the pointer are assigned in the kernel argument
5850 // structure for the host.
5851 if (mapData.IsDeclareTarget[i]) {
5852 // If the original map value is a constant, then we have to make sure all
5853 // of it's uses within the current kernel/function that we are going to
5854 // rewrite are converted to instructions, as we will be altering the old
5855 // use (OriginalValue) from a constant to an instruction, which will be
5856 // illegal and ICE the compiler if the user is a constant expression of
5857 // some kind e.g. a constant GEP.
5858 if (auto *constant = dyn_cast<llvm::Constant>(mapData.OriginalValue[i]))
5859 convertUsersOfConstantsToInstructions(constant, func, false);
5860
5861 // The users iterator will get invalidated if we modify an element,
5862 // so we populate this vector of uses to alter each user on an
5863 // individual basis to emit its own load (rather than one load for
5864 // all).
5866 for (llvm::User *user : mapData.OriginalValue[i]->users())
5867 userVec.push_back(user);
5868
5869 for (llvm::User *user : userVec) {
5870 if (auto *insn = dyn_cast<llvm::Instruction>(user)) {
5871 if (insn->getFunction() == func) {
5872 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
5873 llvm::Value *substitute = mapData.BasePointers[i];
5874 if (isDeclareTargetLink(mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr()
5875 : mapOp.getVarPtr())) {
5876 builder.SetCurrentDebugLocation(insn->getDebugLoc());
5877 substitute = builder.CreateLoad(
5878 mapData.BasePointers[i]->getType(), mapData.BasePointers[i]);
5879 cast<llvm::LoadInst>(substitute)->moveBefore(insn->getIterator());
5880 }
5881 user->replaceUsesOfWith(mapData.OriginalValue[i], substitute);
5882 }
5883 }
5884 }
5885 }
5886 }
5887}
5888
5889// The createDeviceArgumentAccessor function generates
5890// instructions for retrieving (acessing) kernel
5891// arguments inside of the device kernel for use by
5892// the kernel. This enables different semantics such as
5893// the creation of temporary copies of data allowing
5894// semantics like read-only/no host write back kernel
5895// arguments.
5896//
5897// This currently implements a very light version of Clang's
5898// EmitParmDecl's handling of direct argument handling as well
5899// as a portion of the argument access generation based on
5900// capture types found at the end of emitOutlinedFunctionPrologue
5901// in Clang. The indirect path handling of EmitParmDecl's may be
5902// required for future work, but a direct 1-to-1 copy doesn't seem
5903// possible as the logic is rather scattered throughout Clang's
5904// lowering and perhaps we wish to deviate slightly.
5905//
5906// \param mapData - A container containing vectors of information
5907// corresponding to the input argument, which should have a
5908// corresponding entry in the MapInfoData containers
5909// OrigialValue's.
5910// \param arg - This is the generated kernel function argument that
5911// corresponds to the passed in input argument. We generated different
5912// accesses of this Argument, based on capture type and other Input
5913// related information.
5914// \param input - This is the host side value that will be passed to
5915// the kernel i.e. the kernel input, we rewrite all uses of this within
5916// the kernel (as we generate the kernel body based on the target's region
5917// which maintians references to the original input) to the retVal argument
5918// apon exit of this function inside of the OMPIRBuilder. This interlinks
5919// the kernel argument to future uses of it in the function providing
5920// appropriate "glue" instructions inbetween.
5921// \param retVal - This is the value that all uses of input inside of the
5922// kernel will be re-written to, the goal of this function is to generate
5923// an appropriate location for the kernel argument to be accessed from,
5924// e.g. ByRef will result in a temporary allocation location and then
5925// a store of the kernel argument into this allocated memory which
5926// will then be loaded from, ByCopy will use the allocated memory
5927// directly.
5928static llvm::IRBuilderBase::InsertPoint
5929createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg,
5930 llvm::Value *input, llvm::Value *&retVal,
5931 llvm::IRBuilderBase &builder,
5932 llvm::OpenMPIRBuilder &ompBuilder,
5933 LLVM::ModuleTranslation &moduleTranslation,
5934 llvm::IRBuilderBase::InsertPoint allocaIP,
5935 llvm::IRBuilderBase::InsertPoint codeGenIP) {
5936 assert(ompBuilder.Config.isTargetDevice() &&
5937 "function only supported for target device codegen");
5938 builder.restoreIP(allocaIP);
5939
5940 omp::VariableCaptureKind capture = omp::VariableCaptureKind::ByRef;
5941 LLVM::TypeToLLVMIRTranslator typeToLLVMIRTranslator(
5942 ompBuilder.M.getContext());
5943 unsigned alignmentValue = 0;
5944 // Find the associated MapInfoData entry for the current input
5945 for (size_t i = 0; i < mapData.MapClause.size(); ++i)
5946 if (mapData.OriginalValue[i] == input) {
5947 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
5948 capture = mapOp.getMapCaptureType();
5949 // Get information of alignment of mapped object
5950 alignmentValue = typeToLLVMIRTranslator.getPreferredAlignment(
5951 mapOp.getVarType(), ompBuilder.M.getDataLayout());
5952 break;
5953 }
5954
5955 unsigned int allocaAS = ompBuilder.M.getDataLayout().getAllocaAddrSpace();
5956 unsigned int defaultAS =
5957 ompBuilder.M.getDataLayout().getProgramAddressSpace();
5958
5959 // Create the alloca for the argument the current point.
5960 llvm::Value *v = builder.CreateAlloca(arg.getType(), allocaAS);
5961
5962 if (allocaAS != defaultAS && arg.getType()->isPointerTy())
5963 v = builder.CreateAddrSpaceCast(v, builder.getPtrTy(defaultAS));
5964
5965 builder.CreateStore(&arg, v);
5966
5967 builder.restoreIP(codeGenIP);
5968
5969 switch (capture) {
5970 case omp::VariableCaptureKind::ByCopy: {
5971 retVal = v;
5972 break;
5973 }
5974 case omp::VariableCaptureKind::ByRef: {
5975 llvm::LoadInst *loadInst = builder.CreateAlignedLoad(
5976 v->getType(), v,
5977 ompBuilder.M.getDataLayout().getPrefTypeAlign(v->getType()));
5978 // CreateAlignedLoad function creates similar LLVM IR:
5979 // %res = load ptr, ptr %input, align 8
5980 // This LLVM IR does not contain information about alignment
5981 // of the loaded value. We need to add !align metadata to unblock
5982 // optimizer. The existence of the !align metadata on the instruction
5983 // tells the optimizer that the value loaded is known to be aligned to
5984 // a boundary specified by the integer value in the metadata node.
5985 // Example:
5986 // %res = load ptr, ptr %input, align 8, !align !align_md_node
5987 // ^ ^
5988 // | |
5989 // alignment of %input address |
5990 // |
5991 // alignment of %res object
5992 if (v->getType()->isPointerTy() && alignmentValue) {
5993 llvm::MDBuilder MDB(builder.getContext());
5994 loadInst->setMetadata(
5995 llvm::LLVMContext::MD_align,
5996 llvm::MDNode::get(builder.getContext(),
5997 MDB.createConstant(llvm::ConstantInt::get(
5998 llvm::Type::getInt64Ty(builder.getContext()),
5999 alignmentValue))));
6000 }
6001 retVal = loadInst;
6002
6003 break;
6004 }
6005 case omp::VariableCaptureKind::This:
6006 case omp::VariableCaptureKind::VLAType:
6007 // TODO: Consider returning error to use standard reporting for
6008 // unimplemented features.
6009 assert(false && "Currently unsupported capture kind");
6010 break;
6011 }
6012
6013 return builder.saveIP();
6014}
6015
6016/// Follow uses of `host_eval`-defined block arguments of the given `omp.target`
6017/// operation and populate output variables with their corresponding host value
6018/// (i.e. operand evaluated outside of the target region), based on their uses
6019/// inside of the target region.
6020///
6021/// Loop bounds and steps are only optionally populated, if output vectors are
6022/// provided.
6023static void
6024extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
6025 Value &numTeamsLower, Value &numTeamsUpper,
6026 Value &threadLimit,
6027 llvm::SmallVectorImpl<Value> *lowerBounds = nullptr,
6028 llvm::SmallVectorImpl<Value> *upperBounds = nullptr,
6029 llvm::SmallVectorImpl<Value> *steps = nullptr) {
6030 auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
6031 for (auto item : llvm::zip_equal(targetOp.getHostEvalVars(),
6032 blockArgIface.getHostEvalBlockArgs())) {
6033 Value hostEvalVar = std::get<0>(item), blockArg = std::get<1>(item);
6034
6035 for (Operation *user : blockArg.getUsers()) {
6037 .Case([&](omp::TeamsOp teamsOp) {
6038 if (teamsOp.getNumTeamsLower() == blockArg)
6039 numTeamsLower = hostEvalVar;
6040 else if (teamsOp.getNumTeamsUpper() == blockArg)
6041 numTeamsUpper = hostEvalVar;
6042 else if (teamsOp.getThreadLimit() == blockArg)
6043 threadLimit = hostEvalVar;
6044 else
6045 llvm_unreachable("unsupported host_eval use");
6046 })
6047 .Case([&](omp::ParallelOp parallelOp) {
6048 if (parallelOp.getNumThreads() == blockArg)
6049 numThreads = hostEvalVar;
6050 else
6051 llvm_unreachable("unsupported host_eval use");
6052 })
6053 .Case([&](omp::LoopNestOp loopOp) {
6054 auto processBounds =
6055 [&](OperandRange opBounds,
6056 llvm::SmallVectorImpl<Value> *outBounds) -> bool {
6057 bool found = false;
6058 for (auto [i, lb] : llvm::enumerate(opBounds)) {
6059 if (lb == blockArg) {
6060 found = true;
6061 if (outBounds)
6062 (*outBounds)[i] = hostEvalVar;
6063 }
6064 }
6065 return found;
6066 };
6067 bool found =
6068 processBounds(loopOp.getLoopLowerBounds(), lowerBounds);
6069 found = processBounds(loopOp.getLoopUpperBounds(), upperBounds) ||
6070 found;
6071 found = processBounds(loopOp.getLoopSteps(), steps) || found;
6072 (void)found;
6073 assert(found && "unsupported host_eval use");
6074 })
6075 .DefaultUnreachable("unsupported host_eval use");
6076 }
6077 }
6078}
6079
6080/// If \p op is of the given type parameter, return it casted to that type.
6081/// Otherwise, if its immediate parent operation (or some other higher-level
6082/// parent, if \p immediateParent is false) is of that type, return that parent
6083/// casted to the given type.
6084///
6085/// If \p op is \c null or neither it or its parent(s) are of the specified
6086/// type, return a \c null operation.
6087template <typename OpTy>
6088static OpTy castOrGetParentOfType(Operation *op, bool immediateParent = false) {
6089 if (!op)
6090 return OpTy();
6091
6092 if (OpTy casted = dyn_cast<OpTy>(op))
6093 return casted;
6094
6095 if (immediateParent)
6096 return dyn_cast_if_present<OpTy>(op->getParentOp());
6097
6098 return op->getParentOfType<OpTy>();
6099}
6100
6101/// If the given \p value is defined by an \c llvm.mlir.constant operation and
6102/// it is of an integer type, return its value.
6103static std::optional<int64_t> extractConstInteger(Value value) {
6104 if (!value)
6105 return std::nullopt;
6106
6107 if (auto constOp = value.getDefiningOp<LLVM::ConstantOp>())
6108 if (auto constAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
6109 return constAttr.getInt();
6110
6111 return std::nullopt;
6112}
6113
6114static uint64_t getTypeByteSize(mlir::Type type, const DataLayout &dl) {
6115 uint64_t sizeInBits = dl.getTypeSizeInBits(type);
6116 uint64_t sizeInBytes = sizeInBits / 8;
6117 return sizeInBytes;
6118}
6119
6120template <typename OpTy>
6121static uint64_t getReductionDataSize(OpTy &op) {
6122 if (op.getNumReductionVars() > 0) {
6124 collectReductionDecls(op, reductions);
6125
6127 members.reserve(reductions.size());
6128 for (omp::DeclareReductionOp &red : reductions)
6129 members.push_back(red.getType());
6130 Operation *opp = op.getOperation();
6131 auto structType = mlir::LLVM::LLVMStructType::getLiteral(
6132 opp->getContext(), members, /*isPacked=*/false);
6133 DataLayout dl = DataLayout(opp->getParentOfType<ModuleOp>());
6134 return getTypeByteSize(structType, dl);
6135 }
6136 return 0;
6137}
6138
6139/// Populate default `MinTeams`, `MaxTeams` and `MaxThreads` to their default
6140/// values as stated by the corresponding clauses, if constant.
6141///
6142/// These default values must be set before the creation of the outlined LLVM
6143/// function for the target region, so that they can be used to initialize the
6144/// corresponding global `ConfigurationEnvironmentTy` structure.
6145static void
6146initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
6147 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs,
6148 bool isTargetDevice, bool isGPU) {
6149 // TODO: Handle constant 'if' clauses.
6150
6151 Value numThreads, numTeamsLower, numTeamsUpper, threadLimit;
6152 if (!isTargetDevice) {
6153 extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper,
6154 threadLimit);
6155 } else {
6156 // In the target device, values for these clauses are not passed as
6157 // host_eval, but instead evaluated prior to entry to the region. This
6158 // ensures values are mapped and available inside of the target region.
6159 if (auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
6160 numTeamsLower = teamsOp.getNumTeamsLower();
6161 numTeamsUpper = teamsOp.getNumTeamsUpper();
6162 threadLimit = teamsOp.getThreadLimit();
6163 }
6164
6165 if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp))
6166 numThreads = parallelOp.getNumThreads();
6167 }
6168
6169 // Handle clauses impacting the number of teams.
6170
6171 int32_t minTeamsVal = 1, maxTeamsVal = -1;
6172 if (castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
6173 // TODO: Use `hostNumTeamsLower` to initialize `minTeamsVal`. For now,
6174 // match clang and set min and max to the same value.
6175 if (numTeamsUpper) {
6176 if (auto val = extractConstInteger(numTeamsUpper))
6177 minTeamsVal = maxTeamsVal = *val;
6178 } else {
6179 minTeamsVal = maxTeamsVal = 0;
6180 }
6181 } else if (castOrGetParentOfType<omp::ParallelOp>(capturedOp,
6182 /*immediateParent=*/true) ||
6184 /*immediateParent=*/true)) {
6185 minTeamsVal = maxTeamsVal = 1;
6186 } else {
6187 minTeamsVal = maxTeamsVal = -1;
6188 }
6189
6190 // Handle clauses impacting the number of threads.
6191
6192 auto setMaxValueFromClause = [](Value clauseValue, int32_t &result) {
6193 if (!clauseValue)
6194 return;
6195
6196 if (auto val = extractConstInteger(clauseValue))
6197 result = *val;
6198
6199 // Found an applicable clause, so it's not undefined. Mark as unknown
6200 // because it's not constant.
6201 if (result < 0)
6202 result = 0;
6203 };
6204
6205 // Extract 'thread_limit' clause from 'target' and 'teams' directives.
6206 int32_t targetThreadLimitVal = -1, teamsThreadLimitVal = -1;
6207 setMaxValueFromClause(targetOp.getThreadLimit(), targetThreadLimitVal);
6208 setMaxValueFromClause(threadLimit, teamsThreadLimitVal);
6209
6210 // Extract 'max_threads' clause from 'parallel' or set to 1 if it's SIMD.
6211 int32_t maxThreadsVal = -1;
6213 setMaxValueFromClause(numThreads, maxThreadsVal);
6214 else if (castOrGetParentOfType<omp::SimdOp>(capturedOp,
6215 /*immediateParent=*/true))
6216 maxThreadsVal = 1;
6217
6218 // For max values, < 0 means unset, == 0 means set but unknown. Select the
6219 // minimum value between 'max_threads' and 'thread_limit' clauses that were
6220 // set.
6221 int32_t combinedMaxThreadsVal = targetThreadLimitVal;
6222 if (combinedMaxThreadsVal < 0 ||
6223 (teamsThreadLimitVal >= 0 && teamsThreadLimitVal < combinedMaxThreadsVal))
6224 combinedMaxThreadsVal = teamsThreadLimitVal;
6225
6226 if (combinedMaxThreadsVal < 0 ||
6227 (maxThreadsVal >= 0 && maxThreadsVal < combinedMaxThreadsVal))
6228 combinedMaxThreadsVal = maxThreadsVal;
6229
6230 int32_t reductionDataSize = 0;
6231 if (isGPU && capturedOp) {
6232 if (auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp))
6233 reductionDataSize = getReductionDataSize(teamsOp);
6234 }
6235
6236 // Update kernel bounds structure for the `OpenMPIRBuilder` to use.
6237 omp::TargetRegionFlags kernelFlags = targetOp.getKernelExecFlags(capturedOp);
6238 assert(
6239 omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic |
6240 omp::TargetRegionFlags::spmd) &&
6241 "invalid kernel flags");
6242 attrs.ExecFlags =
6243 omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic)
6244 ? omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::spmd)
6245 ? llvm::omp::OMP_TGT_EXEC_MODE_GENERIC_SPMD
6246 : llvm::omp::OMP_TGT_EXEC_MODE_GENERIC
6247 : llvm::omp::OMP_TGT_EXEC_MODE_SPMD;
6248 if (omp::bitEnumContainsAll(kernelFlags,
6249 omp::TargetRegionFlags::spmd |
6250 omp::TargetRegionFlags::no_loop) &&
6251 !omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic))
6252 attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_SPMD_NO_LOOP;
6253
6254 attrs.MinTeams = minTeamsVal;
6255 attrs.MaxTeams.front() = maxTeamsVal;
6256 attrs.MinThreads = 1;
6257 attrs.MaxThreads.front() = combinedMaxThreadsVal;
6258 attrs.ReductionDataSize = reductionDataSize;
6259 // TODO: Allow modified buffer length similar to
6260 // fopenmp-cuda-teams-reduction-recs-num flag in clang.
6261 if (attrs.ReductionDataSize != 0)
6262 attrs.ReductionBufferLength = 1024;
6263}
6264
6265/// Gather LLVM runtime values for all clauses evaluated in the host that are
6266/// passed to the kernel invocation.
6267///
6268/// This function must be called only when compiling for the host. Also, it will
6269/// only provide correct results if it's called after the body of \c targetOp
6270/// has been fully generated.
6271static void
6272initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
6273 LLVM::ModuleTranslation &moduleTranslation,
6274 omp::TargetOp targetOp, Operation *capturedOp,
6275 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs) {
6276 omp::LoopNestOp loopOp = castOrGetParentOfType<omp::LoopNestOp>(capturedOp);
6277 unsigned numLoops = loopOp ? loopOp.getNumLoops() : 0;
6278
6279 Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
6280 llvm::SmallVector<Value> lowerBounds(numLoops), upperBounds(numLoops),
6281 steps(numLoops);
6282 extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper,
6283 teamsThreadLimit, &lowerBounds, &upperBounds, &steps);
6284
6285 // TODO: Handle constant 'if' clauses.
6286 if (Value targetThreadLimit = targetOp.getThreadLimit())
6287 attrs.TargetThreadLimit.front() =
6288 moduleTranslation.lookupValue(targetThreadLimit);
6289
6290 if (numTeamsLower)
6291 attrs.MinTeams = moduleTranslation.lookupValue(numTeamsLower);
6292
6293 if (numTeamsUpper)
6294 attrs.MaxTeams.front() = moduleTranslation.lookupValue(numTeamsUpper);
6295
6296 if (teamsThreadLimit)
6297 attrs.TeamsThreadLimit.front() =
6298 moduleTranslation.lookupValue(teamsThreadLimit);
6299
6300 if (numThreads)
6301 attrs.MaxThreads = moduleTranslation.lookupValue(numThreads);
6302
6303 if (omp::bitEnumContainsAny(targetOp.getKernelExecFlags(capturedOp),
6304 omp::TargetRegionFlags::trip_count)) {
6305 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
6306 attrs.LoopTripCount = nullptr;
6307
6308 // To calculate the trip count, we multiply together the trip counts of
6309 // every collapsed canonical loop. We don't need to create the loop nests
6310 // here, since we're only interested in the trip count.
6311 for (auto [loopLower, loopUpper, loopStep] :
6312 llvm::zip_equal(lowerBounds, upperBounds, steps)) {
6313 llvm::Value *lowerBound = moduleTranslation.lookupValue(loopLower);
6314 llvm::Value *upperBound = moduleTranslation.lookupValue(loopUpper);
6315 llvm::Value *step = moduleTranslation.lookupValue(loopStep);
6316
6317 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
6318 llvm::Value *tripCount = ompBuilder->calculateCanonicalLoopTripCount(
6319 loc, lowerBound, upperBound, step, /*IsSigned=*/true,
6320 loopOp.getLoopInclusive());
6321
6322 if (!attrs.LoopTripCount) {
6323 attrs.LoopTripCount = tripCount;
6324 continue;
6325 }
6326
6327 // TODO: Enable UndefinedSanitizer to diagnose an overflow here.
6328 attrs.LoopTripCount = builder.CreateMul(attrs.LoopTripCount, tripCount,
6329 {}, /*HasNUW=*/true);
6330 }
6331 }
6332
6333 attrs.DeviceID = builder.getInt64(llvm::omp::OMP_DEVICEID_UNDEF);
6334 if (mlir::Value devId = targetOp.getDevice()) {
6335 attrs.DeviceID = moduleTranslation.lookupValue(devId);
6336 attrs.DeviceID =
6337 builder.CreateSExtOrTrunc(attrs.DeviceID, builder.getInt64Ty());
6338 }
6339}
6340
6341static LogicalResult
6342convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
6343 LLVM::ModuleTranslation &moduleTranslation) {
6344 auto targetOp = cast<omp::TargetOp>(opInst);
6345 // The current debug location already has the DISubprogram for the outlined
6346 // function that will be created for the target op. We save it here so that
6347 // we can set it on the outlined function.
6348 llvm::DebugLoc outlinedFnLoc = builder.getCurrentDebugLocation();
6349 if (failed(checkImplementationStatus(opInst)))
6350 return failure();
6351
6352 // During the handling of target op, we will generate instructions in the
6353 // parent function like call to the oulined function or branch to a new
6354 // BasicBlock. We set the debug location here to parent function so that those
6355 // get the correct debug locations. For outlined functions, the normal MLIR op
6356 // conversion will automatically pick the correct location.
6357 llvm::BasicBlock *parentBB = builder.GetInsertBlock();
6358 assert(parentBB && "No insert block is set for the builder");
6359 llvm::Function *parentLLVMFn = parentBB->getParent();
6360 assert(parentLLVMFn && "Parent Function must be valid");
6361 if (llvm::DISubprogram *SP = parentLLVMFn->getSubprogram())
6362 builder.SetCurrentDebugLocation(llvm::DILocation::get(
6363 parentLLVMFn->getContext(), outlinedFnLoc.getLine(),
6364 outlinedFnLoc.getCol(), SP, outlinedFnLoc.getInlinedAt()));
6365
6366 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
6367 bool isTargetDevice = ompBuilder->Config.isTargetDevice();
6368 bool isGPU = ompBuilder->Config.isGPU();
6369
6370 auto parentFn = opInst.getParentOfType<LLVM::LLVMFuncOp>();
6371 auto argIface = cast<omp::BlockArgOpenMPOpInterface>(opInst);
6372 auto &targetRegion = targetOp.getRegion();
6373 // Holds the private vars that have been mapped along with the block
6374 // argument that corresponds to the MapInfoOp corresponding to the private
6375 // var in question. So, for instance:
6376 //
6377 // %10 = omp.map.info var_ptr(%6#0 : !fir.ref<!fir.box<!fir.heap<i32>>>, ..)
6378 // omp.target map_entries(%10 -> %arg0) private(@box.privatizer %6#0-> %arg1)
6379 //
6380 // Then, %10 has been created so that the descriptor can be used by the
6381 // privatizer @box.privatizer on the device side. Here we'd record {%6#0,
6382 // %arg0} in the mappedPrivateVars map.
6383 llvm::DenseMap<Value, Value> mappedPrivateVars;
6384 DataLayout dl = DataLayout(opInst.getParentOfType<ModuleOp>());
6385 SmallVector<Value> mapVars = targetOp.getMapVars();
6386 SmallVector<Value> hdaVars = targetOp.getHasDeviceAddrVars();
6387 ArrayRef<BlockArgument> mapBlockArgs = argIface.getMapBlockArgs();
6388 ArrayRef<BlockArgument> hdaBlockArgs = argIface.getHasDeviceAddrBlockArgs();
6389 llvm::Function *llvmOutlinedFn = nullptr;
6390 TargetDirectiveEnumTy targetDirective =
6391 getTargetDirectiveEnumTyFromOp(&opInst);
6392
6393 // TODO: It can also be false if a compile-time constant `false` IF clause is
6394 // specified.
6395 bool isOffloadEntry =
6396 isTargetDevice || !ompBuilder->Config.TargetTriples.empty();
6397
6398 // For some private variables, the MapsForPrivatizedVariablesPass
6399 // creates MapInfoOp instances. Go through the private variables and
6400 // the mapped variables so that during codegeneration we are able
6401 // to quickly look up the corresponding map variable, if any for each
6402 // private variable.
6403 if (!targetOp.getPrivateVars().empty() && !targetOp.getMapVars().empty()) {
6404 OperandRange privateVars = targetOp.getPrivateVars();
6405 std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms();
6406 std::optional<DenseI64ArrayAttr> privateMapIndices =
6407 targetOp.getPrivateMapsAttr();
6408
6409 for (auto [privVarIdx, privVarSymPair] :
6410 llvm::enumerate(llvm::zip_equal(privateVars, *privateSyms))) {
6411 auto privVar = std::get<0>(privVarSymPair);
6412 auto privSym = std::get<1>(privVarSymPair);
6413
6414 SymbolRefAttr privatizerName = llvm::cast<SymbolRefAttr>(privSym);
6415 omp::PrivateClauseOp privatizer =
6416 findPrivatizer(targetOp, privatizerName);
6417
6418 if (!privatizer.needsMap())
6419 continue;
6420
6421 mlir::Value mappedValue =
6422 targetOp.getMappedValueForPrivateVar(privVarIdx);
6423 assert(mappedValue && "Expected to find mapped value for a privatized "
6424 "variable that needs mapping");
6425
6426 // The MapInfoOp defining the map var isn't really needed later.
6427 // So, we don't store it in any datastructure. Instead, we just
6428 // do some sanity checks on it right now.
6429 auto mapInfoOp = mappedValue.getDefiningOp<omp::MapInfoOp>();
6430 [[maybe_unused]] Type varType = mapInfoOp.getVarType();
6431
6432 // Check #1: Check that the type of the private variable matches
6433 // the type of the variable being mapped.
6434 if (!isa<LLVM::LLVMPointerType>(privVar.getType()))
6435 assert(
6436 varType == privVar.getType() &&
6437 "Type of private var doesn't match the type of the mapped value");
6438
6439 // Ok, only 1 sanity check for now.
6440 // Record the block argument corresponding to this mapvar.
6441 mappedPrivateVars.insert(
6442 {privVar,
6443 targetRegion.getArgument(argIface.getMapBlockArgsStart() +
6444 (*privateMapIndices)[privVarIdx])});
6445 }
6446 }
6447
6448 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
6449 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP)
6450 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
6451 llvm::IRBuilderBase::InsertPointGuard guard(builder);
6452 builder.SetCurrentDebugLocation(llvm::DebugLoc());
6453 // Forward target-cpu and target-features function attributes from the
6454 // original function to the new outlined function.
6455 llvm::Function *llvmParentFn =
6456 moduleTranslation.lookupFunction(parentFn.getName());
6457 llvmOutlinedFn = codeGenIP.getBlock()->getParent();
6458 assert(llvmParentFn && llvmOutlinedFn &&
6459 "Both parent and outlined functions must exist at this point");
6460
6461 if (outlinedFnLoc && llvmParentFn->getSubprogram())
6462 llvmOutlinedFn->setSubprogram(outlinedFnLoc->getScope()->getSubprogram());
6463
6464 if (auto attr = llvmParentFn->getFnAttribute("target-cpu");
6465 attr.isStringAttribute())
6466 llvmOutlinedFn->addFnAttr(attr);
6467
6468 if (auto attr = llvmParentFn->getFnAttribute("target-features");
6469 attr.isStringAttribute())
6470 llvmOutlinedFn->addFnAttr(attr);
6471
6472 for (auto [arg, mapOp] : llvm::zip_equal(mapBlockArgs, mapVars)) {
6473 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
6474 llvm::Value *mapOpValue =
6475 moduleTranslation.lookupValue(mapInfoOp.getVarPtr());
6476 moduleTranslation.mapValue(arg, mapOpValue);
6477 }
6478 for (auto [arg, mapOp] : llvm::zip_equal(hdaBlockArgs, hdaVars)) {
6479 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
6480 llvm::Value *mapOpValue =
6481 moduleTranslation.lookupValue(mapInfoOp.getVarPtr());
6482 moduleTranslation.mapValue(arg, mapOpValue);
6483 }
6484
6485 // Do privatization after moduleTranslation has already recorded
6486 // mapped values.
6487 PrivateVarsInfo privateVarsInfo(targetOp);
6488
6490 allocatePrivateVars(builder, moduleTranslation, privateVarsInfo,
6491 allocaIP, &mappedPrivateVars);
6492
6493 if (failed(handleError(afterAllocas, *targetOp)))
6494 return llvm::make_error<PreviouslyReportedError>();
6495
6496 builder.restoreIP(codeGenIP);
6497 if (handleError(initPrivateVars(builder, moduleTranslation, privateVarsInfo,
6498 &mappedPrivateVars),
6499 *targetOp)
6500 .failed())
6501 return llvm::make_error<PreviouslyReportedError>();
6502
6503 if (failed(copyFirstPrivateVars(
6504 targetOp, builder, moduleTranslation, privateVarsInfo.mlirVars,
6505 privateVarsInfo.llvmVars, privateVarsInfo.privatizers,
6506 targetOp.getPrivateNeedsBarrier(), &mappedPrivateVars)))
6507 return llvm::make_error<PreviouslyReportedError>();
6508
6509 SmallVector<Region *> privateCleanupRegions;
6510 llvm::transform(privateVarsInfo.privatizers,
6511 std::back_inserter(privateCleanupRegions),
6512 [](omp::PrivateClauseOp privatizer) {
6513 return &privatizer.getDeallocRegion();
6514 });
6515
6517 targetRegion, "omp.target", builder, moduleTranslation);
6518
6519 if (!exitBlock)
6520 return exitBlock.takeError();
6521
6522 builder.SetInsertPoint(*exitBlock);
6523 if (!privateCleanupRegions.empty()) {
6524 if (failed(inlineOmpRegionCleanup(
6525 privateCleanupRegions, privateVarsInfo.llvmVars,
6526 moduleTranslation, builder, "omp.targetop.private.cleanup",
6527 /*shouldLoadCleanupRegionArg=*/false))) {
6528 return llvm::createStringError(
6529 "failed to inline `dealloc` region of `omp.private` "
6530 "op in the target region");
6531 }
6532 return builder.saveIP();
6533 }
6534
6535 return InsertPointTy(exitBlock.get(), exitBlock.get()->end());
6536 };
6537
6538 StringRef parentName = parentFn.getName();
6539
6540 llvm::TargetRegionEntryInfo entryInfo;
6541
6542 getTargetEntryUniqueInfo(entryInfo, targetOp, parentName);
6543
6544 MapInfoData mapData;
6545 collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, dl,
6546 builder, /*useDevPtrOperands=*/{},
6547 /*useDevAddrOperands=*/{}, hdaVars);
6548
6549 MapInfosTy combinedInfos;
6550 auto genMapInfoCB =
6551 [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP) -> MapInfosTy & {
6552 builder.restoreIP(codeGenIP);
6553 genMapInfos(builder, moduleTranslation, dl, combinedInfos, mapData,
6554 targetDirective);
6555 return combinedInfos;
6556 };
6557
6558 auto argAccessorCB = [&](llvm::Argument &arg, llvm::Value *input,
6559 llvm::Value *&retVal, InsertPointTy allocaIP,
6560 InsertPointTy codeGenIP)
6561 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
6562 llvm::IRBuilderBase::InsertPointGuard guard(builder);
6563 builder.SetCurrentDebugLocation(llvm::DebugLoc());
6564 // We just return the unaltered argument for the host function
6565 // for now, some alterations may be required in the future to
6566 // keep host fallback functions working identically to the device
6567 // version (e.g. pass ByCopy values should be treated as such on
6568 // host and device, currently not always the case)
6569 if (!isTargetDevice) {
6570 retVal = cast<llvm::Value>(&arg);
6571 return codeGenIP;
6572 }
6573
6574 return createDeviceArgumentAccessor(mapData, arg, input, retVal, builder,
6575 *ompBuilder, moduleTranslation,
6576 allocaIP, codeGenIP);
6577 };
6578
6579 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs runtimeAttrs;
6580 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs;
6581 Operation *targetCapturedOp = targetOp.getInnermostCapturedOmpOp();
6582 initTargetDefaultAttrs(targetOp, targetCapturedOp, defaultAttrs,
6583 isTargetDevice, isGPU);
6584
6585 // Collect host-evaluated values needed to properly launch the kernel from the
6586 // host.
6587 if (!isTargetDevice)
6588 initTargetRuntimeAttrs(builder, moduleTranslation, targetOp,
6589 targetCapturedOp, runtimeAttrs);
6590
6591 // Pass host-evaluated values as parameters to the kernel / host fallback,
6592 // except if they are constants. In any case, map the MLIR block argument to
6593 // the corresponding LLVM values.
6595 SmallVector<Value> hostEvalVars = targetOp.getHostEvalVars();
6596 ArrayRef<BlockArgument> hostEvalBlockArgs = argIface.getHostEvalBlockArgs();
6597 for (auto [arg, var] : llvm::zip_equal(hostEvalBlockArgs, hostEvalVars)) {
6598 llvm::Value *value = moduleTranslation.lookupValue(var);
6599 moduleTranslation.mapValue(arg, value);
6600
6601 if (!llvm::isa<llvm::Constant>(value))
6602 kernelInput.push_back(value);
6603 }
6604
6605 for (size_t i = 0, e = mapData.OriginalValue.size(); i != e; ++i) {
6606 // declare target arguments are not passed to kernels as arguments
6607 // TODO: We currently do not handle cases where a member is explicitly
6608 // passed in as an argument, this will likley need to be handled in
6609 // the near future, rather than using IsAMember, it may be better to
6610 // test if the relevant BlockArg is used within the target region and
6611 // then use that as a basis for exclusion in the kernel inputs.
6612 if (!mapData.IsDeclareTarget[i] && !mapData.IsAMember[i])
6613 kernelInput.push_back(mapData.OriginalValue[i]);
6614 }
6615
6617 buildDependData(targetOp.getDependKinds(), targetOp.getDependVars(),
6618 moduleTranslation, dds);
6619
6620 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
6621 findAllocaInsertPoint(builder, moduleTranslation);
6622 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
6623
6624 llvm::OpenMPIRBuilder::TargetDataInfo info(
6625 /*RequiresDevicePointerInfo=*/false,
6626 /*SeparateBeginEndCalls=*/true);
6627
6628 auto customMapperCB =
6629 [&](unsigned int i) -> llvm::Expected<llvm::Function *> {
6630 if (!combinedInfos.Mappers[i])
6631 return nullptr;
6632 info.HasMapper = true;
6633 return getOrCreateUserDefinedMapperFunc(combinedInfos.Mappers[i], builder,
6634 moduleTranslation, targetDirective);
6635 };
6636
6637 llvm::Value *ifCond = nullptr;
6638 if (Value targetIfCond = targetOp.getIfExpr())
6639 ifCond = moduleTranslation.lookupValue(targetIfCond);
6640
6641 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
6642 moduleTranslation.getOpenMPBuilder()->createTarget(
6643 ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), info, entryInfo,
6644 defaultAttrs, runtimeAttrs, ifCond, kernelInput, genMapInfoCB, bodyCB,
6645 argAccessorCB, customMapperCB, dds, targetOp.getNowait());
6646
6647 if (failed(handleError(afterIP, opInst)))
6648 return failure();
6649
6650 builder.restoreIP(*afterIP);
6651
6652 // Remap access operations to declare target reference pointers for the
6653 // device, essentially generating extra loadop's as necessary
6654 if (moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice())
6655 handleDeclareTargetMapVar(mapData, moduleTranslation, builder,
6656 llvmOutlinedFn);
6657
6658 return success();
6659}
6660
6661static LogicalResult
6662convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute,
6663 LLVM::ModuleTranslation &moduleTranslation) {
6664 // Amend omp.declare_target by deleting the IR of the outlined functions
6665 // created for target regions. They cannot be filtered out from MLIR earlier
6666 // because the omp.target operation inside must be translated to LLVM, but
6667 // the wrapper functions themselves must not remain at the end of the
6668 // process. We know that functions where omp.declare_target does not match
6669 // omp.is_target_device at this stage can only be wrapper functions because
6670 // those that aren't are removed earlier as an MLIR transformation pass.
6671 if (FunctionOpInterface funcOp = dyn_cast<FunctionOpInterface>(op)) {
6672 if (auto offloadMod = dyn_cast<omp::OffloadModuleInterface>(
6673 op->getParentOfType<ModuleOp>().getOperation())) {
6674 if (!offloadMod.getIsTargetDevice())
6675 return success();
6676
6677 omp::DeclareTargetDeviceType declareType =
6678 attribute.getDeviceType().getValue();
6679
6680 if (declareType == omp::DeclareTargetDeviceType::host) {
6681 llvm::Function *llvmFunc =
6682 moduleTranslation.lookupFunction(funcOp.getName());
6683 llvmFunc->dropAllReferences();
6684 llvmFunc->eraseFromParent();
6685 }
6686 }
6687 return success();
6688 }
6689
6690 if (LLVM::GlobalOp gOp = dyn_cast<LLVM::GlobalOp>(op)) {
6691 llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
6692 if (auto *gVal = llvmModule->getNamedValue(gOp.getSymName())) {
6693 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
6694 bool isDeclaration = gOp.isDeclaration();
6695 bool isExternallyVisible =
6696 gOp.getVisibility() != mlir::SymbolTable::Visibility::Private;
6697 auto loc = op->getLoc()->findInstanceOf<FileLineColLoc>();
6698 llvm::StringRef mangledName = gOp.getSymName();
6699 auto captureClause =
6700 convertToCaptureClauseKind(attribute.getCaptureClause().getValue());
6701 auto deviceClause =
6702 convertToDeviceClauseKind(attribute.getDeviceType().getValue());
6703 // unused for MLIR at the moment, required in Clang for book
6704 // keeping
6705 std::vector<llvm::GlobalVariable *> generatedRefs;
6706
6707 std::vector<llvm::Triple> targetTriple;
6708 auto targetTripleAttr = dyn_cast_or_null<mlir::StringAttr>(
6709 op->getParentOfType<mlir::ModuleOp>()->getAttr(
6710 LLVM::LLVMDialect::getTargetTripleAttrName()));
6711 if (targetTripleAttr)
6712 targetTriple.emplace_back(targetTripleAttr.data());
6713
6714 auto fileInfoCallBack = [&loc]() {
6715 std::string filename = "";
6716 std::uint64_t lineNo = 0;
6717
6718 if (loc) {
6719 filename = loc.getFilename().str();
6720 lineNo = loc.getLine();
6721 }
6722
6723 return std::pair<std::string, std::uint64_t>(llvm::StringRef(filename),
6724 lineNo);
6725 };
6726
6727 auto vfs = llvm::vfs::getRealFileSystem();
6728
6729 ompBuilder->registerTargetGlobalVariable(
6730 captureClause, deviceClause, isDeclaration, isExternallyVisible,
6731 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack, *vfs),
6732 mangledName, generatedRefs, /*OpenMPSimd*/ false, targetTriple,
6733 /*GlobalInitializer*/ nullptr, /*VariableLinkage*/ nullptr,
6734 gVal->getType(), gVal);
6735
6736 if (ompBuilder->Config.isTargetDevice() &&
6737 (attribute.getCaptureClause().getValue() !=
6738 mlir::omp::DeclareTargetCaptureClause::to ||
6739 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
6740 ompBuilder->getAddrOfDeclareTargetVar(
6741 captureClause, deviceClause, isDeclaration, isExternallyVisible,
6742 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack, *vfs),
6743 mangledName, generatedRefs, /*OpenMPSimd*/ false, targetTriple,
6744 gVal->getType(), /*GlobalInitializer*/ nullptr,
6745 /*VariableLinkage*/ nullptr);
6746 }
6747 }
6748 }
6749
6750 return success();
6751}
6752
6753// Returns true if the operation is inside a TargetOp or
6754// is part of a declare target function.
6755static bool isTargetDeviceOp(Operation *op) {
6756 // Assumes no reverse offloading
6757 if (op->getParentOfType<omp::TargetOp>())
6758 return true;
6759
6760 // Certain operations return results, and whether utilised in host or
6761 // target there is a chance an LLVM Dialect operation depends on it
6762 // by taking it in as an operand, so we must always lower these in
6763 // some manner or result in an ICE (whether they end up in a no-op
6764 // or otherwise).
6765 if (mlir::isa<omp::ThreadprivateOp>(op))
6766 return true;
6767
6768 if (mlir::isa<omp::TargetAllocMemOp>(op) ||
6769 mlir::isa<omp::TargetFreeMemOp>(op))
6770 return true;
6771
6772 if (auto parentFn = op->getParentOfType<LLVM::LLVMFuncOp>())
6773 if (auto declareTargetIface =
6774 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
6775 parentFn.getOperation()))
6776 if (declareTargetIface.isDeclareTarget() &&
6777 declareTargetIface.getDeclareTargetDeviceType() !=
6778 mlir::omp::DeclareTargetDeviceType::host)
6779 return true;
6780
6781 return false;
6782}
6783
6784static llvm::Function *getOmpTargetAlloc(llvm::IRBuilderBase &builder,
6785 llvm::Module *llvmModule) {
6786 llvm::Type *i64Ty = builder.getInt64Ty();
6787 llvm::Type *i32Ty = builder.getInt32Ty();
6788 llvm::Type *returnType = builder.getPtrTy(0);
6789 llvm::FunctionType *fnType =
6790 llvm::FunctionType::get(returnType, {i64Ty, i32Ty}, false);
6791 llvm::Function *func = cast<llvm::Function>(
6792 llvmModule->getOrInsertFunction("omp_target_alloc", fnType).getCallee());
6793 return func;
6794}
6795
6796static LogicalResult
6797convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder,
6798 LLVM::ModuleTranslation &moduleTranslation) {
6799 auto allocMemOp = cast<omp::TargetAllocMemOp>(opInst);
6800 if (!allocMemOp)
6801 return failure();
6802
6803 // Get "omp_target_alloc" function
6804 llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
6805 llvm::Function *ompTargetAllocFunc = getOmpTargetAlloc(builder, llvmModule);
6806 // Get the corresponding device value in llvm
6807 mlir::Value deviceNum = allocMemOp.getDevice();
6808 llvm::Value *llvmDeviceNum = moduleTranslation.lookupValue(deviceNum);
6809 // Get the allocation size.
6810 llvm::DataLayout dataLayout = llvmModule->getDataLayout();
6811 mlir::Type heapTy = allocMemOp.getAllocatedType();
6812 llvm::Type *llvmHeapTy = moduleTranslation.convertType(heapTy);
6813 llvm::TypeSize typeSize = dataLayout.getTypeStoreSize(llvmHeapTy);
6814 llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue());
6815 for (auto typeParam : allocMemOp.getTypeparams())
6816 allocSize =
6817 builder.CreateMul(allocSize, moduleTranslation.lookupValue(typeParam));
6818 // Create call to "omp_target_alloc" with the args as translated llvm values.
6819 llvm::CallInst *call =
6820 builder.CreateCall(ompTargetAllocFunc, {allocSize, llvmDeviceNum});
6821 llvm::Value *resultI64 = builder.CreatePtrToInt(call, builder.getInt64Ty());
6822
6823 // Map the result
6824 moduleTranslation.mapValue(allocMemOp.getResult(), resultI64);
6825 return success();
6826}
6827
6828static llvm::Function *getOmpTargetFree(llvm::IRBuilderBase &builder,
6829 llvm::Module *llvmModule) {
6830 llvm::Type *ptrTy = builder.getPtrTy(0);
6831 llvm::Type *i32Ty = builder.getInt32Ty();
6832 llvm::Type *voidTy = builder.getVoidTy();
6833 llvm::FunctionType *fnType =
6834 llvm::FunctionType::get(voidTy, {ptrTy, i32Ty}, false);
6835 llvm::Function *func = dyn_cast<llvm::Function>(
6836 llvmModule->getOrInsertFunction("omp_target_free", fnType).getCallee());
6837 return func;
6838}
6839
6840static LogicalResult
6841convertTargetFreeMemOp(Operation &opInst, llvm::IRBuilderBase &builder,
6842 LLVM::ModuleTranslation &moduleTranslation) {
6843 auto freeMemOp = cast<omp::TargetFreeMemOp>(opInst);
6844 if (!freeMemOp)
6845 return failure();
6846
6847 // Get "omp_target_free" function
6848 llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
6849 llvm::Function *ompTragetFreeFunc = getOmpTargetFree(builder, llvmModule);
6850 // Get the corresponding device value in llvm
6851 mlir::Value deviceNum = freeMemOp.getDevice();
6852 llvm::Value *llvmDeviceNum = moduleTranslation.lookupValue(deviceNum);
6853 // Get the corresponding heapref value in llvm
6854 mlir::Value heapref = freeMemOp.getHeapref();
6855 llvm::Value *llvmHeapref = moduleTranslation.lookupValue(heapref);
6856 // Convert heapref int to ptr and call "omp_target_free"
6857 llvm::Value *intToPtr =
6858 builder.CreateIntToPtr(llvmHeapref, builder.getPtrTy(0));
6859 builder.CreateCall(ompTragetFreeFunc, {intToPtr, llvmDeviceNum});
6860 return success();
6861}
6862
6863/// Given an OpenMP MLIR operation, create the corresponding LLVM IR (including
6864/// OpenMP runtime calls).
6865static LogicalResult
6866convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder,
6867 LLVM::ModuleTranslation &moduleTranslation) {
6868 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
6869
6870 // For each loop, introduce one stack frame to hold loop information. Ensure
6871 // this is only done for the outermost loop wrapper to prevent introducing
6872 // multiple stack frames for a single loop. Initially set to null, the loop
6873 // information structure is initialized during translation of the nested
6874 // omp.loop_nest operation, making it available to translation of all loop
6875 // wrappers after their body has been successfully translated.
6876 bool isOutermostLoopWrapper =
6877 isa_and_present<omp::LoopWrapperInterface>(op) &&
6878 !dyn_cast_if_present<omp::LoopWrapperInterface>(op->getParentOp());
6879
6880 if (isOutermostLoopWrapper)
6881 moduleTranslation.stackPush<OpenMPLoopInfoStackFrame>();
6882
6883 auto result =
6885 .Case([&](omp::BarrierOp op) -> LogicalResult {
6886 if (failed(checkImplementationStatus(*op)))
6887 return failure();
6888
6889 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
6890 ompBuilder->createBarrier(builder.saveIP(),
6891 llvm::omp::OMPD_barrier);
6892 LogicalResult res = handleError(afterIP, *op);
6893 if (res.succeeded()) {
6894 // If the barrier generated a cancellation check, the insertion
6895 // point might now need to be changed to a new continuation block
6896 builder.restoreIP(*afterIP);
6897 }
6898 return res;
6899 })
6900 .Case([&](omp::TaskyieldOp op) {
6901 if (failed(checkImplementationStatus(*op)))
6902 return failure();
6903
6904 ompBuilder->createTaskyield(builder.saveIP());
6905 return success();
6906 })
6907 .Case([&](omp::FlushOp op) {
6908 if (failed(checkImplementationStatus(*op)))
6909 return failure();
6910
6911 // No support in Openmp runtime function (__kmpc_flush) to accept
6912 // the argument list.
6913 // OpenMP standard states the following:
6914 // "An implementation may implement a flush with a list by ignoring
6915 // the list, and treating it the same as a flush without a list."
6916 //
6917 // The argument list is discarded so that, flush with a list is
6918 // treated same as a flush without a list.
6919 ompBuilder->createFlush(builder.saveIP());
6920 return success();
6921 })
6922 .Case([&](omp::ParallelOp op) {
6923 return convertOmpParallel(op, builder, moduleTranslation);
6924 })
6925 .Case([&](omp::MaskedOp) {
6926 return convertOmpMasked(*op, builder, moduleTranslation);
6927 })
6928 .Case([&](omp::MasterOp) {
6929 return convertOmpMaster(*op, builder, moduleTranslation);
6930 })
6931 .Case([&](omp::CriticalOp) {
6932 return convertOmpCritical(*op, builder, moduleTranslation);
6933 })
6934 .Case([&](omp::OrderedRegionOp) {
6935 return convertOmpOrderedRegion(*op, builder, moduleTranslation);
6936 })
6937 .Case([&](omp::OrderedOp) {
6938 return convertOmpOrdered(*op, builder, moduleTranslation);
6939 })
6940 .Case([&](omp::WsloopOp) {
6941 return convertOmpWsloop(*op, builder, moduleTranslation);
6942 })
6943 .Case([&](omp::SimdOp) {
6944 return convertOmpSimd(*op, builder, moduleTranslation);
6945 })
6946 .Case([&](omp::AtomicReadOp) {
6947 return convertOmpAtomicRead(*op, builder, moduleTranslation);
6948 })
6949 .Case([&](omp::AtomicWriteOp) {
6950 return convertOmpAtomicWrite(*op, builder, moduleTranslation);
6951 })
6952 .Case([&](omp::AtomicUpdateOp op) {
6953 return convertOmpAtomicUpdate(op, builder, moduleTranslation);
6954 })
6955 .Case([&](omp::AtomicCaptureOp op) {
6956 return convertOmpAtomicCapture(op, builder, moduleTranslation);
6957 })
6958 .Case([&](omp::CancelOp op) {
6959 return convertOmpCancel(op, builder, moduleTranslation);
6960 })
6961 .Case([&](omp::CancellationPointOp op) {
6962 return convertOmpCancellationPoint(op, builder, moduleTranslation);
6963 })
6964 .Case([&](omp::SectionsOp) {
6965 return convertOmpSections(*op, builder, moduleTranslation);
6966 })
6967 .Case([&](omp::SingleOp op) {
6968 return convertOmpSingle(op, builder, moduleTranslation);
6969 })
6970 .Case([&](omp::TeamsOp op) {
6971 return convertOmpTeams(op, builder, moduleTranslation);
6972 })
6973 .Case([&](omp::TaskOp op) {
6974 return convertOmpTaskOp(op, builder, moduleTranslation);
6975 })
6976 .Case([&](omp::TaskloopOp op) {
6977 return convertOmpTaskloopOp(*op, builder, moduleTranslation);
6978 })
6979 .Case([&](omp::TaskgroupOp op) {
6980 return convertOmpTaskgroupOp(op, builder, moduleTranslation);
6981 })
6982 .Case([&](omp::TaskwaitOp op) {
6983 return convertOmpTaskwaitOp(op, builder, moduleTranslation);
6984 })
6985 .Case<omp::YieldOp, omp::TerminatorOp, omp::DeclareMapperOp,
6986 omp::DeclareMapperInfoOp, omp::DeclareReductionOp,
6987 omp::CriticalDeclareOp>([](auto op) {
6988 // `yield` and `terminator` can be just omitted. The block structure
6989 // was created in the region that handles their parent operation.
6990 // `declare_reduction` will be used by reductions and is not
6991 // converted directly, skip it.
6992 // `declare_mapper` and `declare_mapper.info` are handled whenever
6993 // they are referred to through a `map` clause.
6994 // `critical.declare` is only used to declare names of critical
6995 // sections which will be used by `critical` ops and hence can be
6996 // ignored for lowering. The OpenMP IRBuilder will create unique
6997 // name for critical section names.
6998 return success();
6999 })
7000 .Case([&](omp::ThreadprivateOp) {
7001 return convertOmpThreadprivate(*op, builder, moduleTranslation);
7002 })
7003 .Case<omp::TargetDataOp, omp::TargetEnterDataOp,
7004 omp::TargetExitDataOp, omp::TargetUpdateOp>([&](auto op) {
7005 return convertOmpTargetData(op, builder, moduleTranslation);
7006 })
7007 .Case([&](omp::TargetOp) {
7008 return convertOmpTarget(*op, builder, moduleTranslation);
7009 })
7010 .Case([&](omp::DistributeOp) {
7011 return convertOmpDistribute(*op, builder, moduleTranslation);
7012 })
7013 .Case([&](omp::LoopNestOp) {
7014 return convertOmpLoopNest(*op, builder, moduleTranslation);
7015 })
7016 .Case<omp::MapInfoOp, omp::MapBoundsOp, omp::PrivateClauseOp>(
7017 [&](auto op) {
7018 // No-op, should be handled by relevant owning operations e.g.
7019 // TargetOp, TargetEnterDataOp, TargetExitDataOp, TargetDataOp
7020 // etc. and then discarded
7021 return success();
7022 })
7023 .Case([&](omp::NewCliOp op) {
7024 // Meta-operation: Doesn't do anything by itself, but used to
7025 // identify a loop.
7026 return success();
7027 })
7028 .Case([&](omp::CanonicalLoopOp op) {
7029 return convertOmpCanonicalLoopOp(op, builder, moduleTranslation);
7030 })
7031 .Case([&](omp::UnrollHeuristicOp op) {
7032 // FIXME: Handling omp.unroll_heuristic as an executable requires
7033 // that the generator (e.g. omp.canonical_loop) has been seen first.
7034 // For construct that require all codegen to occur inside a callback
7035 // (e.g. OpenMPIRBilder::createParallel), all codegen of that
7036 // contained region including their transformations must occur at
7037 // the omp.canonical_loop.
7038 return applyUnrollHeuristic(op, builder, moduleTranslation);
7039 })
7040 .Case([&](omp::TileOp op) {
7041 return applyTile(op, builder, moduleTranslation);
7042 })
7043 .Case([&](omp::TargetAllocMemOp) {
7044 return convertTargetAllocMemOp(*op, builder, moduleTranslation);
7045 })
7046 .Case([&](omp::TargetFreeMemOp) {
7047 return convertTargetFreeMemOp(*op, builder, moduleTranslation);
7048 })
7049 .Default([&](Operation *inst) {
7050 return inst->emitError()
7051 << "not yet implemented: " << inst->getName();
7052 });
7053
7054 if (isOutermostLoopWrapper)
7055 moduleTranslation.stackPop();
7056
7057 return result;
7058}
7059
7060static LogicalResult
7061convertTargetDeviceOp(Operation *op, llvm::IRBuilderBase &builder,
7062 LLVM::ModuleTranslation &moduleTranslation) {
7063 return convertHostOrTargetOperation(op, builder, moduleTranslation);
7064}
7065
7066static LogicalResult
7067convertTargetOpsInNest(Operation *op, llvm::IRBuilderBase &builder,
7068 LLVM::ModuleTranslation &moduleTranslation) {
7069 if (isa<omp::TargetOp>(op))
7070 return convertOmpTarget(*op, builder, moduleTranslation);
7071 if (isa<omp::TargetDataOp>(op))
7072 return convertOmpTargetData(op, builder, moduleTranslation);
7073 bool interrupted =
7074 op->walk<WalkOrder::PreOrder>([&](Operation *oper) {
7075 if (isa<omp::TargetOp>(oper)) {
7076 if (failed(convertOmpTarget(*oper, builder, moduleTranslation)))
7077 return WalkResult::interrupt();
7078 return WalkResult::skip();
7079 }
7080 if (isa<omp::TargetDataOp>(oper)) {
7081 if (failed(convertOmpTargetData(oper, builder, moduleTranslation)))
7082 return WalkResult::interrupt();
7083 return WalkResult::skip();
7084 }
7085
7086 // Non-target ops might nest target-related ops, therefore, we
7087 // translate them as non-OpenMP scopes. Translating them is needed by
7088 // nested target-related ops since they might need LLVM values defined
7089 // in their parent non-target ops.
7090 if (isa<omp::OpenMPDialect>(oper->getDialect()) &&
7091 oper->getParentOfType<LLVM::LLVMFuncOp>() &&
7092 !oper->getRegions().empty()) {
7093 if (auto blockArgsIface =
7094 dyn_cast<omp::BlockArgOpenMPOpInterface>(oper))
7095 forwardArgs(moduleTranslation, blockArgsIface);
7096 else {
7097 // Here we map entry block arguments of
7098 // non-BlockArgOpenMPOpInterface ops if they can be encountered
7099 // inside of a function and they define any of these arguments.
7100 if (isa<mlir::omp::AtomicUpdateOp>(oper))
7101 for (auto [operand, arg] :
7102 llvm::zip_equal(oper->getOperands(),
7103 oper->getRegion(0).getArguments())) {
7104 moduleTranslation.mapValue(
7105 arg, builder.CreateLoad(
7106 moduleTranslation.convertType(arg.getType()),
7107 moduleTranslation.lookupValue(operand)));
7108 }
7109 }
7110
7111 if (auto loopNest = dyn_cast<omp::LoopNestOp>(oper)) {
7112 assert(builder.GetInsertBlock() &&
7113 "No insert block is set for the builder");
7114 for (auto iv : loopNest.getIVs()) {
7115 // Map iv to an undefined value just to keep the IR validity.
7116 moduleTranslation.mapValue(
7117 iv, llvm::PoisonValue::get(
7118 moduleTranslation.convertType(iv.getType())));
7119 }
7120 }
7121
7122 for (Region &region : oper->getRegions()) {
7123 // Regions are fake in the sense that they are not a truthful
7124 // translation of the OpenMP construct being converted (e.g. no
7125 // OpenMP runtime calls will be generated). We just need this to
7126 // prepare the kernel invocation args.
7129 region, oper->getName().getStringRef().str() + ".fake.region",
7130 builder, moduleTranslation, &phis);
7131 if (failed(handleError(result, *oper)))
7132 return WalkResult::interrupt();
7133
7134 builder.SetInsertPoint(result.get(), result.get()->end());
7135 }
7136
7137 return WalkResult::skip();
7138 }
7139
7140 return WalkResult::advance();
7141 }).wasInterrupted();
7142 return failure(interrupted);
7143}
7144
7145namespace {
7146
7147/// Implementation of the dialect interface that converts operations belonging
7148/// to the OpenMP dialect to LLVM IR.
7149class OpenMPDialectLLVMIRTranslationInterface
7150 : public LLVMTranslationDialectInterface {
7151public:
7153
7154 /// Translates the given operation to LLVM IR using the provided IR builder
7155 /// and saving the state in `moduleTranslation`.
7156 LogicalResult
7157 convertOperation(Operation *op, llvm::IRBuilderBase &builder,
7158 LLVM::ModuleTranslation &moduleTranslation) const final;
7159
7160 /// Given an OpenMP MLIR attribute, create the corresponding LLVM-IR,
7161 /// runtime calls, or operation amendments
7162 LogicalResult
7163 amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
7164 NamedAttribute attribute,
7165 LLVM::ModuleTranslation &moduleTranslation) const final;
7166};
7167
7168} // namespace
7169
7170LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
7171 Operation *op, ArrayRef<llvm::Instruction *> instructions,
7172 NamedAttribute attribute,
7173 LLVM::ModuleTranslation &moduleTranslation) const {
7174 return llvm::StringSwitch<llvm::function_ref<LogicalResult(Attribute)>>(
7175 attribute.getName())
7176 .Case("omp.is_target_device",
7177 [&](Attribute attr) {
7178 if (auto deviceAttr = dyn_cast<BoolAttr>(attr)) {
7179 llvm::OpenMPIRBuilderConfig &config =
7180 moduleTranslation.getOpenMPBuilder()->Config;
7181 config.setIsTargetDevice(deviceAttr.getValue());
7182 return success();
7183 }
7184 return failure();
7185 })
7186 .Case("omp.is_gpu",
7187 [&](Attribute attr) {
7188 if (auto gpuAttr = dyn_cast<BoolAttr>(attr)) {
7189 llvm::OpenMPIRBuilderConfig &config =
7190 moduleTranslation.getOpenMPBuilder()->Config;
7191 config.setIsGPU(gpuAttr.getValue());
7192 return success();
7193 }
7194 return failure();
7195 })
7196 .Case("omp.host_ir_filepath",
7197 [&](Attribute attr) {
7198 if (auto filepathAttr = dyn_cast<StringAttr>(attr)) {
7199 llvm::OpenMPIRBuilder *ompBuilder =
7200 moduleTranslation.getOpenMPBuilder();
7201 auto VFS = llvm::vfs::getRealFileSystem();
7202 ompBuilder->loadOffloadInfoMetadata(*VFS,
7203 filepathAttr.getValue());
7204 return success();
7205 }
7206 return failure();
7207 })
7208 .Case("omp.flags",
7209 [&](Attribute attr) {
7210 if (auto rtlAttr = dyn_cast<omp::FlagsAttr>(attr))
7211 return convertFlagsAttr(op, rtlAttr, moduleTranslation);
7212 return failure();
7213 })
7214 .Case("omp.version",
7215 [&](Attribute attr) {
7216 if (auto versionAttr = dyn_cast<omp::VersionAttr>(attr)) {
7217 llvm::OpenMPIRBuilder *ompBuilder =
7218 moduleTranslation.getOpenMPBuilder();
7219 ompBuilder->M.addModuleFlag(llvm::Module::Max, "openmp",
7220 versionAttr.getVersion());
7221 return success();
7222 }
7223 return failure();
7224 })
7225 .Case("omp.declare_target",
7226 [&](Attribute attr) {
7227 if (auto declareTargetAttr =
7228 dyn_cast<omp::DeclareTargetAttr>(attr))
7229 return convertDeclareTargetAttr(op, declareTargetAttr,
7230 moduleTranslation);
7231 return failure();
7232 })
7233 .Case("omp.requires",
7234 [&](Attribute attr) {
7235 if (auto requiresAttr = dyn_cast<omp::ClauseRequiresAttr>(attr)) {
7236 using Requires = omp::ClauseRequires;
7237 Requires flags = requiresAttr.getValue();
7238 llvm::OpenMPIRBuilderConfig &config =
7239 moduleTranslation.getOpenMPBuilder()->Config;
7240 config.setHasRequiresReverseOffload(
7241 bitEnumContainsAll(flags, Requires::reverse_offload));
7242 config.setHasRequiresUnifiedAddress(
7243 bitEnumContainsAll(flags, Requires::unified_address));
7244 config.setHasRequiresUnifiedSharedMemory(
7245 bitEnumContainsAll(flags, Requires::unified_shared_memory));
7246 config.setHasRequiresDynamicAllocators(
7247 bitEnumContainsAll(flags, Requires::dynamic_allocators));
7248 return success();
7249 }
7250 return failure();
7251 })
7252 .Case("omp.target_triples",
7253 [&](Attribute attr) {
7254 if (auto triplesAttr = dyn_cast<ArrayAttr>(attr)) {
7255 llvm::OpenMPIRBuilderConfig &config =
7256 moduleTranslation.getOpenMPBuilder()->Config;
7257 config.TargetTriples.clear();
7258 config.TargetTriples.reserve(triplesAttr.size());
7259 for (Attribute tripleAttr : triplesAttr) {
7260 if (auto tripleStrAttr = dyn_cast<StringAttr>(tripleAttr))
7261 config.TargetTriples.emplace_back(tripleStrAttr.getValue());
7262 else
7263 return failure();
7264 }
7265 return success();
7266 }
7267 return failure();
7268 })
7269 .Default([](Attribute) {
7270 // Fall through for omp attributes that do not require lowering.
7271 return success();
7272 })(attribute.getValue());
7273
7274 return failure();
7275}
7276
7277/// Given an OpenMP MLIR operation, create the corresponding LLVM IR
7278/// (including OpenMP runtime calls).
7279LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
7280 Operation *op, llvm::IRBuilderBase &builder,
7281 LLVM::ModuleTranslation &moduleTranslation) const {
7282
7283 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
7284 if (ompBuilder->Config.isTargetDevice()) {
7285 if (isTargetDeviceOp(op)) {
7286 return convertTargetDeviceOp(op, builder, moduleTranslation);
7287 }
7288 return convertTargetOpsInNest(op, builder, moduleTranslation);
7289 }
7290 return convertHostOrTargetOperation(op, builder, moduleTranslation);
7291}
7292
7294 registry.insert<omp::OpenMPDialect>();
7295 registry.addExtension(+[](MLIRContext *ctx, omp::OpenMPDialect *dialect) {
7296 dialect->addInterfaces<OpenMPDialectLLVMIRTranslationInterface>();
7297 });
7298}
7299
7301 DialectRegistry registry;
7303 context.appendDialectRegistry(registry);
7304}
for(Operation *op :ops)
return success()
lhs
static ze_device_handle_t getDevice(const uint32_t driverIdx=0, const int32_t devIdx=0)
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
if(!isCopyOut)
static bool mapTypeToBool(ClauseMapFlags value, ClauseMapFlags flag)
static void handleDeclareTargetMapVar(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::Function *func)
static LogicalResult convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP atomic update operation using OpenMPIRBuilder.
static llvm::omp::OrderKind convertOrderKind(std::optional< omp::ClauseOrderKind > o)
Convert Order attribute to llvm::omp::OrderKind.
static bool teamsReductionContainedInDistribute(omp::TeamsOp teamsOp)
static LogicalResult convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
owningDataPtrPtrReductionGens[i]
static Operation * getGlobalOpFromValue(Value value)
static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind convertToCaptureClauseKind(mlir::omp::DeclareTargetCaptureClause captureClause)
static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo, bool first)
static OpTy castOrGetParentOfType(Operation *op, bool immediateParent=false)
If op is of the given type parameter, return it casted to that type. Otherwise, if its immediate pare...
static LogicalResult convertOmpOrderedRegion(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'ordered_region' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertTargetFreeMemOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static void processMapMembersWithParent(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, llvm::omp::OpenMPOffloadMappingFlags memberOfFlag, TargetDirectiveEnumTy targetDirective)
static LogicalResult convertOmpAtomicWrite(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an omp.atomic.write operation to LLVM IR.
static OwningAtomicReductionGen makeAtomicReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Create an OpenMPIRBuilder-compatible atomic reduction generator for the given reduction declaration.
static OwningDataPtrPtrReductionGen makeRefDataPtrGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, bool isByRef)
Create an OpenMPIRBuilder-compatible data_ptr_ptr reduction generator for the given reduction declara...
static llvm::Value * getRefPtrIfDeclareTarget(Value value, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult applyUnrollHeuristic(omp::UnrollHeuristicOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Apply a #pragma omp unroll / "!$omp unroll" transformation using the OpenMPIRBuilder.
static LogicalResult convertOmpMaster(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'master' operation into LLVM IR using OpenMPIRBuilder.
static void getAsIntegers(ArrayAttr values, llvm::SmallVector< int64_t > &ints)
static llvm::Value * findAssociatedValue(Value privateVar, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Return the llvm::Value * corresponding to the privateVar that is being privatized....
static ArrayRef< bool > getIsByRef(std::optional< ArrayRef< bool > > attr)
static LogicalResult convertOmpTaskwaitOp(omp::TaskwaitOp twOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP loop nest into LLVM IR using OpenMPIRBuilder.
static void popCancelFinalizationCB(const ArrayRef< llvm::BranchInst * > cancelTerminators, llvm::OpenMPIRBuilder &ompBuilder, const llvm::OpenMPIRBuilder::InsertPointTy &afterIP)
If we cancelled the construct, we should branch to the finalization block of that construct....
static LogicalResult cleanupPrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, Location loc, SmallVectorImpl< llvm::Value * > &llvmPrivateVars, SmallVectorImpl< omp::PrivateClauseOp > &privateDecls)
static void createAlteredByCaptureMap(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder)
static LogicalResult convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP task construct into LLVM IR using OpenMPIRBuilder.
static void genMapInfos(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, TargetDirectiveEnumTy targetDirective)
static llvm::AtomicOrdering convertAtomicOrdering(std::optional< omp::ClauseMemoryOrderKind > ao)
Convert an Atomic Ordering attribute to llvm::AtomicOrdering.
static void setInsertPointForPossiblyEmptyBlock(llvm::IRBuilderBase &builder, llvm::BasicBlock *block=nullptr)
static LogicalResult convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'critical' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, TargetDirectiveEnumTy targetDirective)
static LogicalResult convertOmpCanonicalLoopOp(omp::CanonicalLoopOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Convert an omp.canonical_loop to LLVM-IR.
static LogicalResult convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static std::optional< int64_t > extractConstInteger(Value value)
If the given value is defined by an llvm.mlir.constant operation and it is of an integer type,...
static llvm::Expected< llvm::Value * > initPrivateVar(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, omp::PrivateClauseOp &privDecl, llvm::Value *nonPrivateVar, BlockArgument &blockArg, llvm::Value *llvmPrivateVar, llvm::BasicBlock *privInitBlock, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Initialize a single (first)private variable. You probably want to use allocateAndInitPrivateVars inst...
static LogicalResult allocAndInitializeReductionVars(OP op, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, llvm::ArrayRef< bool > isByRef)
static LogicalResult convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Function * getOmpTargetAlloc(llvm::IRBuilderBase &builder, llvm::Module *llvmModule)
static llvm::Expected< llvm::Function * > emitUserDefinedMapper(Operation *declMapperOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::StringRef mapperFuncName, TargetDirectiveEnumTy targetDirective)
static llvm::Expected< llvm::BasicBlock * > allocatePrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, PrivateVarsInfo &privateVarsInfo, const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Allocate and initialize delayed private variables. Returns the basic block which comes after all of t...
static LogicalResult convertOmpOrdered(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'ordered' operation into LLVM IR using OpenMPIRBuilder.
static void processMapWithMembersOf(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, TargetDirectiveEnumTy targetDirective)
static LogicalResult convertOmpTaskloopOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpMasked(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'masked' operation into LLVM IR using OpenMPIRBuilder.
static bool isTargetDeviceOp(Operation *op)
static llvm::AtomicRMWInst::BinOp convertBinOpToAtomic(Operation &op)
Converts an LLVM dialect binary operation to the corresponding enum value for atomicrmw supported bin...
static void pushCancelFinalizationCB(SmallVectorImpl< llvm::BranchInst * > &cancelTerminators, llvm::IRBuilderBase &llvmBuilder, llvm::OpenMPIRBuilder &ompBuilder, mlir::Operation *op, llvm::omp::Directive cancelDirective)
Shared implementation of a callback which adds a termiator for the new block created for the branch t...
static LogicalResult convertOmpCancel(omp::CancelOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp)
allocatedType moduleTranslation static convertType(allocatedType) LogicalResult inlineOmpRegionCleanup(llvm::SmallVectorImpl< Region * > &cleanupRegions, llvm::ArrayRef< llvm::Value * > privateVariables, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, StringRef regionName, bool shouldLoadCleanupRegionArg=true)
handling of DeclareReductionOp's cleanup region
static llvm::SmallString< 64 > getDeclareTargetRefPtrSuffix(LLVM::GlobalOp globalOp, llvm::OpenMPIRBuilder &ompBuilder)
static llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Error initPrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, PrivateVarsInfo &privateVarsInfo, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
static llvm::CanonicalLoopInfo * findCurrentLoopInfo(LLVM::ModuleTranslation &moduleTranslation)
Find the loop information structure for the loop nest being translated.
static OwningReductionGen makeReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Create an OpenMPIRBuilder-compatible reduction generator for the given reduction declaration.
static std::vector< llvm::Value * > calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, bool isArrayTy, OperandRange bounds)
This function calculates the array/pointer offset for map data provided with bounds operations,...
static LogicalResult convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts the OpenMP parallel operation to LLVM IR.
static llvm::OpenMPIRBuilder::InsertPointTy findAllocaInsertPoint(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Find the insertion point for allocas given the current insertion point for normal operations in the b...
static LogicalResult inlineConvertOmpRegions(Region &region, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::Value * > *continuationBlockArgs=nullptr)
Translates the blocks contained in the given region and appends them to at the current insertion poin...
static LogicalResult convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP Threadprivate operation into LLVM IR using OpenMPIRBuilder.
static omp::PrivateClauseOp findPrivatizer(Operation *from, SymbolRefAttr symbolName)
Looks up from the operation from and returns the PrivateClauseOp with name symbolName.
static LogicalResult convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Expected< llvm::Function * > getOrCreateUserDefinedMapperFunc(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, TargetDirectiveEnumTy targetDirective)
static uint64_t getTypeByteSize(mlir::Type type, const DataLayout &dl)
static void processIndividualMap(MapInfoData &mapData, size_t mapDataIdx, MapInfosTy &combinedInfo, TargetDirectiveEnumTy targetDirective, int mapDataParentIdx=-1)
static void getTargetEntryUniqueInfo(llvm::TargetRegionEntryInfo &targetInfo, omp::TargetOp targetOp, llvm::StringRef parentName="")
static void extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads, Value &numTeamsLower, Value &numTeamsUpper, Value &threadLimit, llvm::SmallVectorImpl< Value > *lowerBounds=nullptr, llvm::SmallVectorImpl< Value > *upperBounds=nullptr, llvm::SmallVectorImpl< Value > *steps=nullptr)
Follow uses of host_eval-defined block arguments of the given omp.target operation and populate outpu...
static llvm::Expected< llvm::BasicBlock * > convertOmpOpRegions(Region &region, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::PHINode * > *continuationBlockPHIs=nullptr)
Converts the given region that appears within an OpenMP dialect operation to LLVM IR,...
static void forwardArgs(LLVM::ModuleTranslation &moduleTranslation, omp::BlockArgOpenMPOpInterface blockArgIface)
Maps block arguments from blockArgIface (which are MLIR values) to the corresponding LLVM values of t...
static LogicalResult copyFirstPrivateVars(mlir::Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::Value * > &moldVars, ArrayRef< llvm::Value * > llvmPrivateVars, SmallVectorImpl< omp::PrivateClauseOp > &privateDecls, bool insertBarrier, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
static LogicalResult allocReductionVars(T loop, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, SmallVectorImpl< DeferredStore > &deferredStores, llvm::ArrayRef< bool > isByRefs)
Allocate space for privatized reduction variables.
static bool constructIsCancellable(Operation *op)
Returns true if the construct contains omp.cancel or omp.cancellation_point.
static llvm::omp::OpenMPOffloadMappingFlags convertClauseMapFlags(omp::ClauseMapFlags mlirFlags)
static llvm::IRBuilderBase::InsertPoint createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg, llvm::Value *input, llvm::Value *&retVal, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase::InsertPoint allocaIP, llvm::IRBuilderBase::InsertPoint codeGenIP)
static void mapInitializationArgs(T loop, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, DenseMap< Value, llvm::Value * > &reductionVariableMap, unsigned i)
Map input arguments to reduction initialization region.
static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind)
Convert ProcBindKind from MLIR-generated enum to LLVM enum.
static LogicalResult convertTargetOpsInNest(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static void getOverlappedMembers(llvm::SmallVectorImpl< size_t > &overlapMapDataIdxs, omp::MapInfoOp parentOp)
static LogicalResult convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP single construct into LLVM IR using OpenMPIRBuilder.
static bool isDeclareTargetTo(Value value)
static uint64_t getArrayElementSizeInBits(LLVM::LLVMArrayType arrTy, DataLayout &dl)
static void collectReductionDecls(T op, SmallVectorImpl< omp::DeclareReductionOp > &reductions)
Populates reductions with reduction declarations used in the given op.
static LogicalResult handleError(llvm::Error error, Operation &op)
static LogicalResult convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind convertToDeviceClauseKind(mlir::omp::DeclareTargetDeviceType deviceClause)
static LogicalResult checkImplementationStatus(Operation &op)
Check whether translation to LLVM IR for the given operation is currently supported.
static LogicalResult createReductionsAndCleanup(OP op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, ArrayRef< llvm::Value * > privateReductionVariables, ArrayRef< bool > isByRef, bool isNowait=false, bool isTeamsReduction=false)
static LogicalResult convertOmpCancellationPoint(omp::CancellationPointOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static bool opIsInSingleThread(mlir::Operation *op)
This can't always be determined statically, but when we can, it is good to avoid generating compiler-...
static uint64_t getReductionDataSize(OpTy &op)
static LogicalResult convertOmpAtomicRead(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Convert omp.atomic.read operation to LLVM IR.
static llvm::omp::Directive convertCancellationConstructType(omp::ClauseCancellationConstructType directive)
static void initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp, llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs, bool isTargetDevice, bool isGPU)
Populate default MinTeams, MaxTeams and MaxThreads to their default values as stated by the correspon...
static void initTargetRuntimeAttrs(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, omp::TargetOp targetOp, Operation *capturedOp, llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs)
Gather LLVM runtime values for all clauses evaluated in the host that are passed to the kernel invoca...
static LogicalResult convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Given an OpenMP MLIR operation, create the corresponding LLVM IR (including OpenMP runtime calls).
static bool isDeclareTargetLink(Value value)
static LogicalResult convertFlagsAttr(Operation *op, mlir::omp::FlagsAttr attribute, LLVM::ModuleTranslation &moduleTranslation)
Lowers the FlagsAttr which is applied to the module on the device pass when offloading,...
static bool checkIfPointerMap(omp::MapInfoOp mapOp)
static LogicalResult applyTile(omp::TileOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Apply a #pragma omp tile / !$omp tile transformation using the OpenMPIRBuilder.
static void sortMapIndices(llvm::SmallVectorImpl< size_t > &indices, omp::MapInfoOp mapInfo)
static LogicalResult convertTargetDeviceOp(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static void buildDependData(std::optional< ArrayAttr > dependKinds, OperandRange dependVars, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::OpenMPIRBuilder::DependData > &dds)
static llvm::Function * getOmpTargetFree(llvm::IRBuilderBase &builder, llvm::Module *llvmModule)
static LogicalResult convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder.
static void collectMapDataFromMapOperands(MapInfoData &mapData, SmallVectorImpl< Value > &mapVars, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, llvm::IRBuilderBase &builder, ArrayRef< Value > useDevPtrOperands={}, ArrayRef< Value > useDevAddrOperands={}, ArrayRef< Value > hasDevAddrOperands={})
static void extractAtomicControlFlags(omp::AtomicUpdateOp atomicUpdateOp, bool &isIgnoreDenormalMode, bool &isFineGrainedMemory, bool &isRemoteMemory)
static Operation * genLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr, unsigned numCases, bool needsUniv, ArrayRef< TensorLevel > tidLvls)
Generates a for-loop or a while-loop, depending on whether it implements singleton iteration or co-it...
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition TypeID.h:331
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:309
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
Operation & front()
Definition Block.h:163
Operation & back()
Definition Block.h:162
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
iterator begin()
Definition Block.h:153
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
An instance of this location represents a tuple of file, line number, and column number.
Definition Location.h:174
Implementation class for module translation.
llvm::BasicBlock * lookupBlock(Block *block) const
Finds an LLVM IR basic block that corresponds to the given MLIR block.
WalkResult stackWalk(llvm::function_ref< WalkResult(T &)> callback)
Calls callback for every ModuleTranslation stack frame of type T starting from the top of the stack.
void stackPush(Args &&...args)
Creates a stack frame of type T on ModuleTranslation stack.
LogicalResult convertBlock(Block &bb, bool ignoreArguments, llvm::IRBuilderBase &builder)
Translates the contents of the given block to LLVM IR using this translator.
SmallVector< llvm::Value * > lookupValues(ValueRange values)
Looks up remapped a list of remapped values.
void mapFunction(StringRef name, llvm::Function *func)
Stores the mapping between a function name and its LLVM IR representation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
void invalidateOmpLoop(omp::NewCliOp mlir)
Mark an OpenMP loop as having been consumed.
SymbolTableCollection & symbolTable()
llvm::Type * convertType(Type type)
Converts the type from MLIR LLVM dialect to LLVM.
llvm::OpenMPIRBuilder * getOpenMPBuilder()
Returns the OpenMP IR builder associated with the LLVM IR module being constructed.
void mapOmpLoop(omp::NewCliOp mlir, llvm::CanonicalLoopInfo *llvm)
Map an MLIR OpenMP dialect CanonicalLoopInfo to its lowered LLVM-IR OpenMPIRBuilder CanonicalLoopInfo...
llvm::GlobalValue * lookupGlobal(Operation *op)
Finds an LLVM IR global value that corresponds to the given MLIR operation defining a global value.
SaveStateStack< T, ModuleTranslation > SaveStack
RAII object calling stackPush/stackPop on construction/destruction.
llvm::Function * lookupFunction(StringRef name) const
Finds an LLVM IR function by its name.
void mapBlock(Block *mlir, llvm::BasicBlock *llvm)
Stores the mapping between an MLIR block and LLVM IR basic block.
llvm::Module * getLLVMModule()
Returns the LLVM module in which the IR is being constructed.
void stackPop()
Pops the last element from the ModuleTranslation stack.
void forgetMapping(Region &region)
Removes the mapping for blocks contained in the region and values defined in these blocks.
void mapValue(Value mlir, llvm::Value *llvm)
Stores the mapping between an MLIR value and its LLVM IR counterpart.
llvm::CanonicalLoopInfo * lookupOMPLoop(omp::NewCliOp mlir) const
Find the LLVM-IR loop that represents an MLIR loop.
llvm::LLVMContext & getLLVMContext() const
Returns the LLVM context in which the IR is being constructed.
Utility class to translate MLIR LLVM dialect types to LLVM IR.
Definition TypeToLLVM.h:39
unsigned getPreferredAlignment(Type type, const llvm::DataLayout &layout)
Returns the preferred alignment for the type given the data layout.
T findInstanceOf()
Return an instance of the given location type if one is nested under the current location.
Definition Location.h:45
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
void appendDialectRegistry(const DialectRegistry &registry)
Append the contents of the given dialect registry to the registry associated with this context.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:686
Value getOperand(unsigned idx)
Definition Operation.h:350
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
unsigned getNumOperands()
Definition Operation.h:346
OperandRange operand_range
Definition Operation.h:371
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:238
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:797
user_range getUsers()
Returns a range of all users.
Definition Operation.h:873
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
BlockArgListType getArguments()
Definition Region.h:81
bool empty()
Definition Region.h:60
unsigned getNumArguments()
Definition Region.h:123
iterator begin()
Definition Region.h:55
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition Region.h:200
BlockListType & getBlocks()
Definition Region.h:45
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
Concrete CRTP base class for StateStack frames.
Definition StateStack.h:47
@ Private
The symbol is private and may only be referenced by SymbolRefAttrs local to the operations within the...
Definition SymbolTable.h:97
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult skip()
Definition WalkResult.h:48
static WalkResult advance()
Definition WalkResult.h:47
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition WalkResult.h:51
static WalkResult interrupt()
Definition WalkResult.h:46
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Definition CallGraph.h:229
void connectPHINodes(Region &region, const ModuleTranslation &state)
For all blocks in the region that were converted to LLVM IR using the given ModuleTranslation,...
llvm::Constant * createMappingInformation(Location loc, llvm::OpenMPIRBuilder &builder)
Create a constant string representing the mapping information extracted from the MLIR location inform...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
Include the generated interface declarations.
SetVector< Block * > getBlocksSortedByDominance(Region &region)
Gets a list of blocks that is sorted according to dominance.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
const FrozenRewritePatternSet GreedyRewriteConfig config
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void registerOpenMPDialectTranslation(DialectRegistry &registry)
Register the OpenMP dialect and the translation from it to the LLVM IR in the given registry;.
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:131
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition Utils.cpp:1293
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
A util to collect info needed to convert delayed privatizers from MLIR to LLVM.
SmallVector< mlir::Value > mlirVars
SmallVector< omp::PrivateClauseOp > privatizers
MutableArrayRef< BlockArgument > blockArgs
SmallVector< llvm::Value * > llvmVars