MLIR 23.0.0git
OpenMPToLLVMIRTranslation.cpp
Go to the documentation of this file.
1//===- OpenMPToLLVMIRTranslation.cpp - Translate OpenMP dialect to LLVM IR-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a translation between the MLIR OpenMP dialect and LLVM
10// IR.
11//
12//===----------------------------------------------------------------------===//
19#include "mlir/IR/Operation.h"
20#include "mlir/Support/LLVM.h"
23
24#include "llvm/ADT/ArrayRef.h"
25#include "llvm/ADT/SmallVector.h"
26#include "llvm/ADT/TypeSwitch.h"
27#include "llvm/Frontend/OpenMP/OMPConstants.h"
28#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
29#include "llvm/IR/Constants.h"
30#include "llvm/IR/DebugInfoMetadata.h"
31#include "llvm/IR/DerivedTypes.h"
32#include "llvm/IR/IRBuilder.h"
33#include "llvm/IR/MDBuilder.h"
34#include "llvm/IR/ReplaceConstant.h"
35#include "llvm/Support/FileSystem.h"
36#include "llvm/Support/VirtualFileSystem.h"
37#include "llvm/TargetParser/Triple.h"
38#include "llvm/Transforms/Utils/ModuleUtils.h"
39
40#include <cstdint>
41#include <iterator>
42#include <numeric>
43#include <optional>
44#include <utility>
45
46using namespace mlir;
47
48namespace {
49static llvm::omp::ScheduleKind
50convertToScheduleKind(std::optional<omp::ClauseScheduleKind> schedKind) {
51 if (!schedKind.has_value())
52 return llvm::omp::OMP_SCHEDULE_Default;
53 switch (schedKind.value()) {
54 case omp::ClauseScheduleKind::Static:
55 return llvm::omp::OMP_SCHEDULE_Static;
56 case omp::ClauseScheduleKind::Dynamic:
57 return llvm::omp::OMP_SCHEDULE_Dynamic;
58 case omp::ClauseScheduleKind::Guided:
59 return llvm::omp::OMP_SCHEDULE_Guided;
60 case omp::ClauseScheduleKind::Auto:
61 return llvm::omp::OMP_SCHEDULE_Auto;
62 case omp::ClauseScheduleKind::Runtime:
63 return llvm::omp::OMP_SCHEDULE_Runtime;
64 case omp::ClauseScheduleKind::Distribute:
65 return llvm::omp::OMP_SCHEDULE_Distribute;
66 }
67 llvm_unreachable("unhandled schedule clause argument");
68}
69
70/// ModuleTranslation stack frame for OpenMP operations. This keeps track of the
71/// insertion points for allocas.
72class OpenMPAllocaStackFrame
73 : public StateStackFrameBase<OpenMPAllocaStackFrame> {
74public:
76
77 explicit OpenMPAllocaStackFrame(llvm::OpenMPIRBuilder::InsertPointTy allocaIP)
78 : allocaInsertPoint(allocaIP) {}
79 llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
80};
81
82/// Stack frame to hold a \see llvm::CanonicalLoopInfo representing the
83/// collapsed canonical loop information corresponding to an \c omp.loop_nest
84/// operation.
85class OpenMPLoopInfoStackFrame
86 : public StateStackFrameBase<OpenMPLoopInfoStackFrame> {
87public:
88 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPLoopInfoStackFrame)
89 llvm::CanonicalLoopInfo *loopInfo = nullptr;
90};
91
92/// Custom error class to signal translation errors that don't need reporting,
93/// since encountering them will have already triggered relevant error messages.
94///
95/// Its purpose is to serve as the glue between MLIR failures represented as
96/// \see LogicalResult instances and \see llvm::Error instances used to
97/// propagate errors through the \see llvm::OpenMPIRBuilder. Generally, when an
98/// error of the first type is raised, a message is emitted directly (the \see
99/// LogicalResult itself does not hold any information). If we need to forward
100/// this error condition as an \see llvm::Error while avoiding triggering some
101/// redundant error reporting later on, we need a custom \see llvm::ErrorInfo
102/// class to just signal this situation has happened.
103///
104/// For example, this class should be used to trigger errors from within
105/// callbacks passed to the \see OpenMPIRBuilder when they were triggered by the
106/// translation of their own regions. This unclutters the error log from
107/// redundant messages.
108class PreviouslyReportedError
109 : public llvm::ErrorInfo<PreviouslyReportedError> {
110public:
111 void log(raw_ostream &) const override {
112 // Do not log anything.
113 }
114
115 std::error_code convertToErrorCode() const override {
116 llvm_unreachable(
117 "PreviouslyReportedError doesn't support ECError conversion");
118 }
119
120 // Used by ErrorInfo::classID.
121 static char ID;
122};
123
124char PreviouslyReportedError::ID = 0;
125
126/*
127 * Custom class for processing linear clause for omp.wsloop
128 * and omp.simd. Linear clause translation requires setup,
129 * initialization, update, and finalization at varying
130 * basic blocks in the IR. This class helps maintain
131 * internal state to allow consistent translation in
132 * each of these stages.
133 */
134
135class LinearClauseProcessor {
136
137private:
138 SmallVector<llvm::Value *> linearPreconditionVars;
139 SmallVector<llvm::Value *> linearLoopBodyTemps;
140 SmallVector<llvm::Value *> linearOrigVal;
141 SmallVector<llvm::Value *> linearSteps;
142 SmallVector<llvm::Type *> linearVarTypes;
143 llvm::BasicBlock *linearFinalizationBB;
144 llvm::BasicBlock *linearExitBB;
145 llvm::BasicBlock *linearLastIterExitBB;
146
147public:
148 // Register type for the linear variables
149 void registerType(LLVM::ModuleTranslation &moduleTranslation,
150 mlir::Attribute &ty) {
151 linearVarTypes.push_back(moduleTranslation.convertType(
152 mlir::cast<mlir::TypeAttr>(ty).getValue()));
153 }
154
155 // Allocate space for linear variabes
156 void createLinearVar(llvm::IRBuilderBase &builder,
157 LLVM::ModuleTranslation &moduleTranslation,
158 llvm::Value *linearVar, int idx) {
159 linearPreconditionVars.push_back(
160 builder.CreateAlloca(linearVarTypes[idx], nullptr, ".linear_var"));
161 llvm::Value *linearLoopBodyTemp =
162 builder.CreateAlloca(linearVarTypes[idx], nullptr, ".linear_result");
163 linearOrigVal.push_back(linearVar);
164 linearLoopBodyTemps.push_back(linearLoopBodyTemp);
165 }
166
167 // Initialize linear step
168 inline void initLinearStep(LLVM::ModuleTranslation &moduleTranslation,
169 mlir::Value &linearStep) {
170 linearSteps.push_back(moduleTranslation.lookupValue(linearStep));
171 }
172
173 // Emit IR for initialization of linear variables
174 void initLinearVar(llvm::IRBuilderBase &builder,
175 LLVM::ModuleTranslation &moduleTranslation,
176 llvm::BasicBlock *loopPreHeader) {
177 builder.SetInsertPoint(loopPreHeader->getTerminator());
178 for (size_t index = 0; index < linearOrigVal.size(); index++) {
179 llvm::LoadInst *linearVarLoad =
180 builder.CreateLoad(linearVarTypes[index], linearOrigVal[index]);
181 builder.CreateStore(linearVarLoad, linearPreconditionVars[index]);
182 }
183 }
184
185 // Emit IR for updating Linear variables
186 void updateLinearVar(llvm::IRBuilderBase &builder, llvm::BasicBlock *loopBody,
187 llvm::Value *loopInductionVar) {
188 builder.SetInsertPoint(loopBody->getTerminator());
189 for (size_t index = 0; index < linearPreconditionVars.size(); index++) {
190 llvm::Type *linearVarType = linearVarTypes[index];
191 llvm::Value *iv = loopInductionVar;
192 llvm::Value *step = linearSteps[index];
193
194 if (!iv->getType()->isIntegerTy())
195 llvm_unreachable("OpenMP loop induction variable must be an integer "
196 "type");
197
198 if (linearVarType->isIntegerTy()) {
199 // Integer path: normalize all arithmetic to linearVarType
200 iv = builder.CreateSExtOrTrunc(iv, linearVarType);
201 step = builder.CreateSExtOrTrunc(step, linearVarType);
202
203 llvm::LoadInst *linearVarStart =
204 builder.CreateLoad(linearVarType, linearPreconditionVars[index]);
205 llvm::Value *mulInst = builder.CreateMul(iv, step);
206 llvm::Value *addInst = builder.CreateAdd(linearVarStart, mulInst);
207 builder.CreateStore(addInst, linearLoopBodyTemps[index]);
208 } else if (linearVarType->isFloatingPointTy()) {
209 // Float path: perform multiply in integer, then convert to float
210 step = builder.CreateSExtOrTrunc(step, iv->getType());
211 llvm::Value *mulInst = builder.CreateMul(iv, step);
212
213 llvm::LoadInst *linearVarStart =
214 builder.CreateLoad(linearVarType, linearPreconditionVars[index]);
215 llvm::Value *mulFp = builder.CreateSIToFP(mulInst, linearVarType);
216 llvm::Value *addInst = builder.CreateFAdd(linearVarStart, mulFp);
217 builder.CreateStore(addInst, linearLoopBodyTemps[index]);
218 } else {
219 llvm_unreachable(
220 "Linear variable must be of integer or floating-point type");
221 }
222 }
223 }
224
225 // Linear variable finalization is conditional on the last logical iteration.
226 // Create BB splits to manage the same.
227 void splitLinearFiniBB(llvm::IRBuilderBase &builder,
228 llvm::BasicBlock *loopExit) {
229 linearFinalizationBB = loopExit->splitBasicBlock(
230 loopExit->getTerminator(), "omp_loop.linear_finalization");
231 linearExitBB = linearFinalizationBB->splitBasicBlock(
232 linearFinalizationBB->getTerminator(), "omp_loop.linear_exit");
233 linearLastIterExitBB = linearFinalizationBB->splitBasicBlock(
234 linearFinalizationBB->getTerminator(), "omp_loop.linear_lastiter_exit");
235 }
236
237 // Finalize the linear vars
238 llvm::OpenMPIRBuilder::InsertPointOrErrorTy
239 finalizeLinearVar(llvm::IRBuilderBase &builder,
240 LLVM::ModuleTranslation &moduleTranslation,
241 llvm::Value *lastIter) {
242 // Emit condition to check whether last logical iteration is being executed
243 builder.SetInsertPoint(linearFinalizationBB->getTerminator());
244 llvm::Value *loopLastIterLoad = builder.CreateLoad(
245 llvm::Type::getInt32Ty(builder.getContext()), lastIter);
246 llvm::Value *isLast =
247 builder.CreateCmp(llvm::CmpInst::ICMP_NE, loopLastIterLoad,
248 llvm::ConstantInt::get(
249 llvm::Type::getInt32Ty(builder.getContext()), 0));
250 // Store the linear variable values to original variables.
251 builder.SetInsertPoint(linearLastIterExitBB->getTerminator());
252 for (size_t index = 0; index < linearOrigVal.size(); index++) {
253 llvm::LoadInst *linearVarTemp =
254 builder.CreateLoad(linearVarTypes[index], linearLoopBodyTemps[index]);
255 builder.CreateStore(linearVarTemp, linearOrigVal[index]);
256 }
257
258 // Create conditional branch such that the linear variable
259 // values are stored to original variables only at the
260 // last logical iteration
261 builder.SetInsertPoint(linearFinalizationBB->getTerminator());
262 builder.CreateCondBr(isLast, linearLastIterExitBB, linearExitBB);
263 linearFinalizationBB->getTerminator()->eraseFromParent();
264 // Emit barrier
265 builder.SetInsertPoint(linearExitBB->getTerminator());
266 return moduleTranslation.getOpenMPBuilder()->createBarrier(
267 builder.saveIP(), llvm::omp::OMPD_barrier);
268 }
269
270 // Emit stores for linear variables. Useful in case of SIMD
271 // construct.
272 void emitStoresForLinearVar(llvm::IRBuilderBase &builder) {
273 for (size_t index = 0; index < linearOrigVal.size(); index++) {
274 llvm::LoadInst *linearVarTemp =
275 builder.CreateLoad(linearVarTypes[index], linearLoopBodyTemps[index]);
276 builder.CreateStore(linearVarTemp, linearOrigVal[index]);
277 }
278 }
279
280 // Rewrite all uses of the original variable in `BBName`
281 // with the linear variable in-place
282 void rewriteInPlace(llvm::IRBuilderBase &builder, const std::string &BBName,
283 size_t varIndex) {
284 llvm::SmallVector<llvm::User *> users;
285 for (llvm::User *user : linearOrigVal[varIndex]->users())
286 users.push_back(user);
287 for (auto *user : users) {
288 if (auto *userInst = dyn_cast<llvm::Instruction>(user)) {
289 if (userInst->getParent()->getName().str().find(BBName) !=
290 std::string::npos)
291 user->replaceUsesOfWith(linearOrigVal[varIndex],
292 linearLoopBodyTemps[varIndex]);
293 }
294 }
295 }
296};
297
298} // namespace
299
300/// Looks up from the operation from and returns the PrivateClauseOp with
301/// name symbolName
302static omp::PrivateClauseOp findPrivatizer(Operation *from,
303 SymbolRefAttr symbolName) {
304 omp::PrivateClauseOp privatizer =
306 symbolName);
307 assert(privatizer && "privatizer not found in the symbol table");
308 return privatizer;
309}
310
311/// Check whether translation to LLVM IR for the given operation is currently
312/// supported. If not, descriptive diagnostics will be emitted to let users know
313/// this is a not-yet-implemented feature.
314///
315/// \returns success if no unimplemented features are needed to translate the
316/// given operation.
317static LogicalResult checkImplementationStatus(Operation &op) {
318 auto todo = [&op](StringRef clauseName) {
319 return op.emitError() << "not yet implemented: Unhandled clause "
320 << clauseName << " in " << op.getName()
321 << " operation";
322 };
323
324 auto checkAllocate = [&todo](auto op, LogicalResult &result) {
325 if (!op.getAllocateVars().empty() || !op.getAllocatorVars().empty())
326 result = todo("allocate");
327 };
328 auto checkBare = [&todo](auto op, LogicalResult &result) {
329 if (op.getBare())
330 result = todo("ompx_bare");
331 };
332 auto checkCollapse = [&todo](auto op, LogicalResult &result) {
333 if (op.getCollapseNumLoops() > 1)
334 result = todo("collapse");
335 };
336 auto checkDepend = [&todo](auto op, LogicalResult &result) {
337 if (!op.getDependVars().empty() || op.getDependKinds())
338 result = todo("depend");
339 };
340 auto checkHint = [](auto op, LogicalResult &) {
341 if (op.getHint())
342 op.emitWarning("hint clause discarded");
343 };
344 auto checkInReduction = [&todo](auto op, LogicalResult &result) {
345 if (!op.getInReductionVars().empty() || op.getInReductionByref() ||
346 op.getInReductionSyms())
347 result = todo("in_reduction");
348 };
349 auto checkNowait = [&todo](auto op, LogicalResult &result) {
350 if (op.getNowait())
351 result = todo("nowait");
352 };
353 auto checkOrder = [&todo](auto op, LogicalResult &result) {
354 if (op.getOrder() || op.getOrderMod())
355 result = todo("order");
356 };
357 auto checkParLevelSimd = [&todo](auto op, LogicalResult &result) {
358 if (op.getParLevelSimd())
359 result = todo("parallelization-level");
360 };
361 auto checkPrivate = [&todo](auto op, LogicalResult &result) {
362 if (!op.getPrivateVars().empty() || op.getPrivateSyms())
363 result = todo("privatization");
364 };
365 auto checkReduction = [&todo](auto op, LogicalResult &result) {
366 if (isa<omp::TeamsOp>(op) || isa<omp::TaskloopOp>(op))
367 if (!op.getReductionVars().empty() || op.getReductionByref() ||
368 op.getReductionSyms())
369 result = todo("reduction");
370 if (op.getReductionMod() &&
371 op.getReductionMod().value() != omp::ReductionModifier::defaultmod)
372 result = todo("reduction with modifier");
373 };
374 auto checkTaskReduction = [&todo](auto op, LogicalResult &result) {
375 if (!op.getTaskReductionVars().empty() || op.getTaskReductionByref() ||
376 op.getTaskReductionSyms())
377 result = todo("task_reduction");
378 };
379 auto checkNumTeams = [&todo](auto op, LogicalResult &result) {
380 if (op.hasNumTeamsMultiDim())
381 result = todo("num_teams with multi-dimensional values");
382 };
383 auto checkNumThreads = [&todo](auto op, LogicalResult &result) {
384 if (op.hasNumThreadsMultiDim())
385 result = todo("num_threads with multi-dimensional values");
386 };
387
388 auto checkThreadLimit = [&todo](auto op, LogicalResult &result) {
389 if (op.hasThreadLimitMultiDim())
390 result = todo("thread_limit with multi-dimensional values");
391 };
392
393 LogicalResult result = success();
395 .Case([&](omp::DistributeOp op) {
396 checkAllocate(op, result);
397 checkOrder(op, result);
398 })
399 .Case([&](omp::LoopNestOp op) {
400 if (mlir::isa<omp::TaskloopOp>(op.getOperation()->getParentOp()))
401 checkCollapse(op, result);
402 })
403 .Case([&](omp::OrderedRegionOp op) { checkParLevelSimd(op, result); })
404 .Case([&](omp::SectionsOp op) {
405 checkAllocate(op, result);
406 checkPrivate(op, result);
407 checkReduction(op, result);
408 })
409 .Case([&](omp::SingleOp op) {
410 checkAllocate(op, result);
411 checkPrivate(op, result);
412 })
413 .Case([&](omp::TeamsOp op) {
414 checkAllocate(op, result);
415 checkPrivate(op, result);
416 checkNumTeams(op, result);
417 checkThreadLimit(op, result);
418 })
419 .Case([&](omp::TaskOp op) {
420 checkAllocate(op, result);
421 checkInReduction(op, result);
422 })
423 .Case([&](omp::TaskgroupOp op) {
424 checkAllocate(op, result);
425 checkTaskReduction(op, result);
426 })
427 .Case([&](omp::TaskwaitOp op) {
428 checkDepend(op, result);
429 checkNowait(op, result);
430 })
431 .Case([&](omp::TaskloopOp op) {
432 checkAllocate(op, result);
433 checkInReduction(op, result);
434 checkReduction(op, result);
435 })
436 .Case([&](omp::WsloopOp op) {
437 checkAllocate(op, result);
438 checkOrder(op, result);
439 checkReduction(op, result);
440 })
441 .Case([&](omp::ParallelOp op) {
442 checkAllocate(op, result);
443 checkReduction(op, result);
444 checkNumThreads(op, result);
445 })
446 .Case([&](omp::SimdOp op) { checkReduction(op, result); })
447 .Case<omp::AtomicReadOp, omp::AtomicWriteOp, omp::AtomicUpdateOp,
448 omp::AtomicCaptureOp>([&](auto op) { checkHint(op, result); })
449 .Case<omp::TargetEnterDataOp, omp::TargetExitDataOp>(
450 [&](auto op) { checkDepend(op, result); })
451 .Case([&](omp::TargetUpdateOp op) { checkDepend(op, result); })
452 .Case([&](omp::TargetOp op) {
453 checkAllocate(op, result);
454 checkBare(op, result);
455 checkInReduction(op, result);
456 checkThreadLimit(op, result);
457 })
458 .Default([](Operation &) {
459 // Assume all clauses for an operation can be translated unless they are
460 // checked above.
461 });
462 return result;
463}
464
465static LogicalResult handleError(llvm::Error error, Operation &op) {
466 LogicalResult result = success();
467 if (error) {
468 llvm::handleAllErrors(
469 std::move(error),
470 [&](const PreviouslyReportedError &) { result = failure(); },
471 [&](const llvm::ErrorInfoBase &err) {
472 result = op.emitError(err.message());
473 });
474 }
475 return result;
476}
477
478template <typename T>
479static LogicalResult handleError(llvm::Expected<T> &result, Operation &op) {
480 if (!result)
481 return handleError(result.takeError(), op);
482
483 return success();
484}
485
486/// Find the insertion point for allocas given the current insertion point for
487/// normal operations in the builder.
488static llvm::OpenMPIRBuilder::InsertPointTy
489findAllocaInsertPoint(llvm::IRBuilderBase &builder,
490 LLVM::ModuleTranslation &moduleTranslation) {
491 // If there is an alloca insertion point on stack, i.e. we are in a nested
492 // operation and a specific point was provided by some surrounding operation,
493 // use it.
494 llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
495 WalkResult walkResult = moduleTranslation.stackWalk<OpenMPAllocaStackFrame>(
496 [&](OpenMPAllocaStackFrame &frame) {
497 allocaInsertPoint = frame.allocaInsertPoint;
498 return WalkResult::interrupt();
499 });
500 // In cases with multiple levels of outlining, the tree walk might find an
501 // alloca insertion point that is inside the original function while the
502 // builder insertion point is inside the outlined function. We need to make
503 // sure that we do not use it in those cases.
504 if (walkResult.wasInterrupted() &&
505 allocaInsertPoint.getBlock()->getParent() ==
506 builder.GetInsertBlock()->getParent())
507 return allocaInsertPoint;
508
509 // Otherwise, insert to the entry block of the surrounding function.
510 // If the current IRBuilder InsertPoint is the function's entry, it cannot
511 // also be used for alloca insertion which would result in insertion order
512 // confusion. Create a new BasicBlock for the Builder and use the entry block
513 // for the allocs.
514 // TODO: Create a dedicated alloca BasicBlock at function creation such that
515 // we do not need to move the current InertPoint here.
516 if (builder.GetInsertBlock() ==
517 &builder.GetInsertBlock()->getParent()->getEntryBlock()) {
518 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end() &&
519 "Assuming end of basic block");
520 llvm::BasicBlock *entryBB = llvm::BasicBlock::Create(
521 builder.getContext(), "entry", builder.GetInsertBlock()->getParent(),
522 builder.GetInsertBlock()->getNextNode());
523 builder.CreateBr(entryBB);
524 builder.SetInsertPoint(entryBB);
525 }
526
527 llvm::BasicBlock &funcEntryBlock =
528 builder.GetInsertBlock()->getParent()->getEntryBlock();
529 return llvm::OpenMPIRBuilder::InsertPointTy(
530 &funcEntryBlock, funcEntryBlock.getFirstInsertionPt());
531}
532
533/// Find the loop information structure for the loop nest being translated. It
534/// will return a `null` value unless called from the translation function for
535/// a loop wrapper operation after successfully translating its body.
536static llvm::CanonicalLoopInfo *
538 llvm::CanonicalLoopInfo *loopInfo = nullptr;
539 moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
540 [&](OpenMPLoopInfoStackFrame &frame) {
541 loopInfo = frame.loopInfo;
542 return WalkResult::interrupt();
543 });
544 return loopInfo;
545}
546
547/// Converts the given region that appears within an OpenMP dialect operation to
548/// LLVM IR, creating a branch from the `sourceBlock` to the entry block of the
549/// region, and a branch from any block with an successor-less OpenMP terminator
550/// to `continuationBlock`. Populates `continuationBlockPHIs` with the PHI nodes
551/// of the continuation block if provided.
553 Region &region, StringRef blockName, llvm::IRBuilderBase &builder,
554 LLVM::ModuleTranslation &moduleTranslation,
555 SmallVectorImpl<llvm::PHINode *> *continuationBlockPHIs = nullptr) {
556 bool isLoopWrapper = isa<omp::LoopWrapperInterface>(region.getParentOp());
557
558 llvm::BasicBlock *continuationBlock =
559 splitBB(builder, true, "omp.region.cont");
560 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
561
562 llvm::LLVMContext &llvmContext = builder.getContext();
563 for (Block &bb : region) {
564 llvm::BasicBlock *llvmBB = llvm::BasicBlock::Create(
565 llvmContext, blockName, builder.GetInsertBlock()->getParent(),
566 builder.GetInsertBlock()->getNextNode());
567 moduleTranslation.mapBlock(&bb, llvmBB);
568 }
569
570 llvm::Instruction *sourceTerminator = sourceBlock->getTerminator();
571
572 // Terminators (namely YieldOp) may be forwarding values to the region that
573 // need to be available in the continuation block. Collect the types of these
574 // operands in preparation of creating PHI nodes. This is skipped for loop
575 // wrapper operations, for which we know in advance they have no terminators.
576 SmallVector<llvm::Type *> continuationBlockPHITypes;
577 unsigned numYields = 0;
578
579 if (!isLoopWrapper) {
580 bool operandsProcessed = false;
581 for (Block &bb : region.getBlocks()) {
582 if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(bb.getTerminator())) {
583 if (!operandsProcessed) {
584 for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
585 continuationBlockPHITypes.push_back(
586 moduleTranslation.convertType(yield->getOperand(i).getType()));
587 }
588 operandsProcessed = true;
589 } else {
590 assert(continuationBlockPHITypes.size() == yield->getNumOperands() &&
591 "mismatching number of values yielded from the region");
592 for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
593 llvm::Type *operandType =
594 moduleTranslation.convertType(yield->getOperand(i).getType());
595 (void)operandType;
596 assert(continuationBlockPHITypes[i] == operandType &&
597 "values of mismatching types yielded from the region");
598 }
599 }
600 numYields++;
601 }
602 }
603 }
604
605 // Insert PHI nodes in the continuation block for any values forwarded by the
606 // terminators in this region.
607 if (!continuationBlockPHITypes.empty())
608 assert(
609 continuationBlockPHIs &&
610 "expected continuation block PHIs if converted regions yield values");
611 if (continuationBlockPHIs) {
612 llvm::IRBuilderBase::InsertPointGuard guard(builder);
613 continuationBlockPHIs->reserve(continuationBlockPHITypes.size());
614 builder.SetInsertPoint(continuationBlock, continuationBlock->begin());
615 for (llvm::Type *ty : continuationBlockPHITypes)
616 continuationBlockPHIs->push_back(builder.CreatePHI(ty, numYields));
617 }
618
619 // Convert blocks one by one in topological order to ensure
620 // defs are converted before uses.
622 for (Block *bb : blocks) {
623 llvm::BasicBlock *llvmBB = moduleTranslation.lookupBlock(bb);
624 // Retarget the branch of the entry block to the entry block of the
625 // converted region (regions are single-entry).
626 if (bb->isEntryBlock()) {
627 assert(sourceTerminator->getNumSuccessors() == 1 &&
628 "provided entry block has multiple successors");
629 assert(sourceTerminator->getSuccessor(0) == continuationBlock &&
630 "ContinuationBlock is not the successor of the entry block");
631 sourceTerminator->setSuccessor(0, llvmBB);
632 }
633
634 llvm::IRBuilderBase::InsertPointGuard guard(builder);
635 if (failed(
636 moduleTranslation.convertBlock(*bb, bb->isEntryBlock(), builder)))
637 return llvm::make_error<PreviouslyReportedError>();
638
639 // Create a direct branch here for loop wrappers to prevent their lack of a
640 // terminator from causing a crash below.
641 if (isLoopWrapper) {
642 builder.CreateBr(continuationBlock);
643 continue;
644 }
645
646 // Special handling for `omp.yield` and `omp.terminator` (we may have more
647 // than one): they return the control to the parent OpenMP dialect operation
648 // so replace them with the branch to the continuation block. We handle this
649 // here to avoid relying inter-function communication through the
650 // ModuleTranslation class to set up the correct insertion point. This is
651 // also consistent with MLIR's idiom of handling special region terminators
652 // in the same code that handles the region-owning operation.
653 Operation *terminator = bb->getTerminator();
654 if (isa<omp::TerminatorOp, omp::YieldOp>(terminator)) {
655 builder.CreateBr(continuationBlock);
656
657 for (unsigned i = 0, e = terminator->getNumOperands(); i < e; ++i)
658 (*continuationBlockPHIs)[i]->addIncoming(
659 moduleTranslation.lookupValue(terminator->getOperand(i)), llvmBB);
660 }
661 }
662 // After all blocks have been traversed and values mapped, connect the PHI
663 // nodes to the results of preceding blocks.
664 LLVM::detail::connectPHINodes(region, moduleTranslation);
665
666 // Remove the blocks and values defined in this region from the mapping since
667 // they are not visible outside of this region. This allows the same region to
668 // be converted several times, that is cloned, without clashes, and slightly
669 // speeds up the lookups.
670 moduleTranslation.forgetMapping(region);
671
672 return continuationBlock;
673}
674
675/// Convert ProcBindKind from MLIR-generated enum to LLVM enum.
676static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind) {
677 switch (kind) {
678 case omp::ClauseProcBindKind::Close:
679 return llvm::omp::ProcBindKind::OMP_PROC_BIND_close;
680 case omp::ClauseProcBindKind::Master:
681 return llvm::omp::ProcBindKind::OMP_PROC_BIND_master;
682 case omp::ClauseProcBindKind::Primary:
683 return llvm::omp::ProcBindKind::OMP_PROC_BIND_primary;
684 case omp::ClauseProcBindKind::Spread:
685 return llvm::omp::ProcBindKind::OMP_PROC_BIND_spread;
686 }
687 llvm_unreachable("Unknown ClauseProcBindKind kind");
688}
689
690/// Converts an OpenMP 'masked' operation into LLVM IR using OpenMPIRBuilder.
691static LogicalResult
692convertOmpMasked(Operation &opInst, llvm::IRBuilderBase &builder,
693 LLVM::ModuleTranslation &moduleTranslation) {
694 auto maskedOp = cast<omp::MaskedOp>(opInst);
695 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
696
697 if (failed(checkImplementationStatus(opInst)))
698 return failure();
699
700 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
701 // MaskedOp has only one region associated with it.
702 auto &region = maskedOp.getRegion();
703 builder.restoreIP(codeGenIP);
704 return convertOmpOpRegions(region, "omp.masked.region", builder,
705 moduleTranslation)
706 .takeError();
707 };
708
709 // TODO: Perform finalization actions for variables. This has to be
710 // called for variables which have destructors/finalizers.
711 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
712
713 llvm::Value *filterVal = nullptr;
714 if (auto filterVar = maskedOp.getFilteredThreadId()) {
715 filterVal = moduleTranslation.lookupValue(filterVar);
716 } else {
717 llvm::LLVMContext &llvmContext = builder.getContext();
718 filterVal =
719 llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext), /*V=*/0);
720 }
721 assert(filterVal != nullptr);
722 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
723 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
724 moduleTranslation.getOpenMPBuilder()->createMasked(ompLoc, bodyGenCB,
725 finiCB, filterVal);
726
727 if (failed(handleError(afterIP, opInst)))
728 return failure();
729
730 builder.restoreIP(*afterIP);
731 return success();
732}
733
734/// Converts an OpenMP 'master' operation into LLVM IR using OpenMPIRBuilder.
735static LogicalResult
736convertOmpMaster(Operation &opInst, llvm::IRBuilderBase &builder,
737 LLVM::ModuleTranslation &moduleTranslation) {
738 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
739 auto masterOp = cast<omp::MasterOp>(opInst);
740
741 if (failed(checkImplementationStatus(opInst)))
742 return failure();
743
744 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
745 // MasterOp has only one region associated with it.
746 auto &region = masterOp.getRegion();
747 builder.restoreIP(codeGenIP);
748 return convertOmpOpRegions(region, "omp.master.region", builder,
749 moduleTranslation)
750 .takeError();
751 };
752
753 // TODO: Perform finalization actions for variables. This has to be
754 // called for variables which have destructors/finalizers.
755 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
756
757 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
758 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
759 moduleTranslation.getOpenMPBuilder()->createMaster(ompLoc, bodyGenCB,
760 finiCB);
761
762 if (failed(handleError(afterIP, opInst)))
763 return failure();
764
765 builder.restoreIP(*afterIP);
766 return success();
767}
768
769/// Converts an OpenMP 'critical' operation into LLVM IR using OpenMPIRBuilder.
770static LogicalResult
771convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder,
772 LLVM::ModuleTranslation &moduleTranslation) {
773 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
774 auto criticalOp = cast<omp::CriticalOp>(opInst);
775
776 if (failed(checkImplementationStatus(opInst)))
777 return failure();
778
779 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
780 // CriticalOp has only one region associated with it.
781 auto &region = cast<omp::CriticalOp>(opInst).getRegion();
782 builder.restoreIP(codeGenIP);
783 return convertOmpOpRegions(region, "omp.critical.region", builder,
784 moduleTranslation)
785 .takeError();
786 };
787
788 // TODO: Perform finalization actions for variables. This has to be
789 // called for variables which have destructors/finalizers.
790 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
791
792 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
793 llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
794 llvm::Constant *hint = nullptr;
795
796 // If it has a name, it probably has a hint too.
797 if (criticalOp.getNameAttr()) {
798 // The verifiers in OpenMP Dialect guarentee that all the pointers are
799 // non-null
800 auto symbolRef = cast<SymbolRefAttr>(criticalOp.getNameAttr());
801 auto criticalDeclareOp =
803 symbolRef);
804 hint =
805 llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext),
806 static_cast<int>(criticalDeclareOp.getHint()));
807 }
808 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
809 moduleTranslation.getOpenMPBuilder()->createCritical(
810 ompLoc, bodyGenCB, finiCB, criticalOp.getName().value_or(""), hint);
811
812 if (failed(handleError(afterIP, opInst)))
813 return failure();
814
815 builder.restoreIP(*afterIP);
816 return success();
817}
818
819/// A util to collect info needed to convert delayed privatizers from MLIR to
820/// LLVM.
822 template <typename OP>
824 : blockArgs(
825 cast<omp::BlockArgOpenMPOpInterface>(*op).getPrivateBlockArgs()) {
826 mlirVars.reserve(blockArgs.size());
827 llvmVars.reserve(blockArgs.size());
828 collectPrivatizationDecls<OP>(op);
829
830 for (mlir::Value privateVar : op.getPrivateVars())
831 mlirVars.push_back(privateVar);
832 }
833
838
839private:
840 /// Populates `privatizations` with privatization declarations used for the
841 /// given op.
842 template <class OP>
843 void collectPrivatizationDecls(OP op) {
844 std::optional<ArrayAttr> attr = op.getPrivateSyms();
845 if (!attr)
846 return;
847
848 privatizers.reserve(privatizers.size() + attr->size());
849 for (auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
850 privatizers.push_back(findPrivatizer(op, symbolRef));
851 }
852 }
853};
854
855/// Populates `reductions` with reduction declarations used in the given op.
856template <typename T>
857static void
860 std::optional<ArrayAttr> attr = op.getReductionSyms();
861 if (!attr)
862 return;
863
864 reductions.reserve(reductions.size() + op.getNumReductionVars());
865 for (auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
866 reductions.push_back(
868 op, symbolRef));
869 }
870}
871
872/// Translates the blocks contained in the given region and appends them to at
873/// the current insertion point of `builder`. The operations of the entry block
874/// are appended to the current insertion block. If set, `continuationBlockArgs`
875/// is populated with translated values that correspond to the values
876/// omp.yield'ed from the region.
877static LogicalResult inlineConvertOmpRegions(
878 Region &region, StringRef blockName, llvm::IRBuilderBase &builder,
879 LLVM::ModuleTranslation &moduleTranslation,
880 SmallVectorImpl<llvm::Value *> *continuationBlockArgs = nullptr) {
881 if (region.empty())
882 return success();
883
884 // Special case for single-block regions that don't create additional blocks:
885 // insert operations without creating additional blocks.
886 if (region.hasOneBlock()) {
887 llvm::Instruction *potentialTerminator =
888 builder.GetInsertBlock()->empty() ? nullptr
889 : &builder.GetInsertBlock()->back();
890
891 if (potentialTerminator && potentialTerminator->isTerminator())
892 potentialTerminator->removeFromParent();
893 moduleTranslation.mapBlock(&region.front(), builder.GetInsertBlock());
894
895 if (failed(moduleTranslation.convertBlock(
896 region.front(), /*ignoreArguments=*/true, builder)))
897 return failure();
898
899 // The continuation arguments are simply the translated terminator operands.
900 if (continuationBlockArgs)
901 llvm::append_range(
902 *continuationBlockArgs,
903 moduleTranslation.lookupValues(region.front().back().getOperands()));
904
905 // Drop the mapping that is no longer necessary so that the same region can
906 // be processed multiple times.
907 moduleTranslation.forgetMapping(region);
908
909 if (potentialTerminator && potentialTerminator->isTerminator()) {
910 llvm::BasicBlock *block = builder.GetInsertBlock();
911 if (block->empty()) {
912 // this can happen for really simple reduction init regions e.g.
913 // %0 = llvm.mlir.constant(0 : i32) : i32
914 // omp.yield(%0 : i32)
915 // because the llvm.mlir.constant (MLIR op) isn't converted into any
916 // llvm op
917 potentialTerminator->insertInto(block, block->begin());
918 } else {
919 potentialTerminator->insertAfter(&block->back());
920 }
921 }
922
923 return success();
924 }
925
927 llvm::Expected<llvm::BasicBlock *> continuationBlock =
928 convertOmpOpRegions(region, blockName, builder, moduleTranslation, &phis);
929
930 if (failed(handleError(continuationBlock, *region.getParentOp())))
931 return failure();
932
933 if (continuationBlockArgs)
934 llvm::append_range(*continuationBlockArgs, phis);
935 builder.SetInsertPoint(*continuationBlock,
936 (*continuationBlock)->getFirstInsertionPt());
937 return success();
938}
939
940namespace {
941/// Owning equivalents of OpenMPIRBuilder::(Atomic)ReductionGen that are used to
942/// store lambdas with capture.
943using OwningReductionGen =
944 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
945 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *,
946 llvm::Value *&)>;
947using OwningAtomicReductionGen =
948 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
949 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Type *, llvm::Value *,
950 llvm::Value *)>;
951using OwningDataPtrPtrReductionGen =
952 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
953 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *&)>;
954} // namespace
955
956/// Create an OpenMPIRBuilder-compatible reduction generator for the given
957/// reduction declaration. The generator uses `builder` but ignores its
958/// insertion point.
959static OwningReductionGen
960makeReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder,
961 LLVM::ModuleTranslation &moduleTranslation) {
962 // The lambda is mutable because we need access to non-const methods of decl
963 // (which aren't actually mutating it), and we must capture decl by-value to
964 // avoid the dangling reference after the parent function returns.
965 OwningReductionGen gen =
966 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
967 llvm::Value *lhs, llvm::Value *rhs,
968 llvm::Value *&result) mutable
969 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
970 moduleTranslation.mapValue(decl.getReductionLhsArg(), lhs);
971 moduleTranslation.mapValue(decl.getReductionRhsArg(), rhs);
972 builder.restoreIP(insertPoint);
974 if (failed(inlineConvertOmpRegions(decl.getReductionRegion(),
975 "omp.reduction.nonatomic.body", builder,
976 moduleTranslation, &phis)))
977 return llvm::createStringError(
978 "failed to inline `combiner` region of `omp.declare_reduction`");
979 result = llvm::getSingleElement(phis);
980 return builder.saveIP();
981 };
982 return gen;
983}
984
985/// Create an OpenMPIRBuilder-compatible atomic reduction generator for the
986/// given reduction declaration. The generator uses `builder` but ignores its
987/// insertion point. Returns null if there is no atomic region available in the
988/// reduction declaration.
989static OwningAtomicReductionGen
990makeAtomicReductionGen(omp::DeclareReductionOp decl,
991 llvm::IRBuilderBase &builder,
992 LLVM::ModuleTranslation &moduleTranslation) {
993 if (decl.getAtomicReductionRegion().empty())
994 return OwningAtomicReductionGen();
995
996 // The lambda is mutable because we need access to non-const methods of decl
997 // (which aren't actually mutating it), and we must capture decl by-value to
998 // avoid the dangling reference after the parent function returns.
999 OwningAtomicReductionGen atomicGen =
1000 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint, llvm::Type *,
1001 llvm::Value *lhs, llvm::Value *rhs) mutable
1002 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1003 moduleTranslation.mapValue(decl.getAtomicReductionLhsArg(), lhs);
1004 moduleTranslation.mapValue(decl.getAtomicReductionRhsArg(), rhs);
1005 builder.restoreIP(insertPoint);
1007 if (failed(inlineConvertOmpRegions(decl.getAtomicReductionRegion(),
1008 "omp.reduction.atomic.body", builder,
1009 moduleTranslation, &phis)))
1010 return llvm::createStringError(
1011 "failed to inline `atomic` region of `omp.declare_reduction`");
1012 assert(phis.empty());
1013 return builder.saveIP();
1014 };
1015 return atomicGen;
1016}
1017
1018/// Create an OpenMPIRBuilder-compatible `data_ptr_ptr` reduction generator for
1019/// the given reduction declaration. The generator uses `builder` but ignores
1020/// its insertion point. Returns null if there is no `data_ptr_ptr` region
1021/// available in the reduction declaration.
1022static OwningDataPtrPtrReductionGen
1023makeRefDataPtrGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder,
1024 LLVM::ModuleTranslation &moduleTranslation, bool isByRef) {
1025 if (!isByRef)
1026 return OwningDataPtrPtrReductionGen();
1027
1028 OwningDataPtrPtrReductionGen refDataPtrGen =
1029 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
1030 llvm::Value *byRefVal, llvm::Value *&result) mutable
1031 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1032 moduleTranslation.mapValue(decl.getDataPtrPtrRegionArg(), byRefVal);
1033 builder.restoreIP(insertPoint);
1035 if (failed(inlineConvertOmpRegions(decl.getDataPtrPtrRegion(),
1036 "omp.data_ptr_ptr.body", builder,
1037 moduleTranslation, &phis)))
1038 return llvm::createStringError(
1039 "failed to inline `data_ptr_ptr` region of `omp.declare_reduction`");
1040 result = llvm::getSingleElement(phis);
1041 return builder.saveIP();
1042 };
1043
1044 return refDataPtrGen;
1045}
1046
1047/// Converts an OpenMP 'ordered' operation into LLVM IR using OpenMPIRBuilder.
1048static LogicalResult
1049convertOmpOrdered(Operation &opInst, llvm::IRBuilderBase &builder,
1050 LLVM::ModuleTranslation &moduleTranslation) {
1051 auto orderedOp = cast<omp::OrderedOp>(opInst);
1052
1053 if (failed(checkImplementationStatus(opInst)))
1054 return failure();
1055
1056 omp::ClauseDepend dependType = *orderedOp.getDoacrossDependType();
1057 bool isDependSource = dependType == omp::ClauseDepend::dependsource;
1058 unsigned numLoops = *orderedOp.getDoacrossNumLoops();
1059 SmallVector<llvm::Value *> vecValues =
1060 moduleTranslation.lookupValues(orderedOp.getDoacrossDependVars());
1061
1062 size_t indexVecValues = 0;
1063 while (indexVecValues < vecValues.size()) {
1064 SmallVector<llvm::Value *> storeValues;
1065 storeValues.reserve(numLoops);
1066 for (unsigned i = 0; i < numLoops; i++) {
1067 storeValues.push_back(vecValues[indexVecValues]);
1068 indexVecValues++;
1069 }
1070 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1071 findAllocaInsertPoint(builder, moduleTranslation);
1072 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1073 builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createOrderedDepend(
1074 ompLoc, allocaIP, numLoops, storeValues, ".cnt.addr", isDependSource));
1075 }
1076 return success();
1077}
1078
1079/// Converts an OpenMP 'ordered_region' operation into LLVM IR using
1080/// OpenMPIRBuilder.
1081static LogicalResult
1082convertOmpOrderedRegion(Operation &opInst, llvm::IRBuilderBase &builder,
1083 LLVM::ModuleTranslation &moduleTranslation) {
1084 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1085 auto orderedRegionOp = cast<omp::OrderedRegionOp>(opInst);
1086
1087 if (failed(checkImplementationStatus(opInst)))
1088 return failure();
1089
1090 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
1091 // OrderedOp has only one region associated with it.
1092 auto &region = cast<omp::OrderedRegionOp>(opInst).getRegion();
1093 builder.restoreIP(codeGenIP);
1094 return convertOmpOpRegions(region, "omp.ordered.region", builder,
1095 moduleTranslation)
1096 .takeError();
1097 };
1098
1099 // TODO: Perform finalization actions for variables. This has to be
1100 // called for variables which have destructors/finalizers.
1101 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
1102
1103 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1104 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1105 moduleTranslation.getOpenMPBuilder()->createOrderedThreadsSimd(
1106 ompLoc, bodyGenCB, finiCB, !orderedRegionOp.getParLevelSimd());
1107
1108 if (failed(handleError(afterIP, opInst)))
1109 return failure();
1110
1111 builder.restoreIP(*afterIP);
1112 return success();
1113}
1114
1115namespace {
1116/// Contains the arguments for an LLVM store operation
1117struct DeferredStore {
1118 DeferredStore(llvm::Value *value, llvm::Value *address)
1119 : value(value), address(address) {}
1120
1121 llvm::Value *value;
1122 llvm::Value *address;
1123};
1124} // namespace
1125
1126/// Allocate space for privatized reduction variables.
1127/// `deferredStores` contains information to create store operations which needs
1128/// to be inserted after all allocas
1129template <typename T>
1130static LogicalResult
1132 llvm::IRBuilderBase &builder,
1133 LLVM::ModuleTranslation &moduleTranslation,
1134 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1136 SmallVectorImpl<llvm::Value *> &privateReductionVariables,
1137 DenseMap<Value, llvm::Value *> &reductionVariableMap,
1138 SmallVectorImpl<DeferredStore> &deferredStores,
1139 llvm::ArrayRef<bool> isByRefs) {
1140 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1141 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1142
1143 // delay creating stores until after all allocas
1144 deferredStores.reserve(loop.getNumReductionVars());
1145
1146 for (std::size_t i = 0; i < loop.getNumReductionVars(); ++i) {
1147 Region &allocRegion = reductionDecls[i].getAllocRegion();
1148 if (isByRefs[i]) {
1149 if (allocRegion.empty())
1150 continue;
1151
1153 if (failed(inlineConvertOmpRegions(allocRegion, "omp.reduction.alloc",
1154 builder, moduleTranslation, &phis)))
1155 return loop.emitError(
1156 "failed to inline `alloc` region of `omp.declare_reduction`");
1157
1158 assert(phis.size() == 1 && "expected one allocation to be yielded");
1159 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1160
1161 // Allocate reduction variable (which is a pointer to the real reduction
1162 // variable allocated in the inlined region)
1163 llvm::Value *var = builder.CreateAlloca(
1164 moduleTranslation.convertType(reductionDecls[i].getType()));
1165
1166 llvm::Type *ptrTy = builder.getPtrTy();
1167 llvm::Value *castVar =
1168 builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
1169 llvm::Value *castPhi =
1170 builder.CreatePointerBitCastOrAddrSpaceCast(phis[0], ptrTy);
1171
1172 deferredStores.emplace_back(castPhi, castVar);
1173
1174 privateReductionVariables[i] = castVar;
1175 moduleTranslation.mapValue(reductionArgs[i], castPhi);
1176 reductionVariableMap.try_emplace(loop.getReductionVars()[i], castPhi);
1177 } else {
1178 assert(allocRegion.empty() &&
1179 "allocaction is implicit for by-val reduction");
1180 llvm::Value *var = builder.CreateAlloca(
1181 moduleTranslation.convertType(reductionDecls[i].getType()));
1182
1183 llvm::Type *ptrTy = builder.getPtrTy();
1184 llvm::Value *castVar =
1185 builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
1186
1187 moduleTranslation.mapValue(reductionArgs[i], castVar);
1188 privateReductionVariables[i] = castVar;
1189 reductionVariableMap.try_emplace(loop.getReductionVars()[i], castVar);
1190 }
1191 }
1192
1193 return success();
1194}
1195
1196/// Map input arguments to reduction initialization region
1197template <typename T>
1198static void
1200 llvm::IRBuilderBase &builder,
1202 DenseMap<Value, llvm::Value *> &reductionVariableMap,
1203 unsigned i) {
1204 // map input argument to the initialization region
1205 mlir::omp::DeclareReductionOp &reduction = reductionDecls[i];
1206 Region &initializerRegion = reduction.getInitializerRegion();
1207 Block &entry = initializerRegion.front();
1208
1209 mlir::Value mlirSource = loop.getReductionVars()[i];
1210 llvm::Value *llvmSource = moduleTranslation.lookupValue(mlirSource);
1211 llvm::Value *origVal = llvmSource;
1212 // If a non-pointer value is expected, load the value from the source pointer.
1213 if (!isa<LLVM::LLVMPointerType>(
1214 reduction.getInitializerMoldArg().getType()) &&
1215 isa<LLVM::LLVMPointerType>(mlirSource.getType())) {
1216 origVal =
1217 builder.CreateLoad(moduleTranslation.convertType(
1218 reduction.getInitializerMoldArg().getType()),
1219 llvmSource, "omp_orig");
1220 }
1221 moduleTranslation.mapValue(reduction.getInitializerMoldArg(), origVal);
1222
1223 if (entry.getNumArguments() > 1) {
1224 llvm::Value *allocation =
1225 reductionVariableMap.lookup(loop.getReductionVars()[i]);
1226 moduleTranslation.mapValue(reduction.getInitializerAllocArg(), allocation);
1227 }
1228}
1229
1230static void
1231setInsertPointForPossiblyEmptyBlock(llvm::IRBuilderBase &builder,
1232 llvm::BasicBlock *block = nullptr) {
1233 if (block == nullptr)
1234 block = builder.GetInsertBlock();
1235
1236 if (block->empty() || block->getTerminator() == nullptr)
1237 builder.SetInsertPoint(block);
1238 else
1239 builder.SetInsertPoint(block->getTerminator());
1240}
1241
1242/// Inline reductions' `init` regions. This functions assumes that the
1243/// `builder`'s insertion point is where the user wants the `init` regions to be
1244/// inlined; i.e. it does not try to find a proper insertion location for the
1245/// `init` regions. It also leaves the `builder's insertions point in a state
1246/// where the user can continue the code-gen directly afterwards.
1247template <typename OP>
1248static LogicalResult
1249initReductionVars(OP op, ArrayRef<BlockArgument> reductionArgs,
1250 llvm::IRBuilderBase &builder,
1251 LLVM::ModuleTranslation &moduleTranslation,
1252 llvm::BasicBlock *latestAllocaBlock,
1254 SmallVectorImpl<llvm::Value *> &privateReductionVariables,
1255 DenseMap<Value, llvm::Value *> &reductionVariableMap,
1256 llvm::ArrayRef<bool> isByRef,
1257 SmallVectorImpl<DeferredStore> &deferredStores) {
1258 if (op.getNumReductionVars() == 0)
1259 return success();
1260
1261 llvm::BasicBlock *initBlock = splitBB(builder, true, "omp.reduction.init");
1262 auto allocaIP = llvm::IRBuilderBase::InsertPoint(
1263 latestAllocaBlock, latestAllocaBlock->getTerminator()->getIterator());
1264 builder.restoreIP(allocaIP);
1265 SmallVector<llvm::Value *> byRefVars(op.getNumReductionVars());
1266
1267 for (unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1268 if (isByRef[i]) {
1269 if (!reductionDecls[i].getAllocRegion().empty())
1270 continue;
1271
1272 // TODO: remove after all users of by-ref are updated to use the alloc
1273 // region: Allocate reduction variable (which is a pointer to the real
1274 // reduciton variable allocated in the inlined region)
1275 byRefVars[i] = builder.CreateAlloca(
1276 moduleTranslation.convertType(reductionDecls[i].getType()));
1277 }
1278 }
1279
1280 setInsertPointForPossiblyEmptyBlock(builder, initBlock);
1281
1282 // store result of the alloc region to the allocated pointer to the real
1283 // reduction variable
1284 for (auto [data, addr] : deferredStores)
1285 builder.CreateStore(data, addr);
1286
1287 // Before the loop, store the initial values of reductions into reduction
1288 // variables. Although this could be done after allocas, we don't want to mess
1289 // up with the alloca insertion point.
1290 for (unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1292
1293 // map block argument to initializer region
1294 mapInitializationArgs(op, moduleTranslation, builder, reductionDecls,
1295 reductionVariableMap, i);
1296
1297 // TODO In some cases (specially on the GPU), the init regions may
1298 // contains stack alloctaions. If the region is inlined in a loop, this is
1299 // problematic. Instead of just inlining the region, handle allocations by
1300 // hoisting fixed length allocations to the function entry and using
1301 // stacksave and restore for variable length ones.
1302 if (failed(inlineConvertOmpRegions(reductionDecls[i].getInitializerRegion(),
1303 "omp.reduction.neutral", builder,
1304 moduleTranslation, &phis)))
1305 return failure();
1306
1307 assert(phis.size() == 1 && "expected one value to be yielded from the "
1308 "reduction neutral element declaration region");
1309
1311
1312 if (isByRef[i]) {
1313 if (!reductionDecls[i].getAllocRegion().empty())
1314 // done in allocReductionVars
1315 continue;
1316
1317 // TODO: this path can be removed once all users of by-ref are updated to
1318 // use an alloc region
1319
1320 // Store the result of the inlined region to the allocated reduction var
1321 // ptr
1322 builder.CreateStore(phis[0], byRefVars[i]);
1323
1324 privateReductionVariables[i] = byRefVars[i];
1325 moduleTranslation.mapValue(reductionArgs[i], phis[0]);
1326 reductionVariableMap.try_emplace(op.getReductionVars()[i], phis[0]);
1327 } else {
1328 // for by-ref case the store is inside of the reduction region
1329 builder.CreateStore(phis[0], privateReductionVariables[i]);
1330 // the rest was handled in allocByValReductionVars
1331 }
1332
1333 // forget the mapping for the initializer region because we might need a
1334 // different mapping if this reduction declaration is re-used for a
1335 // different variable
1336 moduleTranslation.forgetMapping(reductionDecls[i].getInitializerRegion());
1337 }
1338
1339 return success();
1340}
1341
1342/// Collect reduction info
1343template <typename T>
1344static void collectReductionInfo(
1345 T loop, llvm::IRBuilderBase &builder,
1346 LLVM::ModuleTranslation &moduleTranslation,
1349 SmallVectorImpl<OwningAtomicReductionGen> &owningAtomicReductionGens,
1351 const ArrayRef<llvm::Value *> privateReductionVariables,
1353 ArrayRef<bool> isByRef) {
1354 unsigned numReductions = loop.getNumReductionVars();
1355
1356 for (unsigned i = 0; i < numReductions; ++i) {
1357 owningReductionGens.push_back(
1358 makeReductionGen(reductionDecls[i], builder, moduleTranslation));
1359 owningAtomicReductionGens.push_back(
1360 makeAtomicReductionGen(reductionDecls[i], builder, moduleTranslation));
1362 reductionDecls[i], builder, moduleTranslation, isByRef[i]));
1363 }
1364
1365 // Collect the reduction information.
1366 reductionInfos.reserve(numReductions);
1367 for (unsigned i = 0; i < numReductions; ++i) {
1368 llvm::OpenMPIRBuilder::ReductionGenAtomicCBTy atomicGen = nullptr;
1369 if (owningAtomicReductionGens[i])
1370 atomicGen = owningAtomicReductionGens[i];
1371 llvm::Value *variable =
1372 moduleTranslation.lookupValue(loop.getReductionVars()[i]);
1373 mlir::Type allocatedType;
1374 reductionDecls[i].getAllocRegion().walk([&](mlir::Operation *op) {
1375 if (auto alloca = mlir::dyn_cast<LLVM::AllocaOp>(op)) {
1376 allocatedType = alloca.getElemType();
1378 }
1379
1381 });
1382
1383 reductionInfos.push_back(
1384 {moduleTranslation.convertType(reductionDecls[i].getType()), variable,
1385 privateReductionVariables[i],
1386 /*EvaluationKind=*/llvm::OpenMPIRBuilder::EvalKind::Scalar,
1388 /*ReductionGenClang=*/nullptr, atomicGen,
1390 allocatedType ? moduleTranslation.convertType(allocatedType) : nullptr,
1391 reductionDecls[i].getByrefElementType()
1392 ? moduleTranslation.convertType(
1393 *reductionDecls[i].getByrefElementType())
1394 : nullptr});
1395 }
1396}
1397
1398/// handling of DeclareReductionOp's cleanup region
1399static LogicalResult
1401 llvm::ArrayRef<llvm::Value *> privateVariables,
1402 LLVM::ModuleTranslation &moduleTranslation,
1403 llvm::IRBuilderBase &builder, StringRef regionName,
1404 bool shouldLoadCleanupRegionArg = true) {
1405 for (auto [i, cleanupRegion] : llvm::enumerate(cleanupRegions)) {
1406 if (cleanupRegion->empty())
1407 continue;
1408
1409 // map the argument to the cleanup region
1410 Block &entry = cleanupRegion->front();
1411
1412 llvm::Instruction *potentialTerminator =
1413 builder.GetInsertBlock()->empty() ? nullptr
1414 : &builder.GetInsertBlock()->back();
1415 if (potentialTerminator && potentialTerminator->isTerminator())
1416 builder.SetInsertPoint(potentialTerminator);
1417 llvm::Value *privateVarValue =
1418 shouldLoadCleanupRegionArg
1419 ? builder.CreateLoad(
1420 moduleTranslation.convertType(entry.getArgument(0).getType()),
1421 privateVariables[i])
1422 : privateVariables[i];
1423
1424 moduleTranslation.mapValue(entry.getArgument(0), privateVarValue);
1425
1426 if (failed(inlineConvertOmpRegions(*cleanupRegion, regionName, builder,
1427 moduleTranslation)))
1428 return failure();
1429
1430 // clear block argument mapping in case it needs to be re-created with a
1431 // different source for another use of the same reduction decl
1432 moduleTranslation.forgetMapping(*cleanupRegion);
1433 }
1434 return success();
1435}
1436
1437// TODO: not used by ParallelOp
1438template <class OP>
1439static LogicalResult createReductionsAndCleanup(
1440 OP op, llvm::IRBuilderBase &builder,
1441 LLVM::ModuleTranslation &moduleTranslation,
1442 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1444 ArrayRef<llvm::Value *> privateReductionVariables, ArrayRef<bool> isByRef,
1445 bool isNowait = false, bool isTeamsReduction = false) {
1446 // Process the reductions if required.
1447 if (op.getNumReductionVars() == 0)
1448 return success();
1449
1451 SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
1452 SmallVector<OwningDataPtrPtrReductionGen> owningReductionGenRefDataPtrGens;
1454
1455 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
1456
1457 // Create the reduction generators. We need to own them here because
1458 // ReductionInfo only accepts references to the generators.
1459 collectReductionInfo(op, builder, moduleTranslation, reductionDecls,
1460 owningReductionGens, owningAtomicReductionGens,
1461 owningReductionGenRefDataPtrGens,
1462 privateReductionVariables, reductionInfos, isByRef);
1463
1464 // The call to createReductions below expects the block to have a
1465 // terminator. Create an unreachable instruction to serve as terminator
1466 // and remove it later.
1467 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
1468 builder.SetInsertPoint(tempTerminator);
1469 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
1470 ompBuilder->createReductions(builder.saveIP(), allocaIP, reductionInfos,
1471 isByRef, isNowait, isTeamsReduction);
1472
1473 if (failed(handleError(contInsertPoint, *op)))
1474 return failure();
1475
1476 if (!contInsertPoint->getBlock())
1477 return op->emitOpError() << "failed to convert reductions";
1478
1479 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1480 ompBuilder->createBarrier(*contInsertPoint, llvm::omp::OMPD_for);
1481
1482 if (failed(handleError(afterIP, *op)))
1483 return failure();
1484
1485 tempTerminator->eraseFromParent();
1486 builder.restoreIP(*afterIP);
1487
1488 // after the construct, deallocate private reduction variables
1489 SmallVector<Region *> reductionRegions;
1490 llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
1491 [](omp::DeclareReductionOp reductionDecl) {
1492 return &reductionDecl.getCleanupRegion();
1493 });
1494 return inlineOmpRegionCleanup(reductionRegions, privateReductionVariables,
1495 moduleTranslation, builder,
1496 "omp.reduction.cleanup");
1497 return success();
1498}
1499
1500static ArrayRef<bool> getIsByRef(std::optional<ArrayRef<bool>> attr) {
1501 if (!attr)
1502 return {};
1503 return *attr;
1504}
1505
1506// TODO: not used by omp.parallel
1507template <typename OP>
1509 OP op, ArrayRef<BlockArgument> reductionArgs, llvm::IRBuilderBase &builder,
1510 LLVM::ModuleTranslation &moduleTranslation,
1511 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1513 SmallVectorImpl<llvm::Value *> &privateReductionVariables,
1514 DenseMap<Value, llvm::Value *> &reductionVariableMap,
1515 llvm::ArrayRef<bool> isByRef) {
1516 if (op.getNumReductionVars() == 0)
1517 return success();
1518
1519 SmallVector<DeferredStore> deferredStores;
1520
1521 if (failed(allocReductionVars(op, reductionArgs, builder, moduleTranslation,
1522 allocaIP, reductionDecls,
1523 privateReductionVariables, reductionVariableMap,
1524 deferredStores, isByRef)))
1525 return failure();
1526
1527 return initReductionVars(op, reductionArgs, builder, moduleTranslation,
1528 allocaIP.getBlock(), reductionDecls,
1529 privateReductionVariables, reductionVariableMap,
1530 isByRef, deferredStores);
1531}
1532
1533/// Return the llvm::Value * corresponding to the `privateVar` that
1534/// is being privatized. It isn't always as simple as looking up
1535/// moduleTranslation with privateVar. For instance, in case of
1536/// an allocatable, the descriptor for the allocatable is privatized.
1537/// This descriptor is mapped using an MapInfoOp. So, this function
1538/// will return a pointer to the llvm::Value corresponding to the
1539/// block argument for the mapped descriptor.
1540static llvm::Value *
1541findAssociatedValue(Value privateVar, llvm::IRBuilderBase &builder,
1542 LLVM::ModuleTranslation &moduleTranslation,
1543 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1544 if (mappedPrivateVars == nullptr || !mappedPrivateVars->contains(privateVar))
1545 return moduleTranslation.lookupValue(privateVar);
1546
1547 Value blockArg = (*mappedPrivateVars)[privateVar];
1548 Type privVarType = privateVar.getType();
1549 Type blockArgType = blockArg.getType();
1550 assert(isa<LLVM::LLVMPointerType>(blockArgType) &&
1551 "A block argument corresponding to a mapped var should have "
1552 "!llvm.ptr type");
1553
1554 if (privVarType == blockArgType)
1555 return moduleTranslation.lookupValue(blockArg);
1556
1557 // This typically happens when the privatized type is lowered from
1558 // boxchar<KIND> and gets lowered to !llvm.struct<(ptr, i64)>. That is the
1559 // struct/pair is passed by value. But, mapped values are passed only as
1560 // pointers, so before we privatize, we must load the pointer.
1561 if (!isa<LLVM::LLVMPointerType>(privVarType))
1562 return builder.CreateLoad(moduleTranslation.convertType(privVarType),
1563 moduleTranslation.lookupValue(blockArg));
1564
1565 return moduleTranslation.lookupValue(privateVar);
1566}
1567
1568/// Initialize a single (first)private variable. You probably want to use
1569/// allocateAndInitPrivateVars instead of this.
1570/// This returns the private variable which has been initialized. This
1571/// variable should be mapped before constructing the body of the Op.
1573initPrivateVar(llvm::IRBuilderBase &builder,
1574 LLVM::ModuleTranslation &moduleTranslation,
1575 omp::PrivateClauseOp &privDecl, llvm::Value *nonPrivateVar,
1576 BlockArgument &blockArg, llvm::Value *llvmPrivateVar,
1577 llvm::BasicBlock *privInitBlock,
1578 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1579 Region &initRegion = privDecl.getInitRegion();
1580 if (initRegion.empty())
1581 return llvmPrivateVar;
1582
1583 assert(nonPrivateVar);
1584 moduleTranslation.mapValue(privDecl.getInitMoldArg(), nonPrivateVar);
1585 moduleTranslation.mapValue(privDecl.getInitPrivateArg(), llvmPrivateVar);
1586
1587 // in-place convert the private initialization region
1589 if (failed(inlineConvertOmpRegions(initRegion, "omp.private.init", builder,
1590 moduleTranslation, &phis)))
1591 return llvm::createStringError(
1592 "failed to inline `init` region of `omp.private`");
1593
1594 assert(phis.size() == 1 && "expected one allocation to be yielded");
1595
1596 // clear init region block argument mapping in case it needs to be
1597 // re-created with a different source for another use of the same
1598 // reduction decl
1599 moduleTranslation.forgetMapping(initRegion);
1600
1601 // Prefer the value yielded from the init region to the allocated private
1602 // variable in case the region is operating on arguments by-value (e.g.
1603 // Fortran character boxes).
1604 return phis[0];
1605}
1606
1607/// Version of initPrivateVar which looks up the nonPrivateVar from mlirPrivVar.
1609 llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation,
1610 omp::PrivateClauseOp &privDecl, Value mlirPrivVar, BlockArgument &blockArg,
1611 llvm::Value *llvmPrivateVar, llvm::BasicBlock *privInitBlock,
1612 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1613 return initPrivateVar(
1614 builder, moduleTranslation, privDecl,
1615 findAssociatedValue(mlirPrivVar, builder, moduleTranslation,
1616 mappedPrivateVars),
1617 blockArg, llvmPrivateVar, privInitBlock, mappedPrivateVars);
1618}
1619
1620static llvm::Error
1621initPrivateVars(llvm::IRBuilderBase &builder,
1622 LLVM::ModuleTranslation &moduleTranslation,
1623 PrivateVarsInfo &privateVarsInfo,
1624 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1625 if (privateVarsInfo.blockArgs.empty())
1626 return llvm::Error::success();
1627
1628 llvm::BasicBlock *privInitBlock = splitBB(builder, true, "omp.private.init");
1629 setInsertPointForPossiblyEmptyBlock(builder, privInitBlock);
1630
1631 for (auto [idx, zip] : llvm::enumerate(llvm::zip_equal(
1632 privateVarsInfo.privatizers, privateVarsInfo.mlirVars,
1633 privateVarsInfo.blockArgs, privateVarsInfo.llvmVars))) {
1634 auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVar] = zip;
1636 builder, moduleTranslation, privDecl, mlirPrivVar, blockArg,
1637 llvmPrivateVar, privInitBlock, mappedPrivateVars);
1638
1639 if (!privVarOrErr)
1640 return privVarOrErr.takeError();
1641
1642 llvmPrivateVar = privVarOrErr.get();
1643 moduleTranslation.mapValue(blockArg, llvmPrivateVar);
1644
1646 }
1647
1648 return llvm::Error::success();
1649}
1650
1651/// Allocate and initialize delayed private variables. Returns the basic block
1652/// which comes after all of these allocations. llvm::Value * for each of these
1653/// private variables are populated in llvmPrivateVars.
1655allocatePrivateVars(llvm::IRBuilderBase &builder,
1656 LLVM::ModuleTranslation &moduleTranslation,
1657 PrivateVarsInfo &privateVarsInfo,
1658 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1659 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1660 // Allocate private vars
1661 llvm::Instruction *allocaTerminator = allocaIP.getBlock()->getTerminator();
1662 splitBB(llvm::OpenMPIRBuilder::InsertPointTy(allocaIP.getBlock(),
1663 allocaTerminator->getIterator()),
1664 true, allocaTerminator->getStableDebugLoc(),
1665 "omp.region.after_alloca");
1666
1667 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1668 // Update the allocaTerminator since the alloca block was split above.
1669 allocaTerminator = allocaIP.getBlock()->getTerminator();
1670 builder.SetInsertPoint(allocaTerminator);
1671 // The new terminator is an uncondition branch created by the splitBB above.
1672 assert(allocaTerminator->getNumSuccessors() == 1 &&
1673 "This is an unconditional branch created by splitBB");
1674
1675 llvm::DataLayout dataLayout = builder.GetInsertBlock()->getDataLayout();
1676 llvm::BasicBlock *afterAllocas = allocaTerminator->getSuccessor(0);
1677
1678 unsigned int allocaAS =
1679 moduleTranslation.getLLVMModule()->getDataLayout().getAllocaAddrSpace();
1680 unsigned int defaultAS = moduleTranslation.getLLVMModule()
1681 ->getDataLayout()
1682 .getProgramAddressSpace();
1683
1684 for (auto [privDecl, mlirPrivVar, blockArg] :
1685 llvm::zip_equal(privateVarsInfo.privatizers, privateVarsInfo.mlirVars,
1686 privateVarsInfo.blockArgs)) {
1687 llvm::Type *llvmAllocType =
1688 moduleTranslation.convertType(privDecl.getType());
1689 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1690 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
1691 llvmAllocType, /*ArraySize=*/nullptr, "omp.private.alloc");
1692 if (allocaAS != defaultAS)
1693 llvmPrivateVar = builder.CreateAddrSpaceCast(llvmPrivateVar,
1694 builder.getPtrTy(defaultAS));
1695
1696 privateVarsInfo.llvmVars.push_back(llvmPrivateVar);
1697 }
1698
1699 return afterAllocas;
1700}
1701
1702/// This can't always be determined statically, but when we can, it is good to
1703/// avoid generating compiler-added barriers which will deadlock the program.
1705 for (mlir::Operation *parent = op->getParentOp(); parent != nullptr;
1706 parent = parent->getParentOp()) {
1707 if (mlir::isa<omp::SingleOp, omp::CriticalOp>(parent))
1708 return true;
1709
1710 // e.g.
1711 // omp.single {
1712 // omp.parallel {
1713 // op
1714 // }
1715 // }
1716 if (mlir::isa<omp::ParallelOp>(parent))
1717 return false;
1718 }
1719 return false;
1720}
1721
1722static LogicalResult copyFirstPrivateVars(
1723 mlir::Operation *op, llvm::IRBuilderBase &builder,
1724 LLVM::ModuleTranslation &moduleTranslation,
1726 ArrayRef<llvm::Value *> llvmPrivateVars,
1727 SmallVectorImpl<omp::PrivateClauseOp> &privateDecls, bool insertBarrier,
1728 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1729 // Apply copy region for firstprivate.
1730 bool needsFirstprivate =
1731 llvm::any_of(privateDecls, [](omp::PrivateClauseOp &privOp) {
1732 return privOp.getDataSharingType() ==
1733 omp::DataSharingClauseType::FirstPrivate;
1734 });
1735
1736 if (!needsFirstprivate)
1737 return success();
1738
1739 llvm::BasicBlock *copyBlock =
1740 splitBB(builder, /*CreateBranch=*/true, "omp.private.copy");
1741 setInsertPointForPossiblyEmptyBlock(builder, copyBlock);
1742
1743 for (auto [decl, moldVar, llvmVar] :
1744 llvm::zip_equal(privateDecls, moldVars, llvmPrivateVars)) {
1745 if (decl.getDataSharingType() != omp::DataSharingClauseType::FirstPrivate)
1746 continue;
1747
1748 // copyRegion implements `lhs = rhs`
1749 Region &copyRegion = decl.getCopyRegion();
1750
1751 moduleTranslation.mapValue(decl.getCopyMoldArg(), moldVar);
1752
1753 // map copyRegion lhs arg
1754 moduleTranslation.mapValue(decl.getCopyPrivateArg(), llvmVar);
1755
1756 // in-place convert copy region
1757 if (failed(inlineConvertOmpRegions(copyRegion, "omp.private.copy", builder,
1758 moduleTranslation)))
1759 return decl.emitError("failed to inline `copy` region of `omp.private`");
1760
1762
1763 // ignore unused value yielded from copy region
1764
1765 // clear copy region block argument mapping in case it needs to be
1766 // re-created with different sources for reuse of the same reduction
1767 // decl
1768 moduleTranslation.forgetMapping(copyRegion);
1769 }
1770
1771 if (insertBarrier && !opIsInSingleThread(op)) {
1772 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
1773 llvm::OpenMPIRBuilder::InsertPointOrErrorTy res =
1774 ompBuilder->createBarrier(builder.saveIP(), llvm::omp::OMPD_barrier);
1775 if (failed(handleError(res, *op)))
1776 return failure();
1777 }
1778
1779 return success();
1780}
1781
1782static LogicalResult copyFirstPrivateVars(
1783 mlir::Operation *op, llvm::IRBuilderBase &builder,
1784 LLVM::ModuleTranslation &moduleTranslation,
1785 SmallVectorImpl<mlir::Value> &mlirPrivateVars,
1786 ArrayRef<llvm::Value *> llvmPrivateVars,
1787 SmallVectorImpl<omp::PrivateClauseOp> &privateDecls, bool insertBarrier,
1788 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1789 llvm::SmallVector<llvm::Value *> moldVars(mlirPrivateVars.size());
1790 llvm::transform(mlirPrivateVars, moldVars.begin(), [&](mlir::Value mlirVar) {
1791 // map copyRegion rhs arg
1792 llvm::Value *moldVar = findAssociatedValue(
1793 mlirVar, builder, moduleTranslation, mappedPrivateVars);
1794 assert(moldVar);
1795 return moldVar;
1796 });
1797 return copyFirstPrivateVars(op, builder, moduleTranslation, moldVars,
1798 llvmPrivateVars, privateDecls, insertBarrier,
1799 mappedPrivateVars);
1800}
1801
1802static LogicalResult
1803cleanupPrivateVars(llvm::IRBuilderBase &builder,
1804 LLVM::ModuleTranslation &moduleTranslation, Location loc,
1805 SmallVectorImpl<llvm::Value *> &llvmPrivateVars,
1807 // private variable deallocation
1808 SmallVector<Region *> privateCleanupRegions;
1809 llvm::transform(privateDecls, std::back_inserter(privateCleanupRegions),
1810 [](omp::PrivateClauseOp privatizer) {
1811 return &privatizer.getDeallocRegion();
1812 });
1813
1814 if (failed(inlineOmpRegionCleanup(
1815 privateCleanupRegions, llvmPrivateVars, moduleTranslation, builder,
1816 "omp.private.dealloc", /*shouldLoadCleanupRegionArg=*/false)))
1817 return mlir::emitError(loc, "failed to inline `dealloc` region of an "
1818 "`omp.private` op in");
1819
1820 return success();
1821}
1822
1823/// Returns true if the construct contains omp.cancel or omp.cancellation_point
1825 // omp.cancel and omp.cancellation_point must be "closely nested" so they will
1826 // be visible and not inside of function calls. This is enforced by the
1827 // verifier.
1828 return op
1829 ->walk([](Operation *child) {
1830 if (mlir::isa<omp::CancelOp, omp::CancellationPointOp>(child))
1831 return WalkResult::interrupt();
1832 return WalkResult::advance();
1833 })
1834 .wasInterrupted();
1835}
1836
1837static LogicalResult
1838convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder,
1839 LLVM::ModuleTranslation &moduleTranslation) {
1840 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1841 using StorableBodyGenCallbackTy =
1842 llvm::OpenMPIRBuilder::StorableBodyGenCallbackTy;
1843
1844 auto sectionsOp = cast<omp::SectionsOp>(opInst);
1845
1846 if (failed(checkImplementationStatus(opInst)))
1847 return failure();
1848
1849 llvm::ArrayRef<bool> isByRef = getIsByRef(sectionsOp.getReductionByref());
1850 assert(isByRef.size() == sectionsOp.getNumReductionVars());
1851
1853 collectReductionDecls(sectionsOp, reductionDecls);
1854 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1855 findAllocaInsertPoint(builder, moduleTranslation);
1856
1857 SmallVector<llvm::Value *> privateReductionVariables(
1858 sectionsOp.getNumReductionVars());
1859 DenseMap<Value, llvm::Value *> reductionVariableMap;
1860
1861 MutableArrayRef<BlockArgument> reductionArgs =
1862 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
1863
1865 sectionsOp, reductionArgs, builder, moduleTranslation, allocaIP,
1866 reductionDecls, privateReductionVariables, reductionVariableMap,
1867 isByRef)))
1868 return failure();
1869
1871
1872 for (Operation &op : *sectionsOp.getRegion().begin()) {
1873 auto sectionOp = dyn_cast<omp::SectionOp>(op);
1874 if (!sectionOp) // omp.terminator
1875 continue;
1876
1877 Region &region = sectionOp.getRegion();
1878 auto sectionCB = [&sectionsOp, &region, &builder, &moduleTranslation](
1879 InsertPointTy allocaIP, InsertPointTy codeGenIP) {
1880 builder.restoreIP(codeGenIP);
1881
1882 // map the omp.section reduction block argument to the omp.sections block
1883 // arguments
1884 // TODO: this assumes that the only block arguments are reduction
1885 // variables
1886 assert(region.getNumArguments() ==
1887 sectionsOp.getRegion().getNumArguments());
1888 for (auto [sectionsArg, sectionArg] : llvm::zip_equal(
1889 sectionsOp.getRegion().getArguments(), region.getArguments())) {
1890 llvm::Value *llvmVal = moduleTranslation.lookupValue(sectionsArg);
1891 assert(llvmVal);
1892 moduleTranslation.mapValue(sectionArg, llvmVal);
1893 }
1894
1895 return convertOmpOpRegions(region, "omp.section.region", builder,
1896 moduleTranslation)
1897 .takeError();
1898 };
1899 sectionCBs.push_back(sectionCB);
1900 }
1901
1902 // No sections within omp.sections operation - skip generation. This situation
1903 // is only possible if there is only a terminator operation inside the
1904 // sections operation
1905 if (sectionCBs.empty())
1906 return success();
1907
1908 assert(isa<omp::SectionOp>(*sectionsOp.getRegion().op_begin()));
1909
1910 // TODO: Perform appropriate actions according to the data-sharing
1911 // attribute (shared, private, firstprivate, ...) of variables.
1912 // Currently defaults to shared.
1913 auto privCB = [&](InsertPointTy, InsertPointTy codeGenIP, llvm::Value &,
1914 llvm::Value &vPtr, llvm::Value *&replacementValue)
1915 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1916 replacementValue = &vPtr;
1917 return codeGenIP;
1918 };
1919
1920 // TODO: Perform finalization actions for variables. This has to be
1921 // called for variables which have destructors/finalizers.
1922 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
1923
1924 allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
1925 bool isCancellable = constructIsCancellable(sectionsOp);
1926 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1927 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1928 moduleTranslation.getOpenMPBuilder()->createSections(
1929 ompLoc, allocaIP, sectionCBs, privCB, finiCB, isCancellable,
1930 sectionsOp.getNowait());
1931
1932 if (failed(handleError(afterIP, opInst)))
1933 return failure();
1934
1935 builder.restoreIP(*afterIP);
1936
1937 // Process the reductions if required.
1939 sectionsOp, builder, moduleTranslation, allocaIP, reductionDecls,
1940 privateReductionVariables, isByRef, sectionsOp.getNowait());
1941}
1942
1943/// Converts an OpenMP single construct into LLVM IR using OpenMPIRBuilder.
1944static LogicalResult
1945convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder,
1946 LLVM::ModuleTranslation &moduleTranslation) {
1947 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1948 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1949
1950 if (failed(checkImplementationStatus(*singleOp)))
1951 return failure();
1952
1953 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
1954 builder.restoreIP(codegenIP);
1955 return convertOmpOpRegions(singleOp.getRegion(), "omp.single.region",
1956 builder, moduleTranslation)
1957 .takeError();
1958 };
1959 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
1960
1961 // Handle copyprivate
1962 Operation::operand_range cpVars = singleOp.getCopyprivateVars();
1963 std::optional<ArrayAttr> cpFuncs = singleOp.getCopyprivateSyms();
1966 for (size_t i = 0, e = cpVars.size(); i < e; ++i) {
1967 llvmCPVars.push_back(moduleTranslation.lookupValue(cpVars[i]));
1969 singleOp, cast<SymbolRefAttr>((*cpFuncs)[i]));
1970 llvmCPFuncs.push_back(
1971 moduleTranslation.lookupFunction(llvmFuncOp.getName()));
1972 }
1973
1974 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1975 moduleTranslation.getOpenMPBuilder()->createSingle(
1976 ompLoc, bodyCB, finiCB, singleOp.getNowait(), llvmCPVars,
1977 llvmCPFuncs);
1978
1979 if (failed(handleError(afterIP, *singleOp)))
1980 return failure();
1981
1982 builder.restoreIP(*afterIP);
1983 return success();
1984}
1985
1986static bool teamsReductionContainedInDistribute(omp::TeamsOp teamsOp) {
1987 auto iface =
1988 llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(teamsOp.getOperation());
1989 // Check that all uses of the reduction block arg has the same distribute op
1990 // parent.
1992 Operation *distOp = nullptr;
1993 for (auto ra : iface.getReductionBlockArgs())
1994 for (auto &use : ra.getUses()) {
1995 auto *useOp = use.getOwner();
1996 // Ignore debug uses.
1997 if (mlir::isa<LLVM::DbgDeclareOp, LLVM::DbgValueOp>(useOp)) {
1998 debugUses.push_back(useOp);
1999 continue;
2000 }
2001
2002 auto currentDistOp = useOp->getParentOfType<omp::DistributeOp>();
2003 // Use is not inside a distribute op - return false
2004 if (!currentDistOp)
2005 return false;
2006 // Multiple distribute operations - return false
2007 Operation *currentOp = currentDistOp.getOperation();
2008 if (distOp && (distOp != currentOp))
2009 return false;
2010
2011 distOp = currentOp;
2012 }
2013
2014 // If we are going to use distribute reduction then remove any debug uses of
2015 // the reduction parameters in teamsOp. Otherwise they will be left without
2016 // any mapped value in moduleTranslation and will eventually error out.
2017 for (auto *use : debugUses)
2018 use->erase();
2019 return true;
2020}
2021
2022// Convert an OpenMP Teams construct to LLVM IR using OpenMPIRBuilder
2023static LogicalResult
2024convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder,
2025 LLVM::ModuleTranslation &moduleTranslation) {
2026 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2027 if (failed(checkImplementationStatus(*op)))
2028 return failure();
2029
2030 DenseMap<Value, llvm::Value *> reductionVariableMap;
2031 unsigned numReductionVars = op.getNumReductionVars();
2033 SmallVector<llvm::Value *> privateReductionVariables(numReductionVars);
2034 llvm::ArrayRef<bool> isByRef;
2035 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2036 findAllocaInsertPoint(builder, moduleTranslation);
2037
2038 // Only do teams reduction if there is no distribute op that captures the
2039 // reduction instead.
2040 bool doTeamsReduction = !teamsReductionContainedInDistribute(op);
2041 if (doTeamsReduction) {
2042 isByRef = getIsByRef(op.getReductionByref());
2043
2044 assert(isByRef.size() == op.getNumReductionVars());
2045
2046 MutableArrayRef<BlockArgument> reductionArgs =
2047 llvm::cast<omp::BlockArgOpenMPOpInterface>(*op).getReductionBlockArgs();
2048
2049 collectReductionDecls(op, reductionDecls);
2050
2052 op, reductionArgs, builder, moduleTranslation, allocaIP,
2053 reductionDecls, privateReductionVariables, reductionVariableMap,
2054 isByRef)))
2055 return failure();
2056 }
2057
2058 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
2060 moduleTranslation, allocaIP);
2061 builder.restoreIP(codegenIP);
2062 return convertOmpOpRegions(op.getRegion(), "omp.teams.region", builder,
2063 moduleTranslation)
2064 .takeError();
2065 };
2066
2067 llvm::Value *numTeamsLower = nullptr;
2068 if (Value numTeamsLowerVar = op.getNumTeamsLower())
2069 numTeamsLower = moduleTranslation.lookupValue(numTeamsLowerVar);
2070
2071 llvm::Value *numTeamsUpper = nullptr;
2072 if (!op.getNumTeamsUpperVars().empty())
2073 numTeamsUpper = moduleTranslation.lookupValue(op.getNumTeams(0));
2074
2075 llvm::Value *threadLimit = nullptr;
2076 if (!op.getThreadLimitVars().empty())
2077 threadLimit = moduleTranslation.lookupValue(op.getThreadLimit(0));
2078
2079 llvm::Value *ifExpr = nullptr;
2080 if (Value ifVar = op.getIfExpr())
2081 ifExpr = moduleTranslation.lookupValue(ifVar);
2082
2083 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2084 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2085 moduleTranslation.getOpenMPBuilder()->createTeams(
2086 ompLoc, bodyCB, numTeamsLower, numTeamsUpper, threadLimit, ifExpr);
2087
2088 if (failed(handleError(afterIP, *op)))
2089 return failure();
2090
2091 builder.restoreIP(*afterIP);
2092 if (doTeamsReduction) {
2093 // Process the reductions if required.
2095 op, builder, moduleTranslation, allocaIP, reductionDecls,
2096 privateReductionVariables, isByRef,
2097 /*isNoWait*/ false, /*isTeamsReduction*/ true);
2098 }
2099 return success();
2100}
2101
2102static void
2103buildDependData(std::optional<ArrayAttr> dependKinds, OperandRange dependVars,
2104 LLVM::ModuleTranslation &moduleTranslation,
2106 if (dependVars.empty())
2107 return;
2108 for (auto dep : llvm::zip(dependVars, dependKinds->getValue())) {
2109 llvm::omp::RTLDependenceKindTy type;
2110 switch (
2111 cast<mlir::omp::ClauseTaskDependAttr>(std::get<1>(dep)).getValue()) {
2112 case mlir::omp::ClauseTaskDepend::taskdependin:
2113 type = llvm::omp::RTLDependenceKindTy::DepIn;
2114 break;
2115 // The OpenMP runtime requires that the codegen for 'depend' clause for
2116 // 'out' dependency kind must be the same as codegen for 'depend' clause
2117 // with 'inout' dependency.
2118 case mlir::omp::ClauseTaskDepend::taskdependout:
2119 case mlir::omp::ClauseTaskDepend::taskdependinout:
2120 type = llvm::omp::RTLDependenceKindTy::DepInOut;
2121 break;
2122 case mlir::omp::ClauseTaskDepend::taskdependmutexinoutset:
2123 type = llvm::omp::RTLDependenceKindTy::DepMutexInOutSet;
2124 break;
2125 case mlir::omp::ClauseTaskDepend::taskdependinoutset:
2126 type = llvm::omp::RTLDependenceKindTy::DepInOutSet;
2127 break;
2128 };
2129 llvm::Value *depVal = moduleTranslation.lookupValue(std::get<0>(dep));
2130 llvm::OpenMPIRBuilder::DependData dd(type, depVal->getType(), depVal);
2131 dds.emplace_back(dd);
2132 }
2133}
2134
2135/// Shared implementation of a callback which adds a termiator for the new block
2136/// created for the branch taken when an openmp construct is cancelled. The
2137/// terminator is saved in \p cancelTerminators. This callback is invoked only
2138/// if there is cancellation inside of the taskgroup body.
2139/// The terminator will need to be fixed to branch to the correct block to
2140/// cleanup the construct.
2141static void
2143 llvm::IRBuilderBase &llvmBuilder,
2144 llvm::OpenMPIRBuilder &ompBuilder, mlir::Operation *op,
2145 llvm::omp::Directive cancelDirective) {
2146 auto finiCB = [&](llvm::OpenMPIRBuilder::InsertPointTy ip) -> llvm::Error {
2147 llvm::IRBuilderBase::InsertPointGuard guard(llvmBuilder);
2148
2149 // ip is currently in the block branched to if cancellation occurred.
2150 // We need to create a branch to terminate that block.
2151 llvmBuilder.restoreIP(ip);
2152
2153 // We must still clean up the construct after cancelling it, so we need to
2154 // branch to the block that finalizes the taskgroup.
2155 // That block has not been created yet so use this block as a dummy for now
2156 // and fix this after creating the operation.
2157 cancelTerminators.push_back(llvmBuilder.CreateBr(ip.getBlock()));
2158 return llvm::Error::success();
2159 };
2160 // We have to add the cleanup to the OpenMPIRBuilder before the body gets
2161 // created in case the body contains omp.cancel (which will then expect to be
2162 // able to find this cleanup callback).
2163 ompBuilder.pushFinalizationCB(
2164 {finiCB, cancelDirective, constructIsCancellable(op)});
2165}
2166
2167/// If we cancelled the construct, we should branch to the finalization block of
2168/// that construct. OMPIRBuilder structures the CFG such that the cleanup block
2169/// is immediately before the continuation block. Now this finalization has
2170/// been created we can fix the branch.
2171static void
2173 llvm::OpenMPIRBuilder &ompBuilder,
2174 const llvm::OpenMPIRBuilder::InsertPointTy &afterIP) {
2175 ompBuilder.popFinalizationCB();
2176 llvm::BasicBlock *constructFini = afterIP.getBlock()->getSinglePredecessor();
2177 for (llvm::BranchInst *cancelBranch : cancelTerminators) {
2178 assert(cancelBranch->getNumSuccessors() == 1 &&
2179 "cancel branch should have one target");
2180 cancelBranch->setSuccessor(0, constructFini);
2181 }
2182}
2183
2184namespace {
2185/// TaskContextStructManager takes care of creating and freeing a structure
2186/// containing information needed by the task body to execute.
2187class TaskContextStructManager {
2188public:
2189 TaskContextStructManager(llvm::IRBuilderBase &builder,
2190 LLVM::ModuleTranslation &moduleTranslation,
2191 MutableArrayRef<omp::PrivateClauseOp> privateDecls)
2192 : builder{builder}, moduleTranslation{moduleTranslation},
2193 privateDecls{privateDecls} {}
2194
2195 /// Creates a heap allocated struct containing space for each private
2196 /// variable. Invariant: privateVarTypes, privateDecls, and the elements of
2197 /// the structure should all have the same order (although privateDecls which
2198 /// do not read from the mold argument are skipped).
2199 void generateTaskContextStruct();
2200
2201 /// Create GEPs to access each member of the structure representing a private
2202 /// variable, adding them to llvmPrivateVars. Null values are added where
2203 /// private decls were skipped so that the ordering continues to match the
2204 /// private decls.
2205 void createGEPsToPrivateVars();
2206
2207 /// Given the address of the structure, return a GEP for each private variable
2208 /// in the structure. Null values are added where private decls were skipped
2209 /// so that the ordering continues to match the private decls.
2210 /// Must be called after generateTaskContextStruct().
2211 SmallVector<llvm::Value *>
2212 createGEPsToPrivateVars(llvm::Value *altStructPtr) const;
2213
2214 /// De-allocate the task context structure.
2215 void freeStructPtr();
2216
2217 MutableArrayRef<llvm::Value *> getLLVMPrivateVarGEPs() {
2218 return llvmPrivateVarGEPs;
2219 }
2220
2221 llvm::Value *getStructPtr() { return structPtr; }
2222
2223private:
2224 llvm::IRBuilderBase &builder;
2225 LLVM::ModuleTranslation &moduleTranslation;
2226 MutableArrayRef<omp::PrivateClauseOp> privateDecls;
2227
2228 /// The type of each member of the structure, in order.
2229 SmallVector<llvm::Type *> privateVarTypes;
2230
2231 /// LLVM values for each private variable, or null if that private variable is
2232 /// not included in the task context structure
2233 SmallVector<llvm::Value *> llvmPrivateVarGEPs;
2234
2235 /// A pointer to the structure containing context for this task.
2236 llvm::Value *structPtr = nullptr;
2237 /// The type of the structure
2238 llvm::Type *structTy = nullptr;
2239};
2240} // namespace
2241
2242void TaskContextStructManager::generateTaskContextStruct() {
2243 if (privateDecls.empty())
2244 return;
2245 privateVarTypes.reserve(privateDecls.size());
2246
2247 for (omp::PrivateClauseOp &privOp : privateDecls) {
2248 // Skip private variables which can safely be allocated and initialised
2249 // inside of the task
2250 if (!privOp.readsFromMold())
2251 continue;
2252 Type mlirType = privOp.getType();
2253 privateVarTypes.push_back(moduleTranslation.convertType(mlirType));
2254 }
2255
2256 if (privateVarTypes.empty())
2257 return;
2258
2259 structTy = llvm::StructType::get(moduleTranslation.getLLVMContext(),
2260 privateVarTypes);
2261
2262 llvm::DataLayout dataLayout =
2263 builder.GetInsertBlock()->getModule()->getDataLayout();
2264 llvm::Type *intPtrTy = builder.getIntPtrTy(dataLayout);
2265 llvm::Constant *allocSize = llvm::ConstantExpr::getSizeOf(structTy);
2266
2267 // Heap allocate the structure
2268 structPtr = builder.CreateMalloc(intPtrTy, structTy, allocSize,
2269 /*ArraySize=*/nullptr, /*MallocF=*/nullptr,
2270 "omp.task.context_ptr");
2271}
2272
2273SmallVector<llvm::Value *> TaskContextStructManager::createGEPsToPrivateVars(
2274 llvm::Value *altStructPtr) const {
2275 SmallVector<llvm::Value *> ret;
2276
2277 // Create GEPs for each struct member
2278 ret.reserve(privateDecls.size());
2279 llvm::Value *zero = builder.getInt32(0);
2280 unsigned i = 0;
2281 for (auto privDecl : privateDecls) {
2282 if (!privDecl.readsFromMold()) {
2283 // Handle this inside of the task so we don't pass unnessecary vars in
2284 ret.push_back(nullptr);
2285 continue;
2286 }
2287 llvm::Value *iVal = builder.getInt32(i);
2288 llvm::Value *gep = builder.CreateGEP(structTy, altStructPtr, {zero, iVal});
2289 ret.push_back(gep);
2290 i += 1;
2291 }
2292 return ret;
2293}
2294
2295void TaskContextStructManager::createGEPsToPrivateVars() {
2296 if (!structPtr)
2297 assert(privateVarTypes.empty());
2298 // Still need to run createGEPsToPrivateVars to populate llvmPrivateVarGEPs
2299 // with null values for skipped private decls
2300
2301 llvmPrivateVarGEPs = createGEPsToPrivateVars(structPtr);
2302}
2303
2304void TaskContextStructManager::freeStructPtr() {
2305 if (!structPtr)
2306 return;
2307
2308 llvm::IRBuilderBase::InsertPointGuard guard{builder};
2309 // Ensure we don't put the call to free() after the terminator
2310 builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
2311 builder.CreateFree(structPtr);
2312}
2313
2314/// Converts an OpenMP task construct into LLVM IR using OpenMPIRBuilder.
2315static LogicalResult
2316convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
2317 LLVM::ModuleTranslation &moduleTranslation) {
2318 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2319 if (failed(checkImplementationStatus(*taskOp)))
2320 return failure();
2321
2322 PrivateVarsInfo privateVarsInfo(taskOp);
2323 TaskContextStructManager taskStructMgr{builder, moduleTranslation,
2324 privateVarsInfo.privatizers};
2325
2326 // Allocate and copy private variables before creating the task. This avoids
2327 // accessing invalid memory if (after this scope ends) the private variables
2328 // are initialized from host variables or if the variables are copied into
2329 // from host variables (firstprivate). The insertion point is just before
2330 // where the code for creating and scheduling the task will go. That puts this
2331 // code outside of the outlined task region, which is what we want because
2332 // this way the initialization and copy regions are executed immediately while
2333 // the host variable data are still live.
2334
2335 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2336 findAllocaInsertPoint(builder, moduleTranslation);
2337
2338 // Not using splitBB() because that requires the current block to have a
2339 // terminator.
2340 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end());
2341 llvm::BasicBlock *taskStartBlock = llvm::BasicBlock::Create(
2342 builder.getContext(), "omp.task.start",
2343 /*Parent=*/builder.GetInsertBlock()->getParent());
2344 llvm::Instruction *branchToTaskStartBlock = builder.CreateBr(taskStartBlock);
2345 builder.SetInsertPoint(branchToTaskStartBlock);
2346
2347 // Now do this again to make the initialization and copy blocks
2348 llvm::BasicBlock *copyBlock =
2349 splitBB(builder, /*CreateBranch=*/true, "omp.private.copy");
2350 llvm::BasicBlock *initBlock =
2351 splitBB(builder, /*CreateBranch=*/true, "omp.private.init");
2352
2353 // Now the control flow graph should look like
2354 // starter_block:
2355 // <---- where we started when convertOmpTaskOp was called
2356 // br %omp.private.init
2357 // omp.private.init:
2358 // br %omp.private.copy
2359 // omp.private.copy:
2360 // br %omp.task.start
2361 // omp.task.start:
2362 // <---- where we want the insertion point to be when we call createTask()
2363
2364 // Save the alloca insertion point on ModuleTranslation stack for use in
2365 // nested regions.
2367 moduleTranslation, allocaIP);
2368
2369 // Allocate and initialize private variables
2370 builder.SetInsertPoint(initBlock->getTerminator());
2371
2372 // Create task variable structure
2373 taskStructMgr.generateTaskContextStruct();
2374 // GEPs so that we can initialize the variables. Don't use these GEPs inside
2375 // of the body otherwise it will be the GEP not the struct which is fowarded
2376 // to the outlined function. GEPs forwarded in this way are passed in a
2377 // stack-allocated (by OpenMPIRBuilder) structure which is not safe for tasks
2378 // which may not be executed until after the current stack frame goes out of
2379 // scope.
2380 taskStructMgr.createGEPsToPrivateVars();
2381
2382 for (auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVarAlloc] :
2383 llvm::zip_equal(privateVarsInfo.privatizers, privateVarsInfo.mlirVars,
2384 privateVarsInfo.blockArgs,
2385 taskStructMgr.getLLVMPrivateVarGEPs())) {
2386 // To be handled inside the task.
2387 if (!privDecl.readsFromMold())
2388 continue;
2389 assert(llvmPrivateVarAlloc &&
2390 "reads from mold so shouldn't have been skipped");
2391
2392 llvm::Expected<llvm::Value *> privateVarOrErr =
2393 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2394 blockArg, llvmPrivateVarAlloc, initBlock);
2395 if (!privateVarOrErr)
2396 return handleError(privateVarOrErr, *taskOp.getOperation());
2397
2399
2400 // TODO: this is a bit of a hack for Fortran character boxes.
2401 // Character boxes are passed by value into the init region and then the
2402 // initialized character box is yielded by value. Here we need to store the
2403 // yielded value into the private allocation, and load the private
2404 // allocation to match the type expected by region block arguments.
2405 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
2406 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2407 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
2408 // Load it so we have the value pointed to by the GEP
2409 llvmPrivateVarAlloc = builder.CreateLoad(privateVarOrErr.get()->getType(),
2410 llvmPrivateVarAlloc);
2411 }
2412 assert(llvmPrivateVarAlloc->getType() ==
2413 moduleTranslation.convertType(blockArg.getType()));
2414
2415 // Mapping blockArg -> llvmPrivateVarAlloc is done inside the body callback
2416 // so that OpenMPIRBuilder doesn't try to pass each GEP address through a
2417 // stack allocated structure.
2418 }
2419
2420 // firstprivate copy region
2421 setInsertPointForPossiblyEmptyBlock(builder, copyBlock);
2422 if (failed(copyFirstPrivateVars(
2423 taskOp, builder, moduleTranslation, privateVarsInfo.mlirVars,
2424 taskStructMgr.getLLVMPrivateVarGEPs(), privateVarsInfo.privatizers,
2425 taskOp.getPrivateNeedsBarrier())))
2426 return llvm::failure();
2427
2428 // Set up for call to createTask()
2429 builder.SetInsertPoint(taskStartBlock);
2430
2431 auto bodyCB = [&](InsertPointTy allocaIP,
2432 InsertPointTy codegenIP) -> llvm::Error {
2433 // Save the alloca insertion point on ModuleTranslation stack for use in
2434 // nested regions.
2436 moduleTranslation, allocaIP);
2437
2438 // translate the body of the task:
2439 builder.restoreIP(codegenIP);
2440
2441 llvm::BasicBlock *privInitBlock = nullptr;
2442 privateVarsInfo.llvmVars.resize(privateVarsInfo.blockArgs.size());
2443 for (auto [i, zip] : llvm::enumerate(llvm::zip_equal(
2444 privateVarsInfo.blockArgs, privateVarsInfo.privatizers,
2445 privateVarsInfo.mlirVars))) {
2446 auto [blockArg, privDecl, mlirPrivVar] = zip;
2447 // This is handled before the task executes
2448 if (privDecl.readsFromMold())
2449 continue;
2450
2451 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2452 llvm::Type *llvmAllocType =
2453 moduleTranslation.convertType(privDecl.getType());
2454 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
2455 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
2456 llvmAllocType, /*ArraySize=*/nullptr, "omp.private.alloc");
2457
2458 llvm::Expected<llvm::Value *> privateVarOrError =
2459 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2460 blockArg, llvmPrivateVar, privInitBlock);
2461 if (!privateVarOrError)
2462 return privateVarOrError.takeError();
2463 moduleTranslation.mapValue(blockArg, privateVarOrError.get());
2464 privateVarsInfo.llvmVars[i] = privateVarOrError.get();
2465 }
2466
2467 taskStructMgr.createGEPsToPrivateVars();
2468 for (auto [i, llvmPrivVar] :
2469 llvm::enumerate(taskStructMgr.getLLVMPrivateVarGEPs())) {
2470 if (!llvmPrivVar) {
2471 assert(privateVarsInfo.llvmVars[i] &&
2472 "This is added in the loop above");
2473 continue;
2474 }
2475 privateVarsInfo.llvmVars[i] = llvmPrivVar;
2476 }
2477
2478 // Find and map the addresses of each variable within the task context
2479 // structure
2480 for (auto [blockArg, llvmPrivateVar, privateDecl] :
2481 llvm::zip_equal(privateVarsInfo.blockArgs, privateVarsInfo.llvmVars,
2482 privateVarsInfo.privatizers)) {
2483 // This was handled above.
2484 if (!privateDecl.readsFromMold())
2485 continue;
2486 // Fix broken pass-by-value case for Fortran character boxes
2487 if (!mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2488 llvmPrivateVar = builder.CreateLoad(
2489 moduleTranslation.convertType(blockArg.getType()), llvmPrivateVar);
2490 }
2491 assert(llvmPrivateVar->getType() ==
2492 moduleTranslation.convertType(blockArg.getType()));
2493 moduleTranslation.mapValue(blockArg, llvmPrivateVar);
2494 }
2495
2496 auto continuationBlockOrError = convertOmpOpRegions(
2497 taskOp.getRegion(), "omp.task.region", builder, moduleTranslation);
2498 if (failed(handleError(continuationBlockOrError, *taskOp)))
2499 return llvm::make_error<PreviouslyReportedError>();
2500
2501 builder.SetInsertPoint(continuationBlockOrError.get()->getTerminator());
2502
2503 if (failed(cleanupPrivateVars(builder, moduleTranslation, taskOp.getLoc(),
2504 privateVarsInfo.llvmVars,
2505 privateVarsInfo.privatizers)))
2506 return llvm::make_error<PreviouslyReportedError>();
2507
2508 // Free heap allocated task context structure at the end of the task.
2509 taskStructMgr.freeStructPtr();
2510
2511 return llvm::Error::success();
2512 };
2513
2514 llvm::OpenMPIRBuilder &ompBuilder = *moduleTranslation.getOpenMPBuilder();
2515 SmallVector<llvm::BranchInst *> cancelTerminators;
2516 // The directive to match here is OMPD_taskgroup because it is the taskgroup
2517 // which is canceled. This is handled here because it is the task's cleanup
2518 // block which should be branched to.
2519 pushCancelFinalizationCB(cancelTerminators, builder, ompBuilder, taskOp,
2520 llvm::omp::Directive::OMPD_taskgroup);
2521
2523 buildDependData(taskOp.getDependKinds(), taskOp.getDependVars(),
2524 moduleTranslation, dds);
2525
2526 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2527 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2528 moduleTranslation.getOpenMPBuilder()->createTask(
2529 ompLoc, allocaIP, bodyCB, !taskOp.getUntied(),
2530 moduleTranslation.lookupValue(taskOp.getFinal()),
2531 moduleTranslation.lookupValue(taskOp.getIfExpr()), dds,
2532 taskOp.getMergeable(),
2533 moduleTranslation.lookupValue(taskOp.getEventHandle()),
2534 moduleTranslation.lookupValue(taskOp.getPriority()));
2535
2536 if (failed(handleError(afterIP, *taskOp)))
2537 return failure();
2538
2539 // Set the correct branch target for task cancellation
2540 popCancelFinalizationCB(cancelTerminators, ompBuilder, afterIP.get());
2541
2542 builder.restoreIP(*afterIP);
2543 return success();
2544}
2545
2546// Converts an OpenMP taskloop construct into LLVM IR using OpenMPIRBuilder.
2547static LogicalResult
2548convertOmpTaskloopOp(Operation &opInst, llvm::IRBuilderBase &builder,
2549 LLVM::ModuleTranslation &moduleTranslation) {
2550 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2551 auto taskloopOp = cast<omp::TaskloopOp>(opInst);
2552 if (failed(checkImplementationStatus(opInst)))
2553 return failure();
2554
2555 // It stores the pointer of allocated firstprivate copies,
2556 // which can be used later for freeing the allocated space.
2557 SmallVector<llvm::Value *> llvmFirstPrivateVars;
2558 PrivateVarsInfo privateVarsInfo(taskloopOp);
2559 TaskContextStructManager taskStructMgr{builder, moduleTranslation,
2560 privateVarsInfo.privatizers};
2561
2562 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2563 findAllocaInsertPoint(builder, moduleTranslation);
2564
2565 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end());
2566 llvm::BasicBlock *taskloopStartBlock = llvm::BasicBlock::Create(
2567 builder.getContext(), "omp.taskloop.start",
2568 /*Parent=*/builder.GetInsertBlock()->getParent());
2569 llvm::Instruction *branchToTaskloopStartBlock =
2570 builder.CreateBr(taskloopStartBlock);
2571 builder.SetInsertPoint(branchToTaskloopStartBlock);
2572
2573 llvm::BasicBlock *copyBlock =
2574 splitBB(builder, /*CreateBranch=*/true, "omp.private.copy");
2575 llvm::BasicBlock *initBlock =
2576 splitBB(builder, /*CreateBranch=*/true, "omp.private.init");
2577
2579 moduleTranslation, allocaIP);
2580
2581 // Allocate and initialize private variables
2582 builder.SetInsertPoint(initBlock->getTerminator());
2583
2584 // TODO: don't allocate if the loop has zero iterations.
2585 taskStructMgr.generateTaskContextStruct();
2586 taskStructMgr.createGEPsToPrivateVars();
2587
2588 llvmFirstPrivateVars.resize(privateVarsInfo.blockArgs.size());
2589
2590 for (auto [i, zip] : llvm::enumerate(llvm::zip_equal(
2591 privateVarsInfo.privatizers, privateVarsInfo.mlirVars,
2592 privateVarsInfo.blockArgs, taskStructMgr.getLLVMPrivateVarGEPs()))) {
2593 auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVarAlloc] = zip;
2594 // To be handled inside the taskloop.
2595 if (!privDecl.readsFromMold())
2596 continue;
2597 assert(llvmPrivateVarAlloc &&
2598 "reads from mold so shouldn't have been skipped");
2599
2600 llvm::Expected<llvm::Value *> privateVarOrErr =
2601 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2602 blockArg, llvmPrivateVarAlloc, initBlock);
2603 if (!privateVarOrErr)
2604 return handleError(privateVarOrErr, *taskloopOp.getOperation());
2605
2606 llvmFirstPrivateVars[i] = privateVarOrErr.get();
2607
2608 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2609 builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
2610
2611 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
2612 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2613 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
2614 // Load it so we have the value pointed to by the GEP
2615 llvmPrivateVarAlloc = builder.CreateLoad(privateVarOrErr.get()->getType(),
2616 llvmPrivateVarAlloc);
2617 }
2618 assert(llvmPrivateVarAlloc->getType() ==
2619 moduleTranslation.convertType(blockArg.getType()));
2620 }
2621
2622 // firstprivate copy region
2623 setInsertPointForPossiblyEmptyBlock(builder, copyBlock);
2624 if (failed(copyFirstPrivateVars(
2625 taskloopOp, builder, moduleTranslation, privateVarsInfo.mlirVars,
2626 taskStructMgr.getLLVMPrivateVarGEPs(), privateVarsInfo.privatizers,
2627 taskloopOp.getPrivateNeedsBarrier())))
2628 return llvm::failure();
2629
2630 // Set up inserttion point for call to createTaskloop()
2631 builder.SetInsertPoint(taskloopStartBlock);
2632
2633 auto bodyCB = [&](InsertPointTy allocaIP,
2634 InsertPointTy codegenIP) -> llvm::Error {
2635 // Save the alloca insertion point on ModuleTranslation stack for use in
2636 // nested regions.
2638 moduleTranslation, allocaIP);
2639
2640 // translate the body of the taskloop:
2641 builder.restoreIP(codegenIP);
2642
2643 llvm::BasicBlock *privInitBlock = nullptr;
2644 privateVarsInfo.llvmVars.resize(privateVarsInfo.blockArgs.size());
2645 for (auto [i, zip] : llvm::enumerate(llvm::zip_equal(
2646 privateVarsInfo.blockArgs, privateVarsInfo.privatizers,
2647 privateVarsInfo.mlirVars))) {
2648 auto [blockArg, privDecl, mlirPrivVar] = zip;
2649 // This is handled before the task executes
2650 if (privDecl.readsFromMold())
2651 continue;
2652
2653 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2654 llvm::Type *llvmAllocType =
2655 moduleTranslation.convertType(privDecl.getType());
2656 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
2657 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
2658 llvmAllocType, /*ArraySize=*/nullptr, "omp.private.alloc");
2659
2660 llvm::Expected<llvm::Value *> privateVarOrError =
2661 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2662 blockArg, llvmPrivateVar, privInitBlock);
2663 if (!privateVarOrError)
2664 return privateVarOrError.takeError();
2665 moduleTranslation.mapValue(blockArg, privateVarOrError.get());
2666 privateVarsInfo.llvmVars[i] = privateVarOrError.get();
2667 }
2668
2669 taskStructMgr.createGEPsToPrivateVars();
2670 for (auto [i, llvmPrivVar] :
2671 llvm::enumerate(taskStructMgr.getLLVMPrivateVarGEPs())) {
2672 if (!llvmPrivVar) {
2673 assert(privateVarsInfo.llvmVars[i] &&
2674 "This is added in the loop above");
2675 continue;
2676 }
2677 privateVarsInfo.llvmVars[i] = llvmPrivVar;
2678 }
2679
2680 // Find and map the addresses of each variable within the taskloop context
2681 // structure
2682 for (auto [blockArg, llvmPrivateVar, privateDecl] :
2683 llvm::zip_equal(privateVarsInfo.blockArgs, privateVarsInfo.llvmVars,
2684 privateVarsInfo.privatizers)) {
2685 // This was handled above.
2686 if (!privateDecl.readsFromMold())
2687 continue;
2688 // Fix broken pass-by-value case for Fortran character boxes
2689 if (!mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2690 llvmPrivateVar = builder.CreateLoad(
2691 moduleTranslation.convertType(blockArg.getType()), llvmPrivateVar);
2692 }
2693 assert(llvmPrivateVar->getType() ==
2694 moduleTranslation.convertType(blockArg.getType()));
2695 moduleTranslation.mapValue(blockArg, llvmPrivateVar);
2696 }
2697
2698 auto continuationBlockOrError =
2699 convertOmpOpRegions(taskloopOp.getRegion(), "omp.taskloop.region",
2700 builder, moduleTranslation);
2701
2702 if (failed(handleError(continuationBlockOrError, opInst)))
2703 return llvm::make_error<PreviouslyReportedError>();
2704
2705 builder.SetInsertPoint(continuationBlockOrError.get()->getTerminator());
2706
2707 // This is freeing the private variables as mapped inside of the task: these
2708 // will be per-task private copies possibly after task duplication. This is
2709 // handled transparently by how these are passed to the structure passed
2710 // into the outlined function. When the task is duplicated, that structure
2711 // is duplicated too.
2712 if (failed(cleanupPrivateVars(builder, moduleTranslation,
2713 taskloopOp.getLoc(), privateVarsInfo.llvmVars,
2714 privateVarsInfo.privatizers)))
2715 return llvm::make_error<PreviouslyReportedError>();
2716 // Similarly, the task context structure freed inside the task is the
2717 // per-task copy after task duplication.
2718 taskStructMgr.freeStructPtr();
2719
2720 return llvm::Error::success();
2721 };
2722
2723 // Taskloop divides into an appropriate number of tasks by repeatedly
2724 // duplicating the original task. Each time this is done, the task context
2725 // structure must be duplicated too.
2726 auto taskDupCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP,
2727 llvm::Value *destPtr, llvm::Value *srcPtr)
2729 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2730 builder.restoreIP(codegenIP);
2731
2732 llvm::Type *ptrTy =
2733 builder.getPtrTy(srcPtr->getType()->getPointerAddressSpace());
2734 llvm::Value *src =
2735 builder.CreateLoad(ptrTy, srcPtr, "omp.taskloop.context.src");
2736
2737 TaskContextStructManager &srcStructMgr = taskStructMgr;
2738 TaskContextStructManager destStructMgr(builder, moduleTranslation,
2739 privateVarsInfo.privatizers);
2740 destStructMgr.generateTaskContextStruct();
2741 llvm::Value *dest = destStructMgr.getStructPtr();
2742 dest->setName("omp.taskloop.context.dest");
2743 builder.CreateStore(dest, destPtr);
2744
2746 srcStructMgr.createGEPsToPrivateVars(src);
2748 destStructMgr.createGEPsToPrivateVars(dest);
2749
2750 // Inline init regions.
2751 for (auto [privDecl, mold, blockArg, llvmPrivateVarAlloc] :
2752 llvm::zip_equal(privateVarsInfo.privatizers, srcGEPs,
2753 privateVarsInfo.blockArgs, destGEPs)) {
2754 // To be handled inside task body.
2755 if (!privDecl.readsFromMold())
2756 continue;
2757 assert(llvmPrivateVarAlloc &&
2758 "reads from mold so shouldn't have been skipped");
2759
2760 llvm::Expected<llvm::Value *> privateVarOrErr =
2761 initPrivateVar(builder, moduleTranslation, privDecl, mold, blockArg,
2762 llvmPrivateVarAlloc, builder.GetInsertBlock());
2763 if (!privateVarOrErr)
2764 return privateVarOrErr.takeError();
2765
2767
2768 // TODO: this is a bit of a hack for Fortran character boxes.
2769 // Character boxes are passed by value into the init region and then the
2770 // initialized character box is yielded by value. Here we need to store
2771 // the yielded value into the private allocation, and load the private
2772 // allocation to match the type expected by region block arguments.
2773 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
2774 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2775 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
2776 // Load it so we have the value pointed to by the GEP
2777 llvmPrivateVarAlloc = builder.CreateLoad(
2778 privateVarOrErr.get()->getType(), llvmPrivateVarAlloc);
2779 }
2780 assert(llvmPrivateVarAlloc->getType() ==
2781 moduleTranslation.convertType(blockArg.getType()));
2782
2783 // Mapping blockArg -> llvmPrivateVarAlloc is done inside the body
2784 // callback so that OpenMPIRBuilder doesn't try to pass each GEP address
2785 // through a stack allocated structure.
2786 }
2787
2788 if (failed(copyFirstPrivateVars(
2789 &opInst, builder, moduleTranslation, srcGEPs, destGEPs,
2790 privateVarsInfo.privatizers, taskloopOp.getPrivateNeedsBarrier())))
2791 return llvm::make_error<PreviouslyReportedError>();
2792
2793 return builder.saveIP();
2794 };
2795
2796 auto loopOp = cast<omp::LoopNestOp>(taskloopOp.getWrappedLoop());
2797
2798 auto loopInfo = [&]() -> llvm::Expected<llvm::CanonicalLoopInfo *> {
2799 llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
2800 return loopInfo;
2801 };
2802
2803 llvm::Value *ifCond = nullptr;
2804 llvm::Value *grainsize = nullptr;
2805 int sched = 0; // default
2806 mlir::Value grainsizeVal = taskloopOp.getGrainsize();
2807 mlir::Value numTasksVal = taskloopOp.getNumTasks();
2808 if (Value ifVar = taskloopOp.getIfExpr())
2809 ifCond = moduleTranslation.lookupValue(ifVar);
2810 if (grainsizeVal) {
2811 grainsize = moduleTranslation.lookupValue(grainsizeVal);
2812 sched = 1; // grainsize
2813 } else if (numTasksVal) {
2814 grainsize = moduleTranslation.lookupValue(numTasksVal);
2815 sched = 2; // num_tasks
2816 }
2817
2818 llvm::OpenMPIRBuilder::TaskDupCallbackTy taskDupOrNull = nullptr;
2819 if (taskStructMgr.getStructPtr())
2820 taskDupOrNull = taskDupCB;
2821
2822 llvm::OpenMPIRBuilder &ompBuilder = *moduleTranslation.getOpenMPBuilder();
2823 SmallVector<llvm::BranchInst *> cancelTerminators;
2824 // The directive to match here is OMPD_taskgroup because it is the
2825 // taskgroup which is canceled. This is handled here because it is the
2826 // task's cleanup block which should be branched to. It doesn't depend upon
2827 // nogroup because even in that case the taskloop might still be inside an
2828 // explicit taskgroup.
2829 pushCancelFinalizationCB(cancelTerminators, builder, ompBuilder, taskloopOp,
2830 llvm::omp::Directive::OMPD_taskgroup);
2831
2832 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2833 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2834 moduleTranslation.getOpenMPBuilder()->createTaskloop(
2835 ompLoc, allocaIP, bodyCB, loopInfo,
2836 moduleTranslation.lookupValue(loopOp.getLoopLowerBounds()[0]),
2837 moduleTranslation.lookupValue(loopOp.getLoopUpperBounds()[0]),
2838 moduleTranslation.lookupValue(loopOp.getLoopSteps()[0]),
2839 taskloopOp.getUntied(), ifCond, grainsize, taskloopOp.getNogroup(),
2840 sched, moduleTranslation.lookupValue(taskloopOp.getFinal()),
2841 taskloopOp.getMergeable(),
2842 moduleTranslation.lookupValue(taskloopOp.getPriority()),
2843 taskDupOrNull, taskStructMgr.getStructPtr());
2844
2845 if (failed(handleError(afterIP, opInst)))
2846 return failure();
2847
2848 popCancelFinalizationCB(cancelTerminators, ompBuilder, afterIP.get());
2849
2850 builder.restoreIP(*afterIP);
2851 return success();
2852}
2853
2854/// Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder.
2855static LogicalResult
2856convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder,
2857 LLVM::ModuleTranslation &moduleTranslation) {
2858 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2859 if (failed(checkImplementationStatus(*tgOp)))
2860 return failure();
2861
2862 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
2863 builder.restoreIP(codegenIP);
2864 return convertOmpOpRegions(tgOp.getRegion(), "omp.taskgroup.region",
2865 builder, moduleTranslation)
2866 .takeError();
2867 };
2868
2869 InsertPointTy allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
2870 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2871 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2872 moduleTranslation.getOpenMPBuilder()->createTaskgroup(ompLoc, allocaIP,
2873 bodyCB);
2874
2875 if (failed(handleError(afterIP, *tgOp)))
2876 return failure();
2877
2878 builder.restoreIP(*afterIP);
2879 return success();
2880}
2881
2882static LogicalResult
2883convertOmpTaskwaitOp(omp::TaskwaitOp twOp, llvm::IRBuilderBase &builder,
2884 LLVM::ModuleTranslation &moduleTranslation) {
2885 if (failed(checkImplementationStatus(*twOp)))
2886 return failure();
2887
2888 moduleTranslation.getOpenMPBuilder()->createTaskwait(builder.saveIP());
2889 return success();
2890}
2891
2892/// Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
2893static LogicalResult
2894convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
2895 LLVM::ModuleTranslation &moduleTranslation) {
2896 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
2897 auto wsloopOp = cast<omp::WsloopOp>(opInst);
2898 if (failed(checkImplementationStatus(opInst)))
2899 return failure();
2900
2901 auto loopOp = cast<omp::LoopNestOp>(wsloopOp.getWrappedLoop());
2902 llvm::ArrayRef<bool> isByRef = getIsByRef(wsloopOp.getReductionByref());
2903 assert(isByRef.size() == wsloopOp.getNumReductionVars());
2904
2905 // Static is the default.
2906 auto schedule =
2907 wsloopOp.getScheduleKind().value_or(omp::ClauseScheduleKind::Static);
2908
2909 // Find the loop configuration.
2910 llvm::Value *step = moduleTranslation.lookupValue(loopOp.getLoopSteps()[0]);
2911 llvm::Type *ivType = step->getType();
2912 llvm::Value *chunk = nullptr;
2913 if (wsloopOp.getScheduleChunk()) {
2914 llvm::Value *chunkVar =
2915 moduleTranslation.lookupValue(wsloopOp.getScheduleChunk());
2916 chunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
2917 }
2918
2919 omp::DistributeOp distributeOp = nullptr;
2920 llvm::Value *distScheduleChunk = nullptr;
2921 bool hasDistSchedule = false;
2922 if (llvm::isa_and_present<omp::DistributeOp>(opInst.getParentOp())) {
2923 distributeOp = cast<omp::DistributeOp>(opInst.getParentOp());
2924 hasDistSchedule = distributeOp.getDistScheduleStatic();
2925 if (distributeOp.getDistScheduleChunkSize()) {
2926 llvm::Value *chunkVar = moduleTranslation.lookupValue(
2927 distributeOp.getDistScheduleChunkSize());
2928 distScheduleChunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
2929 }
2930 }
2931
2932 PrivateVarsInfo privateVarsInfo(wsloopOp);
2933
2935 collectReductionDecls(wsloopOp, reductionDecls);
2936 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2937 findAllocaInsertPoint(builder, moduleTranslation);
2938
2939 SmallVector<llvm::Value *> privateReductionVariables(
2940 wsloopOp.getNumReductionVars());
2941
2943 builder, moduleTranslation, privateVarsInfo, allocaIP);
2944 if (handleError(afterAllocas, opInst).failed())
2945 return failure();
2946
2947 DenseMap<Value, llvm::Value *> reductionVariableMap;
2948
2949 MutableArrayRef<BlockArgument> reductionArgs =
2950 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
2951
2952 SmallVector<DeferredStore> deferredStores;
2953
2954 if (failed(allocReductionVars(wsloopOp, reductionArgs, builder,
2955 moduleTranslation, allocaIP, reductionDecls,
2956 privateReductionVariables, reductionVariableMap,
2957 deferredStores, isByRef)))
2958 return failure();
2959
2960 if (handleError(initPrivateVars(builder, moduleTranslation, privateVarsInfo),
2961 opInst)
2962 .failed())
2963 return failure();
2964
2965 if (failed(copyFirstPrivateVars(
2966 wsloopOp, builder, moduleTranslation, privateVarsInfo.mlirVars,
2967 privateVarsInfo.llvmVars, privateVarsInfo.privatizers,
2968 wsloopOp.getPrivateNeedsBarrier())))
2969 return failure();
2970
2971 assert(afterAllocas.get()->getSinglePredecessor());
2972 if (failed(initReductionVars(wsloopOp, reductionArgs, builder,
2973 moduleTranslation,
2974 afterAllocas.get()->getSinglePredecessor(),
2975 reductionDecls, privateReductionVariables,
2976 reductionVariableMap, isByRef, deferredStores)))
2977 return failure();
2978
2979 // TODO: Handle doacross loops when the ordered clause has a parameter.
2980 bool isOrdered = wsloopOp.getOrdered().has_value();
2981 std::optional<omp::ScheduleModifier> scheduleMod = wsloopOp.getScheduleMod();
2982 bool isSimd = wsloopOp.getScheduleSimd();
2983 bool loopNeedsBarrier = !wsloopOp.getNowait();
2984
2985 // The only legal way for the direct parent to be omp.distribute is that this
2986 // represents 'distribute parallel do'. Otherwise, this is a regular
2987 // worksharing loop.
2988 llvm::omp::WorksharingLoopType workshareLoopType =
2989 llvm::isa_and_present<omp::DistributeOp>(opInst.getParentOp())
2990 ? llvm::omp::WorksharingLoopType::DistributeForStaticLoop
2991 : llvm::omp::WorksharingLoopType::ForStaticLoop;
2992
2993 SmallVector<llvm::BranchInst *> cancelTerminators;
2994 pushCancelFinalizationCB(cancelTerminators, builder, *ompBuilder, wsloopOp,
2995 llvm::omp::Directive::OMPD_for);
2996
2997 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2998
2999 // Initialize linear variables and linear step
3000 LinearClauseProcessor linearClauseProcessor;
3001
3002 if (!wsloopOp.getLinearVars().empty()) {
3003 auto linearVarTypes = wsloopOp.getLinearVarTypes().value();
3004 for (mlir::Attribute linearVarType : linearVarTypes)
3005 linearClauseProcessor.registerType(moduleTranslation, linearVarType);
3006
3007 for (auto [idx, linearVar] : llvm::enumerate(wsloopOp.getLinearVars()))
3008 linearClauseProcessor.createLinearVar(
3009 builder, moduleTranslation, moduleTranslation.lookupValue(linearVar),
3010 idx);
3011 for (mlir::Value linearStep : wsloopOp.getLinearStepVars())
3012 linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
3013 }
3014
3016 wsloopOp.getRegion(), "omp.wsloop.region", builder, moduleTranslation);
3017
3018 if (failed(handleError(regionBlock, opInst)))
3019 return failure();
3020
3021 llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
3022
3023 // Emit Initialization and Update IR for linear variables
3024 if (!wsloopOp.getLinearVars().empty()) {
3025 linearClauseProcessor.initLinearVar(builder, moduleTranslation,
3026 loopInfo->getPreheader());
3027 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
3028 moduleTranslation.getOpenMPBuilder()->createBarrier(
3029 builder.saveIP(), llvm::omp::OMPD_barrier);
3030 if (failed(handleError(afterBarrierIP, *loopOp)))
3031 return failure();
3032 builder.restoreIP(*afterBarrierIP);
3033 linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
3034 loopInfo->getIndVar());
3035 linearClauseProcessor.splitLinearFiniBB(builder, loopInfo->getExit());
3036 }
3037
3038 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
3039
3040 // Check if we can generate no-loop kernel
3041 bool noLoopMode = false;
3042 omp::TargetOp targetOp = wsloopOp->getParentOfType<mlir::omp::TargetOp>();
3043 if (targetOp) {
3044 Operation *targetCapturedOp = targetOp.getInnermostCapturedOmpOp();
3045 // We need this check because, without it, noLoopMode would be set to true
3046 // for every omp.wsloop nested inside a no-loop SPMD target region, even if
3047 // that loop is not the top-level SPMD one.
3048 if (loopOp == targetCapturedOp) {
3049 omp::TargetRegionFlags kernelFlags =
3050 targetOp.getKernelExecFlags(targetCapturedOp);
3051 if (omp::bitEnumContainsAll(kernelFlags,
3052 omp::TargetRegionFlags::spmd |
3053 omp::TargetRegionFlags::no_loop) &&
3054 !omp::bitEnumContainsAny(kernelFlags,
3055 omp::TargetRegionFlags::generic))
3056 noLoopMode = true;
3057 }
3058 }
3059
3060 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
3061 ompBuilder->applyWorkshareLoop(
3062 ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
3063 convertToScheduleKind(schedule), chunk, isSimd,
3064 scheduleMod == omp::ScheduleModifier::monotonic,
3065 scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
3066 workshareLoopType, noLoopMode, hasDistSchedule, distScheduleChunk);
3067
3068 if (failed(handleError(wsloopIP, opInst)))
3069 return failure();
3070
3071 // Emit finalization and in-place rewrites for linear vars.
3072 if (!wsloopOp.getLinearVars().empty()) {
3073 llvm::OpenMPIRBuilder::InsertPointTy oldIP = builder.saveIP();
3074 assert(loopInfo->getLastIter() &&
3075 "`lastiter` in CanonicalLoopInfo is nullptr");
3076 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
3077 linearClauseProcessor.finalizeLinearVar(builder, moduleTranslation,
3078 loopInfo->getLastIter());
3079 if (failed(handleError(afterBarrierIP, *loopOp)))
3080 return failure();
3081 for (size_t index = 0; index < wsloopOp.getLinearVars().size(); index++)
3082 linearClauseProcessor.rewriteInPlace(builder, "omp.loop_nest.region",
3083 index);
3084 builder.restoreIP(oldIP);
3085 }
3086
3087 // Set the correct branch target for task cancellation
3088 popCancelFinalizationCB(cancelTerminators, *ompBuilder, wsloopIP.get());
3089
3090 // Process the reductions if required.
3091 if (failed(createReductionsAndCleanup(
3092 wsloopOp, builder, moduleTranslation, allocaIP, reductionDecls,
3093 privateReductionVariables, isByRef, wsloopOp.getNowait(),
3094 /*isTeamsReduction=*/false)))
3095 return failure();
3096
3097 return cleanupPrivateVars(builder, moduleTranslation, wsloopOp.getLoc(),
3098 privateVarsInfo.llvmVars,
3099 privateVarsInfo.privatizers);
3100}
3101
3102/// Converts the OpenMP parallel operation to LLVM IR.
3103static LogicalResult
3104convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
3105 LLVM::ModuleTranslation &moduleTranslation) {
3106 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
3107 ArrayRef<bool> isByRef = getIsByRef(opInst.getReductionByref());
3108 assert(isByRef.size() == opInst.getNumReductionVars());
3109 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3110 bool isCancellable = constructIsCancellable(opInst);
3111
3112 if (failed(checkImplementationStatus(*opInst)))
3113 return failure();
3114
3115 PrivateVarsInfo privateVarsInfo(opInst);
3116
3117 // Collect reduction declarations
3119 collectReductionDecls(opInst, reductionDecls);
3120 SmallVector<llvm::Value *> privateReductionVariables(
3121 opInst.getNumReductionVars());
3122 SmallVector<DeferredStore> deferredStores;
3123
3124 auto bodyGenCB = [&](InsertPointTy allocaIP,
3125 InsertPointTy codeGenIP) -> llvm::Error {
3127 builder, moduleTranslation, privateVarsInfo, allocaIP);
3128 if (handleError(afterAllocas, *opInst).failed())
3129 return llvm::make_error<PreviouslyReportedError>();
3130
3131 // Allocate reduction vars
3132 DenseMap<Value, llvm::Value *> reductionVariableMap;
3133
3134 MutableArrayRef<BlockArgument> reductionArgs =
3135 cast<omp::BlockArgOpenMPOpInterface>(*opInst).getReductionBlockArgs();
3136
3137 allocaIP =
3138 InsertPointTy(allocaIP.getBlock(),
3139 allocaIP.getBlock()->getTerminator()->getIterator());
3140
3141 if (failed(allocReductionVars(
3142 opInst, reductionArgs, builder, moduleTranslation, allocaIP,
3143 reductionDecls, privateReductionVariables, reductionVariableMap,
3144 deferredStores, isByRef)))
3145 return llvm::make_error<PreviouslyReportedError>();
3146
3147 assert(afterAllocas.get()->getSinglePredecessor());
3148 builder.restoreIP(codeGenIP);
3149
3150 if (handleError(
3151 initPrivateVars(builder, moduleTranslation, privateVarsInfo),
3152 *opInst)
3153 .failed())
3154 return llvm::make_error<PreviouslyReportedError>();
3155
3156 if (failed(copyFirstPrivateVars(
3157 opInst, builder, moduleTranslation, privateVarsInfo.mlirVars,
3158 privateVarsInfo.llvmVars, privateVarsInfo.privatizers,
3159 opInst.getPrivateNeedsBarrier())))
3160 return llvm::make_error<PreviouslyReportedError>();
3161
3162 if (failed(
3163 initReductionVars(opInst, reductionArgs, builder, moduleTranslation,
3164 afterAllocas.get()->getSinglePredecessor(),
3165 reductionDecls, privateReductionVariables,
3166 reductionVariableMap, isByRef, deferredStores)))
3167 return llvm::make_error<PreviouslyReportedError>();
3168
3169 // Save the alloca insertion point on ModuleTranslation stack for use in
3170 // nested regions.
3172 moduleTranslation, allocaIP);
3173
3174 // ParallelOp has only one region associated with it.
3176 opInst.getRegion(), "omp.par.region", builder, moduleTranslation);
3177 if (!regionBlock)
3178 return regionBlock.takeError();
3179
3180 // Process the reductions if required.
3181 if (opInst.getNumReductionVars() > 0) {
3182 // Collect reduction info
3184 SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
3186 owningReductionGenRefDataPtrGens;
3188 collectReductionInfo(opInst, builder, moduleTranslation, reductionDecls,
3189 owningReductionGens, owningAtomicReductionGens,
3190 owningReductionGenRefDataPtrGens,
3191 privateReductionVariables, reductionInfos, isByRef);
3192
3193 // Move to region cont block
3194 builder.SetInsertPoint((*regionBlock)->getTerminator());
3195
3196 // Generate reductions from info
3197 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
3198 builder.SetInsertPoint(tempTerminator);
3199
3200 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
3201 ompBuilder->createReductions(
3202 builder.saveIP(), allocaIP, reductionInfos, isByRef,
3203 /*IsNoWait=*/false, /*IsTeamsReduction=*/false);
3204 if (!contInsertPoint)
3205 return contInsertPoint.takeError();
3206
3207 if (!contInsertPoint->getBlock())
3208 return llvm::make_error<PreviouslyReportedError>();
3209
3210 tempTerminator->eraseFromParent();
3211 builder.restoreIP(*contInsertPoint);
3212 }
3213
3214 return llvm::Error::success();
3215 };
3216
3217 auto privCB = [](InsertPointTy allocaIP, InsertPointTy codeGenIP,
3218 llvm::Value &, llvm::Value &val, llvm::Value *&replVal) {
3219 // tell OpenMPIRBuilder not to do anything. We handled Privatisation in
3220 // bodyGenCB.
3221 replVal = &val;
3222 return codeGenIP;
3223 };
3224
3225 // TODO: Perform finalization actions for variables. This has to be
3226 // called for variables which have destructors/finalizers.
3227 auto finiCB = [&](InsertPointTy codeGenIP) -> llvm::Error {
3228 InsertPointTy oldIP = builder.saveIP();
3229 builder.restoreIP(codeGenIP);
3230
3231 // if the reduction has a cleanup region, inline it here to finalize the
3232 // reduction variables
3233 SmallVector<Region *> reductionCleanupRegions;
3234 llvm::transform(reductionDecls, std::back_inserter(reductionCleanupRegions),
3235 [](omp::DeclareReductionOp reductionDecl) {
3236 return &reductionDecl.getCleanupRegion();
3237 });
3238 if (failed(inlineOmpRegionCleanup(
3239 reductionCleanupRegions, privateReductionVariables,
3240 moduleTranslation, builder, "omp.reduction.cleanup")))
3241 return llvm::createStringError(
3242 "failed to inline `cleanup` region of `omp.declare_reduction`");
3243
3244 if (failed(cleanupPrivateVars(builder, moduleTranslation, opInst.getLoc(),
3245 privateVarsInfo.llvmVars,
3246 privateVarsInfo.privatizers)))
3247 return llvm::make_error<PreviouslyReportedError>();
3248
3249 // If we could be performing cancellation, add the cancellation barrier on
3250 // the way out of the outlined region.
3251 if (isCancellable) {
3252 auto IPOrErr = ompBuilder->createBarrier(
3253 llvm::OpenMPIRBuilder::LocationDescription(builder),
3254 llvm::omp::Directive::OMPD_unknown,
3255 /* ForceSimpleCall */ false,
3256 /* CheckCancelFlag */ false);
3257 if (!IPOrErr)
3258 return IPOrErr.takeError();
3259 }
3260
3261 builder.restoreIP(oldIP);
3262 return llvm::Error::success();
3263 };
3264
3265 llvm::Value *ifCond = nullptr;
3266 if (auto ifVar = opInst.getIfExpr())
3267 ifCond = moduleTranslation.lookupValue(ifVar);
3268 llvm::Value *numThreads = nullptr;
3269 if (!opInst.getNumThreadsVars().empty())
3270 numThreads = moduleTranslation.lookupValue(opInst.getNumThreads(0));
3271 auto pbKind = llvm::omp::OMP_PROC_BIND_default;
3272 if (auto bind = opInst.getProcBindKind())
3273 pbKind = getProcBindKind(*bind);
3274
3275 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3276 findAllocaInsertPoint(builder, moduleTranslation);
3277 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3278
3279 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3280 ompBuilder->createParallel(ompLoc, allocaIP, bodyGenCB, privCB, finiCB,
3281 ifCond, numThreads, pbKind, isCancellable);
3282
3283 if (failed(handleError(afterIP, *opInst)))
3284 return failure();
3285
3286 builder.restoreIP(*afterIP);
3287 return success();
3288}
3289
3290/// Convert Order attribute to llvm::omp::OrderKind.
3291static llvm::omp::OrderKind
3292convertOrderKind(std::optional<omp::ClauseOrderKind> o) {
3293 if (!o)
3294 return llvm::omp::OrderKind::OMP_ORDER_unknown;
3295 switch (*o) {
3296 case omp::ClauseOrderKind::Concurrent:
3297 return llvm::omp::OrderKind::OMP_ORDER_concurrent;
3298 }
3299 llvm_unreachable("Unknown ClauseOrderKind kind");
3300}
3301
3302/// Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder.
3303static LogicalResult
3304convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
3305 LLVM::ModuleTranslation &moduleTranslation) {
3306 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3307 auto simdOp = cast<omp::SimdOp>(opInst);
3308
3309 if (failed(checkImplementationStatus(opInst)))
3310 return failure();
3311
3312 PrivateVarsInfo privateVarsInfo(simdOp);
3313
3314 MutableArrayRef<BlockArgument> reductionArgs =
3315 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
3316 DenseMap<Value, llvm::Value *> reductionVariableMap;
3317 SmallVector<llvm::Value *> privateReductionVariables(
3318 simdOp.getNumReductionVars());
3319 SmallVector<DeferredStore> deferredStores;
3321 collectReductionDecls(simdOp, reductionDecls);
3322 llvm::ArrayRef<bool> isByRef = getIsByRef(simdOp.getReductionByref());
3323 assert(isByRef.size() == simdOp.getNumReductionVars());
3324
3325 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3326 findAllocaInsertPoint(builder, moduleTranslation);
3327
3329 builder, moduleTranslation, privateVarsInfo, allocaIP);
3330 if (handleError(afterAllocas, opInst).failed())
3331 return failure();
3332
3333 // Initialize linear variables and linear step
3334 LinearClauseProcessor linearClauseProcessor;
3335
3336 if (!simdOp.getLinearVars().empty()) {
3337 auto linearVarTypes = simdOp.getLinearVarTypes().value();
3338 for (mlir::Attribute linearVarType : linearVarTypes)
3339 linearClauseProcessor.registerType(moduleTranslation, linearVarType);
3340 for (auto [idx, linearVar] : llvm::enumerate(simdOp.getLinearVars())) {
3341 bool isImplicit = false;
3342 for (auto [mlirPrivVar, llvmPrivateVar] : llvm::zip_equal(
3343 privateVarsInfo.mlirVars, privateVarsInfo.llvmVars)) {
3344 // If the linear variable is implicit, reuse the already
3345 // existing llvm::Value
3346 if (linearVar == mlirPrivVar) {
3347 isImplicit = true;
3348 linearClauseProcessor.createLinearVar(builder, moduleTranslation,
3349 llvmPrivateVar, idx);
3350 break;
3351 }
3352 }
3353
3354 if (!isImplicit)
3355 linearClauseProcessor.createLinearVar(
3356 builder, moduleTranslation,
3357 moduleTranslation.lookupValue(linearVar), idx);
3358 }
3359 for (mlir::Value linearStep : simdOp.getLinearStepVars())
3360 linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
3361 }
3362
3363 if (failed(allocReductionVars(simdOp, reductionArgs, builder,
3364 moduleTranslation, allocaIP, reductionDecls,
3365 privateReductionVariables, reductionVariableMap,
3366 deferredStores, isByRef)))
3367 return failure();
3368
3369 if (handleError(initPrivateVars(builder, moduleTranslation, privateVarsInfo),
3370 opInst)
3371 .failed())
3372 return failure();
3373
3374 // No call to copyFirstPrivateVars because FIRSTPRIVATE is not allowed for
3375 // SIMD.
3376
3377 assert(afterAllocas.get()->getSinglePredecessor());
3378 if (failed(initReductionVars(simdOp, reductionArgs, builder,
3379 moduleTranslation,
3380 afterAllocas.get()->getSinglePredecessor(),
3381 reductionDecls, privateReductionVariables,
3382 reductionVariableMap, isByRef, deferredStores)))
3383 return failure();
3384
3385 llvm::ConstantInt *simdlen = nullptr;
3386 if (std::optional<uint64_t> simdlenVar = simdOp.getSimdlen())
3387 simdlen = builder.getInt64(simdlenVar.value());
3388
3389 llvm::ConstantInt *safelen = nullptr;
3390 if (std::optional<uint64_t> safelenVar = simdOp.getSafelen())
3391 safelen = builder.getInt64(safelenVar.value());
3392
3393 llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars;
3394 llvm::omp::OrderKind order = convertOrderKind(simdOp.getOrder());
3395
3396 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
3397 std::optional<ArrayAttr> alignmentValues = simdOp.getAlignments();
3398 mlir::OperandRange operands = simdOp.getAlignedVars();
3399 for (size_t i = 0; i < operands.size(); ++i) {
3400 llvm::Value *alignment = nullptr;
3401 llvm::Value *llvmVal = moduleTranslation.lookupValue(operands[i]);
3402 llvm::Type *ty = llvmVal->getType();
3403
3404 auto intAttr = cast<IntegerAttr>((*alignmentValues)[i]);
3405 alignment = builder.getInt64(intAttr.getInt());
3406 assert(ty->isPointerTy() && "Invalid type for aligned variable");
3407 assert(alignment && "Invalid alignment value");
3408
3409 // Check if the alignment value is not a power of 2. If so, skip emitting
3410 // alignment.
3411 if (!intAttr.getValue().isPowerOf2())
3412 continue;
3413
3414 auto curInsert = builder.saveIP();
3415 builder.SetInsertPoint(sourceBlock);
3416 llvmVal = builder.CreateLoad(ty, llvmVal);
3417 builder.restoreIP(curInsert);
3418 alignedVars[llvmVal] = alignment;
3419 }
3420
3422 simdOp.getRegion(), "omp.simd.region", builder, moduleTranslation);
3423
3424 if (failed(handleError(regionBlock, opInst)))
3425 return failure();
3426
3427 llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
3428 // Emit Initialization for linear variables
3429 if (simdOp.getLinearVars().size()) {
3430 linearClauseProcessor.initLinearVar(builder, moduleTranslation,
3431 loopInfo->getPreheader());
3432
3433 linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
3434 loopInfo->getIndVar());
3435 }
3436 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
3437
3438 ompBuilder->applySimd(loopInfo, alignedVars,
3439 simdOp.getIfExpr()
3440 ? moduleTranslation.lookupValue(simdOp.getIfExpr())
3441 : nullptr,
3442 order, simdlen, safelen);
3443
3444 linearClauseProcessor.emitStoresForLinearVar(builder);
3445 for (size_t index = 0; index < simdOp.getLinearVars().size(); index++)
3446 linearClauseProcessor.rewriteInPlace(builder, "omp.loop_nest.region",
3447 index);
3448
3449 // We now need to reduce the per-simd-lane reduction variable into the
3450 // original variable. This works a bit differently to other reductions (e.g.
3451 // wsloop) because we don't need to call into the OpenMP runtime to handle
3452 // threads: everything happened in this one thread.
3453 for (auto [i, tuple] : llvm::enumerate(
3454 llvm::zip(reductionDecls, isByRef, simdOp.getReductionVars(),
3455 privateReductionVariables))) {
3456 auto [decl, byRef, reductionVar, privateReductionVar] = tuple;
3457
3458 OwningReductionGen gen = makeReductionGen(decl, builder, moduleTranslation);
3459 llvm::Value *originalVariable = moduleTranslation.lookupValue(reductionVar);
3460 llvm::Type *reductionType = moduleTranslation.convertType(decl.getType());
3461
3462 // We have one less load for by-ref case because that load is now inside of
3463 // the reduction region.
3464 llvm::Value *redValue = originalVariable;
3465 if (!byRef)
3466 redValue =
3467 builder.CreateLoad(reductionType, redValue, "red.value." + Twine(i));
3468 llvm::Value *privateRedValue = builder.CreateLoad(
3469 reductionType, privateReductionVar, "red.private.value." + Twine(i));
3470 llvm::Value *reduced;
3471
3472 auto res = gen(builder.saveIP(), redValue, privateRedValue, reduced);
3473 if (failed(handleError(res, opInst)))
3474 return failure();
3475 builder.restoreIP(res.get());
3476
3477 // For by-ref case, the store is inside of the reduction region.
3478 if (!byRef)
3479 builder.CreateStore(reduced, originalVariable);
3480 }
3481
3482 // After the construct, deallocate private reduction variables.
3483 SmallVector<Region *> reductionRegions;
3484 llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
3485 [](omp::DeclareReductionOp reductionDecl) {
3486 return &reductionDecl.getCleanupRegion();
3487 });
3488 if (failed(inlineOmpRegionCleanup(reductionRegions, privateReductionVariables,
3489 moduleTranslation, builder,
3490 "omp.reduction.cleanup")))
3491 return failure();
3492
3493 return cleanupPrivateVars(builder, moduleTranslation, simdOp.getLoc(),
3494 privateVarsInfo.llvmVars,
3495 privateVarsInfo.privatizers);
3496}
3497
3498/// Converts an OpenMP loop nest into LLVM IR using OpenMPIRBuilder.
3499static LogicalResult
3500convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder,
3501 LLVM::ModuleTranslation &moduleTranslation) {
3502 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3503 auto loopOp = cast<omp::LoopNestOp>(opInst);
3504
3505 if (failed(checkImplementationStatus(opInst)))
3506 return failure();
3507
3508 // Set up the source location value for OpenMP runtime.
3509 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3510
3511 // Generator of the canonical loop body.
3514 auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip,
3515 llvm::Value *iv) -> llvm::Error {
3516 // Make sure further conversions know about the induction variable.
3517 moduleTranslation.mapValue(
3518 loopOp.getRegion().front().getArgument(loopInfos.size()), iv);
3519
3520 // Capture the body insertion point for use in nested loops. BodyIP of the
3521 // CanonicalLoopInfo always points to the beginning of the entry block of
3522 // the body.
3523 bodyInsertPoints.push_back(ip);
3524
3525 if (loopInfos.size() != loopOp.getNumLoops() - 1)
3526 return llvm::Error::success();
3527
3528 // Convert the body of the loop.
3529 builder.restoreIP(ip);
3531 loopOp.getRegion(), "omp.loop_nest.region", builder, moduleTranslation);
3532 if (!regionBlock)
3533 return regionBlock.takeError();
3534
3535 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
3536 return llvm::Error::success();
3537 };
3538
3539 // Delegate actual loop construction to the OpenMP IRBuilder.
3540 // TODO: this currently assumes omp.loop_nest is semantically similar to SCF
3541 // loop, i.e. it has a positive step, uses signed integer semantics.
3542 // Reconsider this code when the nested loop operation clearly supports more
3543 // cases.
3544 for (unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) {
3545 llvm::Value *lowerBound =
3546 moduleTranslation.lookupValue(loopOp.getLoopLowerBounds()[i]);
3547 llvm::Value *upperBound =
3548 moduleTranslation.lookupValue(loopOp.getLoopUpperBounds()[i]);
3549 llvm::Value *step = moduleTranslation.lookupValue(loopOp.getLoopSteps()[i]);
3550
3551 // Make sure loop trip count are emitted in the preheader of the outermost
3552 // loop at the latest so that they are all available for the new collapsed
3553 // loop will be created below.
3554 llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc;
3555 llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP;
3556 if (i != 0) {
3557 loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back(),
3558 ompLoc.DL);
3559 computeIP = loopInfos.front()->getPreheaderIP();
3560 }
3561
3563 ompBuilder->createCanonicalLoop(
3564 loc, bodyGen, lowerBound, upperBound, step,
3565 /*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP);
3566
3567 if (failed(handleError(loopResult, *loopOp)))
3568 return failure();
3569
3570 loopInfos.push_back(*loopResult);
3571 }
3572
3573 llvm::OpenMPIRBuilder::InsertPointTy afterIP =
3574 loopInfos.front()->getAfterIP();
3575
3576 // Do tiling.
3577 if (const auto &tiles = loopOp.getTileSizes()) {
3578 llvm::Type *ivType = loopInfos.front()->getIndVarType();
3580
3581 for (auto tile : tiles.value()) {
3582 llvm::Value *tileVal = llvm::ConstantInt::get(ivType, tile);
3583 tileSizes.push_back(tileVal);
3584 }
3585
3586 std::vector<llvm::CanonicalLoopInfo *> newLoops =
3587 ompBuilder->tileLoops(ompLoc.DL, loopInfos, tileSizes);
3588
3589 // Update afterIP to get the correct insertion point after
3590 // tiling.
3591 llvm::BasicBlock *afterBB = newLoops.front()->getAfter();
3592 llvm::BasicBlock *afterAfterBB = afterBB->getSingleSuccessor();
3593 afterIP = {afterAfterBB, afterAfterBB->begin()};
3594
3595 // Update the loop infos.
3596 loopInfos.clear();
3597 for (const auto &newLoop : newLoops)
3598 loopInfos.push_back(newLoop);
3599 } // Tiling done.
3600
3601 // Do collapse.
3602 const auto &numCollapse = loopOp.getCollapseNumLoops();
3604 loopInfos.begin(), loopInfos.begin() + (numCollapse));
3605
3606 auto newTopLoopInfo =
3607 ompBuilder->collapseLoops(ompLoc.DL, collapseLoopInfos, {});
3608
3609 assert(newTopLoopInfo && "New top loop information is missing");
3610 moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
3611 [&](OpenMPLoopInfoStackFrame &frame) {
3612 frame.loopInfo = newTopLoopInfo;
3613 return WalkResult::interrupt();
3614 });
3615
3616 // Continue building IR after the loop. Note that the LoopInfo returned by
3617 // `collapseLoops` points inside the outermost loop and is intended for
3618 // potential further loop transformations. Use the insertion point stored
3619 // before collapsing loops instead.
3620 builder.restoreIP(afterIP);
3621 return success();
3622}
3623
3624/// Convert an omp.canonical_loop to LLVM-IR
3625static LogicalResult
3626convertOmpCanonicalLoopOp(omp::CanonicalLoopOp op, llvm::IRBuilderBase &builder,
3627 LLVM::ModuleTranslation &moduleTranslation) {
3628 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3629
3630 llvm::OpenMPIRBuilder::LocationDescription loopLoc(builder);
3631 Value loopIV = op.getInductionVar();
3632 Value loopTC = op.getTripCount();
3633
3634 llvm::Value *llvmTC = moduleTranslation.lookupValue(loopTC);
3635
3637 ompBuilder->createCanonicalLoop(
3638 loopLoc,
3639 [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *llvmIV) {
3640 // Register the mapping of MLIR induction variable to LLVM-IR
3641 // induction variable
3642 moduleTranslation.mapValue(loopIV, llvmIV);
3643
3644 builder.restoreIP(ip);
3646 convertOmpOpRegions(op.getRegion(), "omp.loop.region", builder,
3647 moduleTranslation);
3648
3649 return bodyGenStatus.takeError();
3650 },
3651 llvmTC, "omp.loop");
3652 if (!llvmOrError)
3653 return op.emitError(llvm::toString(llvmOrError.takeError()));
3654
3655 llvm::CanonicalLoopInfo *llvmCLI = *llvmOrError;
3656 llvm::IRBuilderBase::InsertPoint afterIP = llvmCLI->getAfterIP();
3657 builder.restoreIP(afterIP);
3658
3659 // Register the mapping of MLIR loop to LLVM-IR OpenMPIRBuilder loop
3660 if (Value cli = op.getCli())
3661 moduleTranslation.mapOmpLoop(cli, llvmCLI);
3662
3663 return success();
3664}
3665
3666/// Apply a `#pragma omp unroll` / "!$omp unroll" transformation using the
3667/// OpenMPIRBuilder.
3668static LogicalResult
3669applyUnrollHeuristic(omp::UnrollHeuristicOp op, llvm::IRBuilderBase &builder,
3670 LLVM::ModuleTranslation &moduleTranslation) {
3671 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3672
3673 Value applyee = op.getApplyee();
3674 assert(applyee && "Loop to apply unrolling on required");
3675
3676 llvm::CanonicalLoopInfo *consBuilderCLI =
3677 moduleTranslation.lookupOMPLoop(applyee);
3678 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
3679 ompBuilder->unrollLoopHeuristic(loc.DL, consBuilderCLI);
3680
3681 moduleTranslation.invalidateOmpLoop(applyee);
3682 return success();
3683}
3684
3685/// Apply a `#pragma omp tile` / `!$omp tile` transformation using the
3686/// OpenMPIRBuilder.
3687static LogicalResult applyTile(omp::TileOp op, llvm::IRBuilderBase &builder,
3688 LLVM::ModuleTranslation &moduleTranslation) {
3689 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3690 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
3691
3693 SmallVector<llvm::Value *> translatedSizes;
3694
3695 for (Value size : op.getSizes()) {
3696 llvm::Value *translatedSize = moduleTranslation.lookupValue(size);
3697 assert(translatedSize &&
3698 "sizes clause arguments must already be translated");
3699 translatedSizes.push_back(translatedSize);
3700 }
3701
3702 for (Value applyee : op.getApplyees()) {
3703 llvm::CanonicalLoopInfo *consBuilderCLI =
3704 moduleTranslation.lookupOMPLoop(applyee);
3705 assert(applyee && "Canonical loop must already been translated");
3706 translatedLoops.push_back(consBuilderCLI);
3707 }
3708
3709 auto generatedLoops =
3710 ompBuilder->tileLoops(loc.DL, translatedLoops, translatedSizes);
3711 if (!op.getGeneratees().empty()) {
3712 for (auto [mlirLoop, genLoop] :
3713 zip_equal(op.getGeneratees(), generatedLoops))
3714 moduleTranslation.mapOmpLoop(mlirLoop, genLoop);
3715 }
3716
3717 // CLIs can only be consumed once
3718 for (Value applyee : op.getApplyees())
3719 moduleTranslation.invalidateOmpLoop(applyee);
3720
3721 return success();
3722}
3723
3724/// Convert an Atomic Ordering attribute to llvm::AtomicOrdering.
3725static llvm::AtomicOrdering
3726convertAtomicOrdering(std::optional<omp::ClauseMemoryOrderKind> ao) {
3727 if (!ao)
3728 return llvm::AtomicOrdering::Monotonic; // Default Memory Ordering
3729
3730 switch (*ao) {
3731 case omp::ClauseMemoryOrderKind::Seq_cst:
3732 return llvm::AtomicOrdering::SequentiallyConsistent;
3733 case omp::ClauseMemoryOrderKind::Acq_rel:
3734 return llvm::AtomicOrdering::AcquireRelease;
3735 case omp::ClauseMemoryOrderKind::Acquire:
3736 return llvm::AtomicOrdering::Acquire;
3737 case omp::ClauseMemoryOrderKind::Release:
3738 return llvm::AtomicOrdering::Release;
3739 case omp::ClauseMemoryOrderKind::Relaxed:
3740 return llvm::AtomicOrdering::Monotonic;
3741 }
3742 llvm_unreachable("Unknown ClauseMemoryOrderKind kind");
3743}
3744
3745/// Convert omp.atomic.read operation to LLVM IR.
3746static LogicalResult
3747convertOmpAtomicRead(Operation &opInst, llvm::IRBuilderBase &builder,
3748 LLVM::ModuleTranslation &moduleTranslation) {
3749 auto readOp = cast<omp::AtomicReadOp>(opInst);
3750 if (failed(checkImplementationStatus(opInst)))
3751 return failure();
3752
3753 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3754 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3755 findAllocaInsertPoint(builder, moduleTranslation);
3756
3757 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3758
3759 llvm::AtomicOrdering AO = convertAtomicOrdering(readOp.getMemoryOrder());
3760 llvm::Value *x = moduleTranslation.lookupValue(readOp.getX());
3761 llvm::Value *v = moduleTranslation.lookupValue(readOp.getV());
3762
3763 llvm::Type *elementType =
3764 moduleTranslation.convertType(readOp.getElementType());
3765
3766 llvm::OpenMPIRBuilder::AtomicOpValue V = {v, elementType, false, false};
3767 llvm::OpenMPIRBuilder::AtomicOpValue X = {x, elementType, false, false};
3768 builder.restoreIP(ompBuilder->createAtomicRead(ompLoc, X, V, AO, allocaIP));
3769 return success();
3770}
3771
3772/// Converts an omp.atomic.write operation to LLVM IR.
3773static LogicalResult
3774convertOmpAtomicWrite(Operation &opInst, llvm::IRBuilderBase &builder,
3775 LLVM::ModuleTranslation &moduleTranslation) {
3776 auto writeOp = cast<omp::AtomicWriteOp>(opInst);
3777 if (failed(checkImplementationStatus(opInst)))
3778 return failure();
3779
3780 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3781 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3782 findAllocaInsertPoint(builder, moduleTranslation);
3783
3784 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3785 llvm::AtomicOrdering ao = convertAtomicOrdering(writeOp.getMemoryOrder());
3786 llvm::Value *expr = moduleTranslation.lookupValue(writeOp.getExpr());
3787 llvm::Value *dest = moduleTranslation.lookupValue(writeOp.getX());
3788 llvm::Type *ty = moduleTranslation.convertType(writeOp.getExpr().getType());
3789 llvm::OpenMPIRBuilder::AtomicOpValue x = {dest, ty, /*isSigned=*/false,
3790 /*isVolatile=*/false};
3791 builder.restoreIP(
3792 ompBuilder->createAtomicWrite(ompLoc, x, expr, ao, allocaIP));
3793 return success();
3794}
3795
3796/// Converts an LLVM dialect binary operation to the corresponding enum value
3797/// for `atomicrmw` supported binary operation.
3798static llvm::AtomicRMWInst::BinOp convertBinOpToAtomic(Operation &op) {
3800 .Case([&](LLVM::AddOp) { return llvm::AtomicRMWInst::BinOp::Add; })
3801 .Case([&](LLVM::SubOp) { return llvm::AtomicRMWInst::BinOp::Sub; })
3802 .Case([&](LLVM::AndOp) { return llvm::AtomicRMWInst::BinOp::And; })
3803 .Case([&](LLVM::OrOp) { return llvm::AtomicRMWInst::BinOp::Or; })
3804 .Case([&](LLVM::XOrOp) { return llvm::AtomicRMWInst::BinOp::Xor; })
3805 .Case([&](LLVM::UMaxOp) { return llvm::AtomicRMWInst::BinOp::UMax; })
3806 .Case([&](LLVM::UMinOp) { return llvm::AtomicRMWInst::BinOp::UMin; })
3807 .Case([&](LLVM::FAddOp) { return llvm::AtomicRMWInst::BinOp::FAdd; })
3808 .Case([&](LLVM::FSubOp) { return llvm::AtomicRMWInst::BinOp::FSub; })
3809 .Default(llvm::AtomicRMWInst::BinOp::BAD_BINOP);
3810}
3811
3812static void extractAtomicControlFlags(omp::AtomicUpdateOp atomicUpdateOp,
3813 bool &isIgnoreDenormalMode,
3814 bool &isFineGrainedMemory,
3815 bool &isRemoteMemory) {
3816 isIgnoreDenormalMode = false;
3817 isFineGrainedMemory = false;
3818 isRemoteMemory = false;
3819 if (atomicUpdateOp &&
3820 atomicUpdateOp->hasAttr(atomicUpdateOp.getAtomicControlAttrName())) {
3821 mlir::omp::AtomicControlAttr atomicControlAttr =
3822 atomicUpdateOp.getAtomicControlAttr();
3823 isIgnoreDenormalMode = atomicControlAttr.getIgnoreDenormalMode();
3824 isFineGrainedMemory = atomicControlAttr.getFineGrainedMemory();
3825 isRemoteMemory = atomicControlAttr.getRemoteMemory();
3826 }
3827}
3828
3829/// Converts an OpenMP atomic update operation using OpenMPIRBuilder.
3830static LogicalResult
3831convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst,
3832 llvm::IRBuilderBase &builder,
3833 LLVM::ModuleTranslation &moduleTranslation) {
3834 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3835 if (failed(checkImplementationStatus(*opInst)))
3836 return failure();
3837
3838 // Convert values and types.
3839 auto &innerOpList = opInst.getRegion().front().getOperations();
3840 bool isXBinopExpr{false};
3841 llvm::AtomicRMWInst::BinOp binop;
3842 mlir::Value mlirExpr;
3843 llvm::Value *llvmExpr = nullptr;
3844 llvm::Value *llvmX = nullptr;
3845 llvm::Type *llvmXElementType = nullptr;
3846 if (innerOpList.size() == 2) {
3847 // The two operations here are the update and the terminator.
3848 // Since we can identify the update operation, there is a possibility
3849 // that we can generate the atomicrmw instruction.
3850 mlir::Operation &innerOp = *opInst.getRegion().front().begin();
3851 if (!llvm::is_contained(innerOp.getOperands(),
3852 opInst.getRegion().getArgument(0))) {
3853 return opInst.emitError("no atomic update operation with region argument"
3854 " as operand found inside atomic.update region");
3855 }
3856 binop = convertBinOpToAtomic(innerOp);
3857 isXBinopExpr = innerOp.getOperand(0) == opInst.getRegion().getArgument(0);
3858 mlirExpr = (isXBinopExpr ? innerOp.getOperand(1) : innerOp.getOperand(0));
3859 llvmExpr = moduleTranslation.lookupValue(mlirExpr);
3860 } else {
3861 // Since the update region includes more than one operation
3862 // we will resort to generating a cmpxchg loop.
3863 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
3864 }
3865 llvmX = moduleTranslation.lookupValue(opInst.getX());
3866 llvmXElementType = moduleTranslation.convertType(
3867 opInst.getRegion().getArgument(0).getType());
3868 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
3869 /*isSigned=*/false,
3870 /*isVolatile=*/false};
3871
3872 llvm::AtomicOrdering atomicOrdering =
3873 convertAtomicOrdering(opInst.getMemoryOrder());
3874
3875 // Generate update code.
3876 auto updateFn =
3877 [&opInst, &moduleTranslation](
3878 llvm::Value *atomicx,
3879 llvm::IRBuilder<> &builder) -> llvm::Expected<llvm::Value *> {
3880 Block &bb = *opInst.getRegion().begin();
3881 moduleTranslation.mapValue(*opInst.getRegion().args_begin(), atomicx);
3882 moduleTranslation.mapBlock(&bb, builder.GetInsertBlock());
3883 if (failed(moduleTranslation.convertBlock(bb, true, builder)))
3884 return llvm::make_error<PreviouslyReportedError>();
3885
3886 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.getTerminator());
3887 assert(yieldop && yieldop.getResults().size() == 1 &&
3888 "terminator must be omp.yield op and it must have exactly one "
3889 "argument");
3890 return moduleTranslation.lookupValue(yieldop.getResults()[0]);
3891 };
3892
3893 bool isIgnoreDenormalMode;
3894 bool isFineGrainedMemory;
3895 bool isRemoteMemory;
3896 extractAtomicControlFlags(opInst, isIgnoreDenormalMode, isFineGrainedMemory,
3897 isRemoteMemory);
3898 // Handle ambiguous alloca, if any.
3899 auto allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
3900 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3901 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3902 ompBuilder->createAtomicUpdate(ompLoc, allocaIP, llvmAtomicX, llvmExpr,
3903 atomicOrdering, binop, updateFn,
3904 isXBinopExpr, isIgnoreDenormalMode,
3905 isFineGrainedMemory, isRemoteMemory);
3906
3907 if (failed(handleError(afterIP, *opInst)))
3908 return failure();
3909
3910 builder.restoreIP(*afterIP);
3911 return success();
3912}
3913
3914static LogicalResult
3915convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp,
3916 llvm::IRBuilderBase &builder,
3917 LLVM::ModuleTranslation &moduleTranslation) {
3918 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3919 if (failed(checkImplementationStatus(*atomicCaptureOp)))
3920 return failure();
3921
3922 mlir::Value mlirExpr;
3923 bool isXBinopExpr = false, isPostfixUpdate = false;
3924 llvm::AtomicRMWInst::BinOp binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
3925
3926 omp::AtomicUpdateOp atomicUpdateOp = atomicCaptureOp.getAtomicUpdateOp();
3927 omp::AtomicWriteOp atomicWriteOp = atomicCaptureOp.getAtomicWriteOp();
3928
3929 assert((atomicUpdateOp || atomicWriteOp) &&
3930 "internal op must be an atomic.update or atomic.write op");
3931
3932 if (atomicWriteOp) {
3933 isPostfixUpdate = true;
3934 mlirExpr = atomicWriteOp.getExpr();
3935 } else {
3936 isPostfixUpdate = atomicCaptureOp.getSecondOp() ==
3937 atomicCaptureOp.getAtomicUpdateOp().getOperation();
3938 auto &innerOpList = atomicUpdateOp.getRegion().front().getOperations();
3939 // Find the binary update operation that uses the region argument
3940 // and get the expression to update
3941 if (innerOpList.size() == 2) {
3942 mlir::Operation &innerOp = *atomicUpdateOp.getRegion().front().begin();
3943 if (!llvm::is_contained(innerOp.getOperands(),
3944 atomicUpdateOp.getRegion().getArgument(0))) {
3945 return atomicUpdateOp.emitError(
3946 "no atomic update operation with region argument"
3947 " as operand found inside atomic.update region");
3948 }
3949 binop = convertBinOpToAtomic(innerOp);
3950 isXBinopExpr =
3951 innerOp.getOperand(0) == atomicUpdateOp.getRegion().getArgument(0);
3952 mlirExpr = (isXBinopExpr ? innerOp.getOperand(1) : innerOp.getOperand(0));
3953 } else {
3954 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
3955 }
3956 }
3957
3958 llvm::Value *llvmExpr = moduleTranslation.lookupValue(mlirExpr);
3959 llvm::Value *llvmX =
3960 moduleTranslation.lookupValue(atomicCaptureOp.getAtomicReadOp().getX());
3961 llvm::Value *llvmV =
3962 moduleTranslation.lookupValue(atomicCaptureOp.getAtomicReadOp().getV());
3963 llvm::Type *llvmXElementType = moduleTranslation.convertType(
3964 atomicCaptureOp.getAtomicReadOp().getElementType());
3965 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
3966 /*isSigned=*/false,
3967 /*isVolatile=*/false};
3968 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicV = {llvmV, llvmXElementType,
3969 /*isSigned=*/false,
3970 /*isVolatile=*/false};
3971
3972 llvm::AtomicOrdering atomicOrdering =
3973 convertAtomicOrdering(atomicCaptureOp.getMemoryOrder());
3974
3975 auto updateFn =
3976 [&](llvm::Value *atomicx,
3977 llvm::IRBuilder<> &builder) -> llvm::Expected<llvm::Value *> {
3978 if (atomicWriteOp)
3979 return moduleTranslation.lookupValue(atomicWriteOp.getExpr());
3980 Block &bb = *atomicUpdateOp.getRegion().begin();
3981 moduleTranslation.mapValue(*atomicUpdateOp.getRegion().args_begin(),
3982 atomicx);
3983 moduleTranslation.mapBlock(&bb, builder.GetInsertBlock());
3984 if (failed(moduleTranslation.convertBlock(bb, true, builder)))
3985 return llvm::make_error<PreviouslyReportedError>();
3986
3987 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.getTerminator());
3988 assert(yieldop && yieldop.getResults().size() == 1 &&
3989 "terminator must be omp.yield op and it must have exactly one "
3990 "argument");
3991 return moduleTranslation.lookupValue(yieldop.getResults()[0]);
3992 };
3993
3994 bool isIgnoreDenormalMode;
3995 bool isFineGrainedMemory;
3996 bool isRemoteMemory;
3997 extractAtomicControlFlags(atomicUpdateOp, isIgnoreDenormalMode,
3998 isFineGrainedMemory, isRemoteMemory);
3999 // Handle ambiguous alloca, if any.
4000 auto allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
4001 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4002 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4003 ompBuilder->createAtomicCapture(
4004 ompLoc, allocaIP, llvmAtomicX, llvmAtomicV, llvmExpr, atomicOrdering,
4005 binop, updateFn, atomicUpdateOp, isPostfixUpdate, isXBinopExpr,
4006 isIgnoreDenormalMode, isFineGrainedMemory, isRemoteMemory);
4007
4008 if (failed(handleError(afterIP, *atomicCaptureOp)))
4009 return failure();
4010
4011 builder.restoreIP(*afterIP);
4012 return success();
4013}
4014
4015static llvm::omp::Directive convertCancellationConstructType(
4016 omp::ClauseCancellationConstructType directive) {
4017 switch (directive) {
4018 case omp::ClauseCancellationConstructType::Loop:
4019 return llvm::omp::Directive::OMPD_for;
4020 case omp::ClauseCancellationConstructType::Parallel:
4021 return llvm::omp::Directive::OMPD_parallel;
4022 case omp::ClauseCancellationConstructType::Sections:
4023 return llvm::omp::Directive::OMPD_sections;
4024 case omp::ClauseCancellationConstructType::Taskgroup:
4025 return llvm::omp::Directive::OMPD_taskgroup;
4026 }
4027 llvm_unreachable("Unhandled cancellation construct type");
4028}
4029
4030static LogicalResult
4031convertOmpCancel(omp::CancelOp op, llvm::IRBuilderBase &builder,
4032 LLVM::ModuleTranslation &moduleTranslation) {
4033 if (failed(checkImplementationStatus(*op.getOperation())))
4034 return failure();
4035
4036 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4037 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4038
4039 llvm::Value *ifCond = nullptr;
4040 if (Value ifVar = op.getIfExpr())
4041 ifCond = moduleTranslation.lookupValue(ifVar);
4042
4043 llvm::omp::Directive cancelledDirective =
4044 convertCancellationConstructType(op.getCancelDirective());
4045
4046 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4047 ompBuilder->createCancel(ompLoc, ifCond, cancelledDirective);
4048
4049 if (failed(handleError(afterIP, *op.getOperation())))
4050 return failure();
4051
4052 builder.restoreIP(afterIP.get());
4053
4054 return success();
4055}
4056
4057static LogicalResult
4058convertOmpCancellationPoint(omp::CancellationPointOp op,
4059 llvm::IRBuilderBase &builder,
4060 LLVM::ModuleTranslation &moduleTranslation) {
4061 if (failed(checkImplementationStatus(*op.getOperation())))
4062 return failure();
4063
4064 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4065 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4066
4067 llvm::omp::Directive cancelledDirective =
4068 convertCancellationConstructType(op.getCancelDirective());
4069
4070 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4071 ompBuilder->createCancellationPoint(ompLoc, cancelledDirective);
4072
4073 if (failed(handleError(afterIP, *op.getOperation())))
4074 return failure();
4075
4076 builder.restoreIP(afterIP.get());
4077
4078 return success();
4079}
4080
4081/// Converts an OpenMP Threadprivate operation into LLVM IR using
4082/// OpenMPIRBuilder.
4083static LogicalResult
4084convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder,
4085 LLVM::ModuleTranslation &moduleTranslation) {
4086 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4087 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4088 auto threadprivateOp = cast<omp::ThreadprivateOp>(opInst);
4089
4090 if (failed(checkImplementationStatus(opInst)))
4091 return failure();
4092
4093 Value symAddr = threadprivateOp.getSymAddr();
4094 auto *symOp = symAddr.getDefiningOp();
4095
4096 if (auto asCast = dyn_cast<LLVM::AddrSpaceCastOp>(symOp))
4097 symOp = asCast.getOperand().getDefiningOp();
4098
4099 if (!isa<LLVM::AddressOfOp>(symOp))
4100 return opInst.emitError("Addressing symbol not found");
4101 LLVM::AddressOfOp addressOfOp = dyn_cast<LLVM::AddressOfOp>(symOp);
4102
4103 LLVM::GlobalOp global =
4104 addressOfOp.getGlobal(moduleTranslation.symbolTable());
4105 llvm::GlobalValue *globalValue = moduleTranslation.lookupGlobal(global);
4106 llvm::Type *type = globalValue->getValueType();
4107 llvm::TypeSize typeSize =
4108 builder.GetInsertBlock()->getModule()->getDataLayout().getTypeStoreSize(
4109 type);
4110 llvm::ConstantInt *size = builder.getInt64(typeSize.getFixedValue());
4111 llvm::Value *callInst = ompBuilder->createCachedThreadPrivate(
4112 ompLoc, globalValue, size, global.getSymName() + ".cache");
4113 moduleTranslation.mapValue(opInst.getResult(0), callInst);
4114
4115 return success();
4116}
4117
4118static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind
4119convertToDeviceClauseKind(mlir::omp::DeclareTargetDeviceType deviceClause) {
4120 switch (deviceClause) {
4121 case mlir::omp::DeclareTargetDeviceType::host:
4122 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseHost;
4123 break;
4124 case mlir::omp::DeclareTargetDeviceType::nohost:
4125 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseNoHost;
4126 break;
4127 case mlir::omp::DeclareTargetDeviceType::any:
4128 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseAny;
4129 break;
4130 }
4131 llvm_unreachable("unhandled device clause");
4132}
4133
4134static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind
4136 mlir::omp::DeclareTargetCaptureClause captureClause) {
4137 switch (captureClause) {
4138 case mlir::omp::DeclareTargetCaptureClause::to:
4139 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo;
4140 case mlir::omp::DeclareTargetCaptureClause::link:
4141 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink;
4142 case mlir::omp::DeclareTargetCaptureClause::enter:
4143 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter;
4144 case mlir::omp::DeclareTargetCaptureClause::none:
4145 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryNone;
4146 }
4147 llvm_unreachable("unhandled capture clause");
4148}
4149
4151 Operation *op = value.getDefiningOp();
4152 if (auto addrCast = dyn_cast_if_present<LLVM::AddrSpaceCastOp>(op))
4153 op = addrCast->getOperand(0).getDefiningOp();
4154 if (auto addressOfOp = dyn_cast_if_present<LLVM::AddressOfOp>(op)) {
4155 auto modOp = addressOfOp->getParentOfType<mlir::ModuleOp>();
4156 return modOp.lookupSymbol(addressOfOp.getGlobalName());
4157 }
4158 return nullptr;
4159}
4160
4161static llvm::SmallString<64>
4162getDeclareTargetRefPtrSuffix(LLVM::GlobalOp globalOp,
4163 llvm::OpenMPIRBuilder &ompBuilder) {
4164 llvm::SmallString<64> suffix;
4165 llvm::raw_svector_ostream os(suffix);
4166 if (globalOp.getVisibility() == mlir::SymbolTable::Visibility::Private) {
4167 auto loc = globalOp->getLoc()->findInstanceOf<FileLineColLoc>();
4168 auto fileInfoCallBack = [&loc]() {
4169 return std::pair<std::string, uint64_t>(
4170 llvm::StringRef(loc.getFilename()), loc.getLine());
4171 };
4172
4173 auto vfs = llvm::vfs::getRealFileSystem();
4174 os << llvm::format(
4175 "_%x",
4176 ompBuilder.getTargetEntryUniqueInfo(fileInfoCallBack, *vfs).FileID);
4177 }
4178 os << "_decl_tgt_ref_ptr";
4179
4180 return suffix;
4181}
4182
4183static bool isDeclareTargetLink(Value value) {
4184 if (auto declareTargetGlobal =
4185 dyn_cast_if_present<omp::DeclareTargetInterface>(
4186 getGlobalOpFromValue(value)))
4187 if (declareTargetGlobal.getDeclareTargetCaptureClause() ==
4188 omp::DeclareTargetCaptureClause::link)
4189 return true;
4190 return false;
4191}
4192
4193static bool isDeclareTargetTo(Value value) {
4194 if (auto declareTargetGlobal =
4195 dyn_cast_if_present<omp::DeclareTargetInterface>(
4196 getGlobalOpFromValue(value)))
4197 if (declareTargetGlobal.getDeclareTargetCaptureClause() ==
4198 omp::DeclareTargetCaptureClause::to ||
4199 declareTargetGlobal.getDeclareTargetCaptureClause() ==
4200 omp::DeclareTargetCaptureClause::enter)
4201 return true;
4202 return false;
4203}
4204
4205// Returns the reference pointer generated by the lowering of the declare
4206// target operation in cases where the link clause is used or the to clause is
4207// used in USM mode.
4208static llvm::Value *
4210 LLVM::ModuleTranslation &moduleTranslation) {
4211 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4212 if (auto gOp =
4213 dyn_cast_or_null<LLVM::GlobalOp>(getGlobalOpFromValue(value))) {
4214 if (auto declareTargetGlobal =
4215 dyn_cast<omp::DeclareTargetInterface>(gOp.getOperation())) {
4216 // In this case, we must utilise the reference pointer generated by
4217 // the declare target operation, similar to Clang
4218 if ((declareTargetGlobal.getDeclareTargetCaptureClause() ==
4219 omp::DeclareTargetCaptureClause::link) ||
4220 (declareTargetGlobal.getDeclareTargetCaptureClause() ==
4221 omp::DeclareTargetCaptureClause::to &&
4222 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
4223 llvm::SmallString<64> suffix =
4224 getDeclareTargetRefPtrSuffix(gOp, *ompBuilder);
4225
4226 if (gOp.getSymName().contains(suffix))
4227 return moduleTranslation.getLLVMModule()->getNamedValue(
4228 gOp.getSymName());
4229
4230 return moduleTranslation.getLLVMModule()->getNamedValue(
4231 (gOp.getSymName().str() + suffix.str()).str());
4232 }
4233 }
4234 }
4235 return nullptr;
4236}
4237
4238namespace {
4239// Append customMappers information to existing MapInfosTy
4240struct MapInfosTy : llvm::OpenMPIRBuilder::MapInfosTy {
4241 SmallVector<Operation *, 4> Mappers;
4242
4243 /// Append arrays in \a CurInfo.
4244 void append(MapInfosTy &curInfo) {
4245 Mappers.append(curInfo.Mappers.begin(), curInfo.Mappers.end());
4246 llvm::OpenMPIRBuilder::MapInfosTy::append(curInfo);
4247 }
4248};
4249// A small helper structure to contain data gathered
4250// for map lowering and coalese it into one area and
4251// avoiding extra computations such as searches in the
4252// llvm module for lowered mapped variables or checking
4253// if something is declare target (and retrieving the
4254// value) more than neccessary.
4255struct MapInfoData : MapInfosTy {
4256 llvm::SmallVector<bool, 4> IsDeclareTarget;
4257 llvm::SmallVector<bool, 4> IsAMember;
4258 // Identify if mapping was added by mapClause or use_device clauses.
4259 llvm::SmallVector<bool, 4> IsAMapping;
4260 llvm::SmallVector<mlir::Operation *, 4> MapClause;
4261 llvm::SmallVector<llvm::Value *, 4> OriginalValue;
4262 // Stripped off array/pointer to get the underlying
4263 // element type
4264 llvm::SmallVector<llvm::Type *, 4> BaseType;
4265
4266 /// Append arrays in \a CurInfo.
4267 void append(MapInfoData &CurInfo) {
4268 IsDeclareTarget.append(CurInfo.IsDeclareTarget.begin(),
4269 CurInfo.IsDeclareTarget.end());
4270 MapClause.append(CurInfo.MapClause.begin(), CurInfo.MapClause.end());
4271 OriginalValue.append(CurInfo.OriginalValue.begin(),
4272 CurInfo.OriginalValue.end());
4273 BaseType.append(CurInfo.BaseType.begin(), CurInfo.BaseType.end());
4274 MapInfosTy::append(CurInfo);
4275 }
4276};
4277
4278enum class TargetDirectiveEnumTy : uint32_t {
4279 None = 0,
4280 Target = 1,
4281 TargetData = 2,
4282 TargetEnterData = 3,
4283 TargetExitData = 4,
4284 TargetUpdate = 5
4285};
4286
4287static TargetDirectiveEnumTy getTargetDirectiveEnumTyFromOp(Operation *op) {
4288 return llvm::TypeSwitch<Operation *, TargetDirectiveEnumTy>(op)
4289 .Case([](omp::TargetDataOp) { return TargetDirectiveEnumTy::TargetData; })
4290 .Case([](omp::TargetEnterDataOp) {
4291 return TargetDirectiveEnumTy::TargetEnterData;
4292 })
4293 .Case([&](omp::TargetExitDataOp) {
4294 return TargetDirectiveEnumTy::TargetExitData;
4295 })
4296 .Case([&](omp::TargetUpdateOp) {
4297 return TargetDirectiveEnumTy::TargetUpdate;
4298 })
4299 .Case([&](omp::TargetOp) { return TargetDirectiveEnumTy::Target; })
4300 .Default([&](Operation *op) { return TargetDirectiveEnumTy::None; });
4301}
4302
4303} // namespace
4304
4305static uint64_t getArrayElementSizeInBits(LLVM::LLVMArrayType arrTy,
4306 DataLayout &dl) {
4307 if (auto nestedArrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(
4308 arrTy.getElementType()))
4309 return getArrayElementSizeInBits(nestedArrTy, dl);
4310 return dl.getTypeSizeInBits(arrTy.getElementType());
4311}
4312
4313// This function calculates the size to be offloaded for a specified type, given
4314// its associated map clause (which can contain bounds information which affects
4315// the total size), this size is calculated based on the underlying element type
4316// e.g. given a 1-D array of ints, we will calculate the size from the integer
4317// type * number of elements in the array. This size can be used in other
4318// calculations but is ultimately used as an argument to the OpenMP runtimes
4319// kernel argument structure which is generated through the combinedInfo data
4320// structures.
4321// This function is somewhat equivalent to Clang's getExprTypeSize inside of
4322// CGOpenMPRuntime.cpp.
4323static llvm::Value *getSizeInBytes(DataLayout &dl, const mlir::Type &type,
4324 Operation *clauseOp,
4325 llvm::Value *basePointer,
4326 llvm::Type *baseType,
4327 llvm::IRBuilderBase &builder,
4328 LLVM::ModuleTranslation &moduleTranslation) {
4329 if (auto memberClause =
4330 mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(clauseOp)) {
4331 // This calculates the size to transfer based on bounds and the underlying
4332 // element type, provided bounds have been specified (Fortran
4333 // pointers/allocatables/target and arrays that have sections specified fall
4334 // into this as well)
4335 if (!memberClause.getBounds().empty()) {
4336 llvm::Value *elementCount = builder.getInt64(1);
4337 for (auto bounds : memberClause.getBounds()) {
4338 if (auto boundOp = mlir::dyn_cast_if_present<mlir::omp::MapBoundsOp>(
4339 bounds.getDefiningOp())) {
4340 // The below calculation for the size to be mapped calculated from the
4341 // map.info's bounds is: (elemCount * [UB - LB] + 1), later we
4342 // multiply by the underlying element types byte size to get the full
4343 // size to be offloaded based on the bounds
4344 elementCount = builder.CreateMul(
4345 elementCount,
4346 builder.CreateAdd(
4347 builder.CreateSub(
4348 moduleTranslation.lookupValue(boundOp.getUpperBound()),
4349 moduleTranslation.lookupValue(boundOp.getLowerBound())),
4350 builder.getInt64(1)));
4351 }
4352 }
4353
4354 // utilising getTypeSizeInBits instead of getTypeSize as getTypeSize gives
4355 // the size in inconsistent byte or bit format.
4356 uint64_t underlyingTypeSzInBits = dl.getTypeSizeInBits(type);
4357 if (auto arrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(type))
4358 underlyingTypeSzInBits = getArrayElementSizeInBits(arrTy, dl);
4359
4360 // The size in bytes x number of elements, the sizeInBytes stored is
4361 // the underyling types size, e.g. if ptr<i32>, it'll be the i32's
4362 // size, so we do some on the fly runtime math to get the size in
4363 // bytes from the extent (ub - lb) * sizeInBytes. NOTE: This may need
4364 // some adjustment for members with more complex types.
4365 return builder.CreateMul(elementCount,
4366 builder.getInt64(underlyingTypeSzInBits / 8));
4367 }
4368 }
4369
4370 return builder.getInt64(dl.getTypeSizeInBits(type) / 8);
4371}
4372
4373// Convert the MLIR map flag set to the runtime map flag set for embedding
4374// in LLVM-IR. This is important as the two bit-flag lists do not correspond
4375// 1-to-1 as there's flags the runtime doesn't care about and vice versa.
4376// Certain flags are discarded here such as RefPtee and co.
4377static llvm::omp::OpenMPOffloadMappingFlags
4378convertClauseMapFlags(omp::ClauseMapFlags mlirFlags) {
4379 auto mapTypeToBool = [&mlirFlags](omp::ClauseMapFlags flag) {
4380 return (mlirFlags & flag) == flag;
4381 };
4382 const bool hasExplicitMap =
4383 (mlirFlags & ~omp::ClauseMapFlags::is_device_ptr) !=
4384 omp::ClauseMapFlags::none;
4385
4386 llvm::omp::OpenMPOffloadMappingFlags mapType =
4387 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
4388
4389 if (mapTypeToBool(omp::ClauseMapFlags::to))
4390 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
4391
4392 if (mapTypeToBool(omp::ClauseMapFlags::from))
4393 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
4394
4395 if (mapTypeToBool(omp::ClauseMapFlags::always))
4396 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
4397
4398 if (mapTypeToBool(omp::ClauseMapFlags::del))
4399 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
4400
4401 if (mapTypeToBool(omp::ClauseMapFlags::return_param))
4402 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
4403
4404 if (mapTypeToBool(omp::ClauseMapFlags::priv))
4405 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRIVATE;
4406
4407 if (mapTypeToBool(omp::ClauseMapFlags::literal))
4408 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
4409
4410 if (mapTypeToBool(omp::ClauseMapFlags::implicit))
4411 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
4412
4413 if (mapTypeToBool(omp::ClauseMapFlags::close))
4414 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
4415
4416 if (mapTypeToBool(omp::ClauseMapFlags::present))
4417 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT;
4418
4419 if (mapTypeToBool(omp::ClauseMapFlags::ompx_hold))
4420 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD;
4421
4422 if (mapTypeToBool(omp::ClauseMapFlags::attach))
4423 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH;
4424
4425 if (mapTypeToBool(omp::ClauseMapFlags::is_device_ptr)) {
4426 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
4427 if (!hasExplicitMap)
4428 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
4429 }
4430
4431 return mapType;
4432}
4433
4435 MapInfoData &mapData, SmallVectorImpl<Value> &mapVars,
4436 LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl,
4437 llvm::IRBuilderBase &builder, ArrayRef<Value> useDevPtrOperands = {},
4438 ArrayRef<Value> useDevAddrOperands = {},
4439 ArrayRef<Value> hasDevAddrOperands = {}) {
4440 auto checkIsAMember = [](const auto &mapVars, auto mapOp) {
4441 // Check if this is a member mapping and correctly assign that it is, if
4442 // it is a member of a larger object.
4443 // TODO: Need better handling of members, and distinguishing of members
4444 // that are implicitly allocated on device vs explicitly passed in as
4445 // arguments.
4446 // TODO: May require some further additions to support nested record
4447 // types, i.e. member maps that can have member maps.
4448 for (Value mapValue : mapVars) {
4449 auto map = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
4450 for (auto member : map.getMembers())
4451 if (member == mapOp)
4452 return true;
4453 }
4454 return false;
4455 };
4456
4457 // Process MapOperands
4458 for (Value mapValue : mapVars) {
4459 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
4460 Value offloadPtr =
4461 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
4462 mapData.OriginalValue.push_back(moduleTranslation.lookupValue(offloadPtr));
4463 mapData.Pointers.push_back(mapData.OriginalValue.back());
4464
4465 if (llvm::Value *refPtr =
4466 getRefPtrIfDeclareTarget(offloadPtr, moduleTranslation)) {
4467 mapData.IsDeclareTarget.push_back(true);
4468 mapData.BasePointers.push_back(refPtr);
4469 } else if (isDeclareTargetTo(offloadPtr)) {
4470 mapData.IsDeclareTarget.push_back(true);
4471 mapData.BasePointers.push_back(mapData.OriginalValue.back());
4472 } else { // regular mapped variable
4473 mapData.IsDeclareTarget.push_back(false);
4474 mapData.BasePointers.push_back(mapData.OriginalValue.back());
4475 }
4476
4477 mapData.BaseType.push_back(
4478 moduleTranslation.convertType(mapOp.getVarType()));
4479 mapData.Sizes.push_back(
4480 getSizeInBytes(dl, mapOp.getVarType(), mapOp, mapData.Pointers.back(),
4481 mapData.BaseType.back(), builder, moduleTranslation));
4482 mapData.MapClause.push_back(mapOp.getOperation());
4483 mapData.Types.push_back(convertClauseMapFlags(mapOp.getMapType()));
4484 mapData.Names.push_back(LLVM::createMappingInformation(
4485 mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
4486 mapData.DevicePointers.push_back(llvm::OpenMPIRBuilder::DeviceInfoTy::None);
4487 if (mapOp.getMapperId())
4488 mapData.Mappers.push_back(
4490 mapOp, mapOp.getMapperIdAttr()));
4491 else
4492 mapData.Mappers.push_back(nullptr);
4493 mapData.IsAMapping.push_back(true);
4494 mapData.IsAMember.push_back(checkIsAMember(mapVars, mapOp));
4495 }
4496
4497 auto findMapInfo = [&mapData](llvm::Value *val,
4498 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
4499 unsigned index = 0;
4500 bool found = false;
4501 for (llvm::Value *basePtr : mapData.OriginalValue) {
4502 if (basePtr == val && mapData.IsAMapping[index]) {
4503 found = true;
4504 mapData.Types[index] |=
4505 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
4506 mapData.DevicePointers[index] = devInfoTy;
4507 }
4508 index++;
4509 }
4510 return found;
4511 };
4512
4513 // Process useDevPtr(Addr)Operands
4514 auto addDevInfos = [&](const llvm::ArrayRef<Value> &useDevOperands,
4515 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
4516 for (Value mapValue : useDevOperands) {
4517 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
4518 Value offloadPtr =
4519 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
4520 llvm::Value *origValue = moduleTranslation.lookupValue(offloadPtr);
4521
4522 // Check if map info is already present for this entry.
4523 if (!findMapInfo(origValue, devInfoTy)) {
4524 mapData.OriginalValue.push_back(origValue);
4525 mapData.Pointers.push_back(mapData.OriginalValue.back());
4526 mapData.IsDeclareTarget.push_back(false);
4527 mapData.BasePointers.push_back(mapData.OriginalValue.back());
4528 mapData.BaseType.push_back(
4529 moduleTranslation.convertType(mapOp.getVarType()));
4530 mapData.Sizes.push_back(builder.getInt64(0));
4531 mapData.MapClause.push_back(mapOp.getOperation());
4532 mapData.Types.push_back(
4533 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM);
4534 mapData.Names.push_back(LLVM::createMappingInformation(
4535 mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
4536 mapData.DevicePointers.push_back(devInfoTy);
4537 mapData.Mappers.push_back(nullptr);
4538 mapData.IsAMapping.push_back(false);
4539 mapData.IsAMember.push_back(checkIsAMember(useDevOperands, mapOp));
4540 }
4541 }
4542 };
4543
4544 addDevInfos(useDevAddrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
4545 addDevInfos(useDevPtrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer);
4546
4547 for (Value mapValue : hasDevAddrOperands) {
4548 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
4549 Value offloadPtr =
4550 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
4551 llvm::Value *origValue = moduleTranslation.lookupValue(offloadPtr);
4552 auto mapType = convertClauseMapFlags(mapOp.getMapType());
4553 auto mapTypeAlways = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
4554 bool isDevicePtr =
4555 (mapOp.getMapType() & omp::ClauseMapFlags::is_device_ptr) !=
4556 omp::ClauseMapFlags::none;
4557
4558 mapData.OriginalValue.push_back(origValue);
4559 mapData.BasePointers.push_back(origValue);
4560 mapData.Pointers.push_back(origValue);
4561 mapData.IsDeclareTarget.push_back(false);
4562 mapData.BaseType.push_back(
4563 moduleTranslation.convertType(mapOp.getVarType()));
4564 mapData.Sizes.push_back(
4565 builder.getInt64(dl.getTypeSize(mapOp.getVarType())));
4566 mapData.MapClause.push_back(mapOp.getOperation());
4567 if (llvm::to_underlying(mapType & mapTypeAlways)) {
4568 // Descriptors are mapped with the ALWAYS flag, since they can get
4569 // rematerialized, so the address of the decriptor for a given object
4570 // may change from one place to another.
4571 mapData.Types.push_back(mapType);
4572 // Technically it's possible for a non-descriptor mapping to have
4573 // both has-device-addr and ALWAYS, so lookup the mapper in case it
4574 // exists.
4575 if (mapOp.getMapperId()) {
4576 mapData.Mappers.push_back(
4578 mapOp, mapOp.getMapperIdAttr()));
4579 } else {
4580 mapData.Mappers.push_back(nullptr);
4581 }
4582 } else {
4583 // For is_device_ptr we need the map type to propagate so the runtime
4584 // can materialize the device-side copy of the pointer container.
4585 mapData.Types.push_back(
4586 isDevicePtr ? mapType
4587 : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL);
4588 mapData.Mappers.push_back(nullptr);
4589 }
4590 mapData.Names.push_back(LLVM::createMappingInformation(
4591 mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
4592 mapData.DevicePointers.push_back(
4593 isDevicePtr ? llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer
4594 : llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
4595 mapData.IsAMapping.push_back(false);
4596 mapData.IsAMember.push_back(checkIsAMember(hasDevAddrOperands, mapOp));
4597 }
4598}
4599
4600static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp) {
4601 auto *res = llvm::find(mapData.MapClause, memberOp);
4602 assert(res != mapData.MapClause.end() &&
4603 "MapInfoOp for member not found in MapData, cannot return index");
4604 return std::distance(mapData.MapClause.begin(), res);
4605}
4606
4608 omp::MapInfoOp mapInfo) {
4609 ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
4610 llvm::SmallVector<size_t> occludedChildren;
4611 llvm::sort(
4612 indices.begin(), indices.end(), [&](const size_t a, const size_t b) {
4613 // Bail early if we are asked to look at the same index. If we do not
4614 // bail early, we can end up mistakenly adding indices to
4615 // occludedChildren. This can occur with some types of libc++ hardening.
4616 if (a == b)
4617 return false;
4618
4619 auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]);
4620 auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]);
4621
4622 for (auto it : llvm::zip(memberIndicesA, memberIndicesB)) {
4623 int64_t aIndex = mlir::cast<IntegerAttr>(std::get<0>(it)).getInt();
4624 int64_t bIndex = mlir::cast<IntegerAttr>(std::get<1>(it)).getInt();
4625
4626 if (aIndex == bIndex)
4627 continue;
4628
4629 if (aIndex < bIndex)
4630 return true;
4631
4632 if (aIndex > bIndex)
4633 return false;
4634 }
4635
4636 // Iterated up until the end of the smallest member and
4637 // they were found to be equal up to that point, so select
4638 // the member with the lowest index count, so the "parent"
4639 bool memberAParent = memberIndicesA.size() < memberIndicesB.size();
4640 if (memberAParent)
4641 occludedChildren.push_back(b);
4642 else
4643 occludedChildren.push_back(a);
4644 return memberAParent;
4645 });
4646
4647 // We remove children from the index list that are overshadowed by
4648 // a parent, this prevents us retrieving these as the first or last
4649 // element when the parent is the correct element in these cases.
4650 for (auto v : occludedChildren)
4651 indices.erase(std::remove(indices.begin(), indices.end(), v),
4652 indices.end());
4653}
4654
4655static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo,
4656 bool first) {
4657 ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
4658 // Only 1 member has been mapped, we can return it.
4659 if (indexAttr.size() == 1)
4660 return cast<omp::MapInfoOp>(mapInfo.getMembers()[0].getDefiningOp());
4661 llvm::SmallVector<size_t> indices(indexAttr.size());
4662 std::iota(indices.begin(), indices.end(), 0);
4663 sortMapIndices(indices, mapInfo);
4664 return llvm::cast<omp::MapInfoOp>(
4665 mapInfo.getMembers()[first ? indices.front() : indices.back()]
4666 .getDefiningOp());
4667}
4668
4669/// This function calculates the array/pointer offset for map data provided
4670/// with bounds operations, e.g. when provided something like the following:
4671///
4672/// Fortran
4673/// map(tofrom: array(2:5, 3:2))
4674///
4675/// We must calculate the initial pointer offset to pass across, this function
4676/// performs this using bounds.
4677///
4678/// TODO/WARNING: This only supports Fortran's column major indexing currently
4679/// as is noted in the note below and comments in the function, we must extend
4680/// this function when we add a C++ frontend.
4681/// NOTE: which while specified in row-major order it currently needs to be
4682/// flipped for Fortran's column order array allocation and access (as
4683/// opposed to C++'s row-major, hence the backwards processing where order is
4684/// important). This is likely important to keep in mind for the future when
4685/// we incorporate a C++ frontend, both frontends will need to agree on the
4686/// ordering of generated bounds operations (one may have to flip them) to
4687/// make the below lowering frontend agnostic. The offload size
4688/// calcualtion may also have to be adjusted for C++.
4689static std::vector<llvm::Value *>
4691 llvm::IRBuilderBase &builder, bool isArrayTy,
4692 OperandRange bounds) {
4693 std::vector<llvm::Value *> idx;
4694 // There's no bounds to calculate an offset from, we can safely
4695 // ignore and return no indices.
4696 if (bounds.empty())
4697 return idx;
4698
4699 // If we have an array type, then we have its type so can treat it as a
4700 // normal GEP instruction where the bounds operations are simply indexes
4701 // into the array. We currently do reverse order of the bounds, which
4702 // I believe leans more towards Fortran's column-major in memory.
4703 if (isArrayTy) {
4704 idx.push_back(builder.getInt64(0));
4705 for (int i = bounds.size() - 1; i >= 0; --i) {
4706 if (auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
4707 bounds[i].getDefiningOp())) {
4708 idx.push_back(moduleTranslation.lookupValue(boundOp.getLowerBound()));
4709 }
4710 }
4711 } else {
4712 // If we do not have an array type, but we have bounds, then we're dealing
4713 // with a pointer that's being treated like an array and we have the
4714 // underlying type e.g. an i32, or f64 etc, e.g. a fortran descriptor base
4715 // address (pointer pointing to the actual data) so we must caclulate the
4716 // offset using a single index which the following loop attempts to
4717 // compute using the standard column-major algorithm e.g for a 3D array:
4718 //
4719 // ((((c_idx * b_len) + b_idx) * a_len) + a_idx)
4720 //
4721 // It is of note that it's doing column-major rather than row-major at the
4722 // moment, but having a way for the frontend to indicate which major format
4723 // to use or standardizing/canonicalizing the order of the bounds to compute
4724 // the offset may be useful in the future when there's other frontends with
4725 // different formats.
4726 std::vector<llvm::Value *> dimensionIndexSizeOffset;
4727 for (int i = bounds.size() - 1; i >= 0; --i) {
4728 if (auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
4729 bounds[i].getDefiningOp())) {
4730 if (i == ((int)bounds.size() - 1))
4731 idx.emplace_back(
4732 moduleTranslation.lookupValue(boundOp.getLowerBound()));
4733 else
4734 idx.back() = builder.CreateAdd(
4735 builder.CreateMul(idx.back(), moduleTranslation.lookupValue(
4736 boundOp.getExtent())),
4737 moduleTranslation.lookupValue(boundOp.getLowerBound()));
4738 }
4739 }
4740 }
4741
4742 return idx;
4743}
4744
4746 llvm::transform(values, std::back_inserter(ints), [](Attribute value) {
4747 return cast<IntegerAttr>(value).getInt();
4748 });
4749}
4750
4751// Gathers members that are overlapping in the parent, excluding members that
4752// themselves overlap, keeping the top-most (closest to parents level) map.
4753static void
4755 omp::MapInfoOp parentOp) {
4756 // No members mapped, no overlaps.
4757 if (parentOp.getMembers().empty())
4758 return;
4759
4760 // Single member, we can insert and return early.
4761 if (parentOp.getMembers().size() == 1) {
4762 overlapMapDataIdxs.push_back(0);
4763 return;
4764 }
4765
4766 // 1) collect list of top-level overlapping members from MemberOp
4768 ArrayAttr indexAttr = parentOp.getMembersIndexAttr();
4769 for (auto [memIndex, indicesAttr] : llvm::enumerate(indexAttr))
4770 memberByIndex.push_back(
4771 std::make_pair(memIndex, cast<ArrayAttr>(indicesAttr)));
4772
4773 // Sort the smallest first (higher up the parent -> member chain), so that
4774 // when we remove members, we remove as much as we can in the initial
4775 // iterations, shortening the number of passes required.
4776 llvm::sort(memberByIndex.begin(), memberByIndex.end(),
4777 [&](auto a, auto b) { return a.second.size() < b.second.size(); });
4778
4779 // Remove elements from the vector if there is a parent element that
4780 // supersedes it. i.e. if member [0] is mapped, we can remove members [0,1],
4781 // [0,2].. etc.
4783 for (auto v : memberByIndex) {
4784 llvm::SmallVector<int64_t> vArr(v.second.size());
4785 getAsIntegers(v.second, vArr);
4786 skipList.push_back(
4787 *std::find_if(memberByIndex.begin(), memberByIndex.end(), [&](auto x) {
4788 if (v == x)
4789 return false;
4790 llvm::SmallVector<int64_t> xArr(x.second.size());
4791 getAsIntegers(x.second, xArr);
4792 return std::equal(vArr.begin(), vArr.end(), xArr.begin()) &&
4793 xArr.size() >= vArr.size();
4794 }));
4795 }
4796
4797 // Collect the indices, as we need the base pointer etc. from the MapData
4798 // structure which is primarily accessible via index at the moment.
4799 for (auto v : memberByIndex)
4800 if (find(skipList.begin(), skipList.end(), v) == skipList.end())
4801 overlapMapDataIdxs.push_back(v.first);
4802}
4803
4804// The intent is to verify if the mapped data being passed is a
4805// pointer -> pointee that requires special handling in certain cases,
4806// e.g. applying the OMP_MAP_PTR_AND_OBJ map type.
4807//
4808// There may be a better way to verify this, but unfortunately with
4809// opaque pointers we lose the ability to easily check if something is
4810// a pointer whilst maintaining access to the underlying type.
4811static bool checkIfPointerMap(omp::MapInfoOp mapOp) {
4812 // If we have a varPtrPtr field assigned then the underlying type is a pointer
4813 if (mapOp.getVarPtrPtr())
4814 return true;
4815
4816 // If the map data is declare target with a link clause, then it's represented
4817 // as a pointer when we lower it to LLVM-IR even if at the MLIR level it has
4818 // no relation to pointers.
4819 if (isDeclareTargetLink(mapOp.getVarPtr()))
4820 return true;
4821
4822 return false;
4823}
4824
4825// This creates two insertions into the MapInfosTy data structure for the
4826// "parent" of a set of members, (usually a container e.g.
4827// class/structure/derived type) when subsequent members have also been
4828// explicitly mapped on the same map clause. Certain types, such as Fortran
4829// descriptors are mapped like this as well, however, the members are
4830// implicit as far as a user is concerned, but we must explicitly map them
4831// internally.
4832//
4833// This function also returns the memberOfFlag for this particular parent,
4834// which is utilised in subsequent member mappings (by modifying there map type
4835// with it) to indicate that a member is part of this parent and should be
4836// treated by the runtime as such. Important to achieve the correct mapping.
4837//
4838// This function borrows a lot from Clang's emitCombinedEntry function
4839// inside of CGOpenMPRuntime.cpp
4840static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
4841 LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder,
4842 llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo,
4843 MapInfoData &mapData, uint64_t mapDataIndex,
4844 TargetDirectiveEnumTy targetDirective) {
4845 assert(!ompBuilder.Config.isTargetDevice() &&
4846 "function only supported for host device codegen");
4847
4848 auto parentClause =
4849 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
4850
4851 auto *parentMapper = mapData.Mappers[mapDataIndex];
4852
4853 // Map the first segment of the parent. If a user-defined mapper is attached,
4854 // include the parent's to/from-style bits (and common modifiers) in this
4855 // base entry so the mapper receives correct copy semantics via its 'type'
4856 // parameter. Also keep TARGET_PARAM when required for kernel arguments.
4857 llvm::omp::OpenMPOffloadMappingFlags baseFlag =
4858 (targetDirective == TargetDirectiveEnumTy::Target &&
4859 !mapData.IsDeclareTarget[mapDataIndex])
4860 ? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM
4861 : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
4862
4863 if (parentMapper) {
4864 using mapFlags = llvm::omp::OpenMPOffloadMappingFlags;
4865 // Preserve relevant map-type bits from the parent clause. These include
4866 // the copy direction (TO/FROM), as well as commonly used modifiers that
4867 // should be visible to the mapper for correct behaviour.
4868 mapFlags parentFlags = mapData.Types[mapDataIndex];
4869 mapFlags preserve = mapFlags::OMP_MAP_TO | mapFlags::OMP_MAP_FROM |
4870 mapFlags::OMP_MAP_ALWAYS | mapFlags::OMP_MAP_CLOSE |
4871 mapFlags::OMP_MAP_PRESENT | mapFlags::OMP_MAP_OMPX_HOLD;
4872 baseFlag |= (parentFlags & preserve);
4873 }
4874
4875 combinedInfo.Types.emplace_back(baseFlag);
4876 combinedInfo.DevicePointers.emplace_back(
4877 mapData.DevicePointers[mapDataIndex]);
4878 // Only attach the mapper to the base entry when we are mapping the whole
4879 // parent. Combined/segment entries must not carry a mapper; otherwise the
4880 // mapper can be invoked with a partial size, which is undefined behaviour.
4881 combinedInfo.Mappers.emplace_back(
4882 parentMapper && !parentClause.getPartialMap() ? parentMapper : nullptr);
4883 combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
4884 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
4885 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
4886
4887 // Calculate size of the parent object being mapped based on the
4888 // addresses at runtime, highAddr - lowAddr = size. This of course
4889 // doesn't factor in allocated data like pointers, hence the further
4890 // processing of members specified by users, or in the case of
4891 // Fortran pointers and allocatables, the mapping of the pointed to
4892 // data by the descriptor (which itself, is a structure containing
4893 // runtime information on the dynamically allocated data).
4894 llvm::Value *lowAddr, *highAddr;
4895 if (!parentClause.getPartialMap()) {
4896 lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
4897 builder.getPtrTy());
4898 highAddr = builder.CreatePointerCast(
4899 builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
4900 mapData.Pointers[mapDataIndex], 1),
4901 builder.getPtrTy());
4902 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
4903 } else {
4904 auto mapOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
4905 int firstMemberIdx = getMapDataMemberIdx(
4906 mapData, getFirstOrLastMappedMemberPtr(mapOp, true));
4907 lowAddr = builder.CreatePointerCast(mapData.Pointers[firstMemberIdx],
4908 builder.getPtrTy());
4909 int lastMemberIdx = getMapDataMemberIdx(
4910 mapData, getFirstOrLastMappedMemberPtr(mapOp, false));
4911 highAddr = builder.CreatePointerCast(
4912 builder.CreateGEP(mapData.BaseType[lastMemberIdx],
4913 mapData.Pointers[lastMemberIdx], builder.getInt64(1)),
4914 builder.getPtrTy());
4915 combinedInfo.Pointers.emplace_back(mapData.Pointers[firstMemberIdx]);
4916 }
4917
4918 llvm::Value *size = builder.CreateIntCast(
4919 builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
4920 builder.getInt64Ty(),
4921 /*isSigned=*/false);
4922 combinedInfo.Sizes.push_back(size);
4923
4924 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag =
4925 ompBuilder.getMemberOfFlag(combinedInfo.BasePointers.size() - 1);
4926
4927 // This creates the initial MEMBER_OF mapping that consists of
4928 // the parent/top level container (same as above effectively, except
4929 // with a fixed initial compile time size and separate maptype which
4930 // indicates the true mape type (tofrom etc.). This parent mapping is
4931 // only relevant if the structure in its totality is being mapped,
4932 // otherwise the above suffices.
4933 if (!parentClause.getPartialMap()) {
4934 // TODO: This will need to be expanded to include the whole host of logic
4935 // for the map flags that Clang currently supports (e.g. it should do some
4936 // further case specific flag modifications). For the moment, it handles
4937 // what we support as expected.
4938 llvm::omp::OpenMPOffloadMappingFlags mapFlag = mapData.Types[mapDataIndex];
4939 bool hasMapClose = (llvm::omp::OpenMPOffloadMappingFlags(mapFlag) &
4940 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE) ==
4941 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
4942 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
4943
4944 if (targetDirective == TargetDirectiveEnumTy::TargetUpdate || hasMapClose) {
4945 combinedInfo.Types.emplace_back(mapFlag);
4946 combinedInfo.DevicePointers.emplace_back(
4947 mapData.DevicePointers[mapDataIndex]);
4948 combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
4949 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
4950 combinedInfo.BasePointers.emplace_back(
4951 mapData.BasePointers[mapDataIndex]);
4952 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
4953 combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIndex]);
4954 combinedInfo.Mappers.emplace_back(nullptr);
4955 } else {
4956 llvm::SmallVector<size_t> overlapIdxs;
4957 // Find all of the members that "overlap", i.e. occlude other members that
4958 // were mapped alongside the parent, e.g. member [0], occludes [0,1] and
4959 // [0,2], but not [1,0].
4960 getOverlappedMembers(overlapIdxs, parentClause);
4961 // We need to make sure the overlapped members are sorted in order of
4962 // lowest address to highest address.
4963 sortMapIndices(overlapIdxs, parentClause);
4964
4965 lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
4966 builder.getPtrTy());
4967 highAddr = builder.CreatePointerCast(
4968 builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
4969 mapData.Pointers[mapDataIndex], 1),
4970 builder.getPtrTy());
4971
4972 // TODO: We may want to skip arrays/array sections in this as Clang does.
4973 // It appears to be an optimisation rather than a necessity though,
4974 // but this requires further investigation. However, we would have to make
4975 // sure to not exclude maps with bounds that ARE pointers, as these are
4976 // processed as separate components, i.e. pointer + data.
4977 for (auto v : overlapIdxs) {
4978 auto mapDataOverlapIdx = getMapDataMemberIdx(
4979 mapData,
4980 cast<omp::MapInfoOp>(parentClause.getMembers()[v].getDefiningOp()));
4981 combinedInfo.Types.emplace_back(mapFlag);
4982 combinedInfo.DevicePointers.emplace_back(
4983 mapData.DevicePointers[mapDataOverlapIdx]);
4984 combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
4985 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
4986 combinedInfo.BasePointers.emplace_back(
4987 mapData.BasePointers[mapDataIndex]);
4988 combinedInfo.Mappers.emplace_back(nullptr);
4989 combinedInfo.Pointers.emplace_back(lowAddr);
4990 combinedInfo.Sizes.emplace_back(builder.CreateIntCast(
4991 builder.CreatePtrDiff(builder.getInt8Ty(),
4992 mapData.OriginalValue[mapDataOverlapIdx],
4993 lowAddr),
4994 builder.getInt64Ty(), /*isSigned=*/true));
4995 lowAddr = builder.CreateConstGEP1_32(
4996 checkIfPointerMap(llvm::cast<omp::MapInfoOp>(
4997 mapData.MapClause[mapDataOverlapIdx]))
4998 ? builder.getPtrTy()
4999 : mapData.BaseType[mapDataOverlapIdx],
5000 mapData.BasePointers[mapDataOverlapIdx], 1);
5001 }
5002
5003 combinedInfo.Types.emplace_back(mapFlag);
5004 combinedInfo.DevicePointers.emplace_back(
5005 mapData.DevicePointers[mapDataIndex]);
5006 combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
5007 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
5008 combinedInfo.BasePointers.emplace_back(
5009 mapData.BasePointers[mapDataIndex]);
5010 combinedInfo.Mappers.emplace_back(nullptr);
5011 combinedInfo.Pointers.emplace_back(lowAddr);
5012 combinedInfo.Sizes.emplace_back(builder.CreateIntCast(
5013 builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
5014 builder.getInt64Ty(), true));
5015 }
5016 }
5017 return memberOfFlag;
5018}
5019
5020// This function is intended to add explicit mappings of members
5022 LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder,
5023 llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo,
5024 MapInfoData &mapData, uint64_t mapDataIndex,
5025 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag,
5026 TargetDirectiveEnumTy targetDirective) {
5027 assert(!ompBuilder.Config.isTargetDevice() &&
5028 "function only supported for host device codegen");
5029
5030 auto parentClause =
5031 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
5032
5033 for (auto mappedMembers : parentClause.getMembers()) {
5034 auto memberClause =
5035 llvm::cast<omp::MapInfoOp>(mappedMembers.getDefiningOp());
5036 int memberDataIdx = getMapDataMemberIdx(mapData, memberClause);
5037
5038 assert(memberDataIdx >= 0 && "could not find mapped member of structure");
5039
5040 // If we're currently mapping a pointer to a block of data, we must
5041 // initially map the pointer, and then attatch/bind the data with a
5042 // subsequent map to the pointer. This segment of code generates the
5043 // pointer mapping, which can in certain cases be optimised out as Clang
5044 // currently does in its lowering. However, for the moment we do not do so,
5045 // in part as we currently have substantially less information on the data
5046 // being mapped at this stage.
5047 if (checkIfPointerMap(memberClause)) {
5048 auto mapFlag = convertClauseMapFlags(memberClause.getMapType());
5049 mapFlag &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
5050 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
5051 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
5052 combinedInfo.Types.emplace_back(mapFlag);
5053 combinedInfo.DevicePointers.emplace_back(
5054 llvm::OpenMPIRBuilder::DeviceInfoTy::None);
5055 combinedInfo.Mappers.emplace_back(nullptr);
5056 combinedInfo.Names.emplace_back(
5057 LLVM::createMappingInformation(memberClause.getLoc(), ompBuilder));
5058 combinedInfo.BasePointers.emplace_back(
5059 mapData.BasePointers[mapDataIndex]);
5060 combinedInfo.Pointers.emplace_back(mapData.BasePointers[memberDataIdx]);
5061 combinedInfo.Sizes.emplace_back(builder.getInt64(
5062 moduleTranslation.getLLVMModule()->getDataLayout().getPointerSize()));
5063 }
5064
5065 // Same MemberOfFlag to indicate its link with parent and other members
5066 // of.
5067 auto mapFlag = convertClauseMapFlags(memberClause.getMapType());
5068 mapFlag &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
5069 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
5070 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
5071 bool isDeclTargetTo = isDeclareTargetTo(parentClause.getVarPtr()
5072 ? parentClause.getVarPtr()
5073 : parentClause.getVarPtrPtr());
5074 if (checkIfPointerMap(memberClause) &&
5075 (!isDeclTargetTo ||
5076 (targetDirective != TargetDirectiveEnumTy::TargetUpdate &&
5077 targetDirective != TargetDirectiveEnumTy::TargetData))) {
5078 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
5079 }
5080
5081 combinedInfo.Types.emplace_back(mapFlag);
5082 combinedInfo.DevicePointers.emplace_back(
5083 mapData.DevicePointers[memberDataIdx]);
5084 combinedInfo.Mappers.emplace_back(mapData.Mappers[memberDataIdx]);
5085 combinedInfo.Names.emplace_back(
5086 LLVM::createMappingInformation(memberClause.getLoc(), ompBuilder));
5087 uint64_t basePointerIndex =
5088 checkIfPointerMap(memberClause) ? memberDataIdx : mapDataIndex;
5089 combinedInfo.BasePointers.emplace_back(
5090 mapData.BasePointers[basePointerIndex]);
5091 combinedInfo.Pointers.emplace_back(mapData.Pointers[memberDataIdx]);
5092
5093 llvm::Value *size = mapData.Sizes[memberDataIdx];
5094 if (checkIfPointerMap(memberClause)) {
5095 size = builder.CreateSelect(
5096 builder.CreateIsNull(mapData.Pointers[memberDataIdx]),
5097 builder.getInt64(0), size);
5098 }
5099
5100 combinedInfo.Sizes.emplace_back(size);
5101 }
5102}
5103
5104static void processIndividualMap(MapInfoData &mapData, size_t mapDataIdx,
5105 MapInfosTy &combinedInfo,
5106 TargetDirectiveEnumTy targetDirective,
5107 int mapDataParentIdx = -1) {
5108 // Declare Target Mappings are excluded from being marked as
5109 // OMP_MAP_TARGET_PARAM as they are not passed as parameters, they're
5110 // marked with OMP_MAP_PTR_AND_OBJ instead.
5111 auto mapFlag = mapData.Types[mapDataIdx];
5112 auto mapInfoOp = llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIdx]);
5113
5114 bool isPtrTy = checkIfPointerMap(mapInfoOp);
5115 if (isPtrTy)
5116 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
5117
5118 if (targetDirective == TargetDirectiveEnumTy::Target &&
5119 !mapData.IsDeclareTarget[mapDataIdx])
5120 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
5121
5122 if (mapInfoOp.getMapCaptureType() == omp::VariableCaptureKind::ByCopy &&
5123 !isPtrTy)
5124 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
5125
5126 // if we're provided a mapDataParentIdx, then the data being mapped is
5127 // part of a larger object (in a parent <-> member mapping) and in this
5128 // case our BasePointer should be the parent.
5129 if (mapDataParentIdx >= 0)
5130 combinedInfo.BasePointers.emplace_back(
5131 mapData.BasePointers[mapDataParentIdx]);
5132 else
5133 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIdx]);
5134
5135 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIdx]);
5136 combinedInfo.DevicePointers.emplace_back(mapData.DevicePointers[mapDataIdx]);
5137 combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIdx]);
5138 combinedInfo.Names.emplace_back(mapData.Names[mapDataIdx]);
5139 combinedInfo.Types.emplace_back(mapFlag);
5140 combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIdx]);
5141}
5142
5144 llvm::IRBuilderBase &builder,
5145 llvm::OpenMPIRBuilder &ompBuilder,
5146 DataLayout &dl, MapInfosTy &combinedInfo,
5147 MapInfoData &mapData, uint64_t mapDataIndex,
5148 TargetDirectiveEnumTy targetDirective) {
5149 assert(!ompBuilder.Config.isTargetDevice() &&
5150 "function only supported for host device codegen");
5151
5152 auto parentClause =
5153 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
5154
5155 // If we have a partial map (no parent referenced in the map clauses of the
5156 // directive, only members) and only a single member, we do not need to bind
5157 // the map of the member to the parent, we can pass the member separately.
5158 if (parentClause.getMembers().size() == 1 && parentClause.getPartialMap()) {
5159 auto memberClause = llvm::cast<omp::MapInfoOp>(
5160 parentClause.getMembers()[0].getDefiningOp());
5161 int memberDataIdx = getMapDataMemberIdx(mapData, memberClause);
5162 // Note: Clang treats arrays with explicit bounds that fall into this
5163 // category as a parent with map case, however, it seems this isn't a
5164 // requirement, and processing them as an individual map is fine. So,
5165 // we will handle them as individual maps for the moment, as it's
5166 // difficult for us to check this as we always require bounds to be
5167 // specified currently and it's also marginally more optimal (single
5168 // map rather than two). The difference may come from the fact that
5169 // Clang maps array without bounds as pointers (which we do not
5170 // currently do), whereas we treat them as arrays in all cases
5171 // currently.
5172 processIndividualMap(mapData, memberDataIdx, combinedInfo, targetDirective,
5173 mapDataIndex);
5174 return;
5175 }
5176
5177 llvm::omp::OpenMPOffloadMappingFlags memberOfParentFlag =
5178 mapParentWithMembers(moduleTranslation, builder, ompBuilder, dl,
5179 combinedInfo, mapData, mapDataIndex,
5180 targetDirective);
5181 processMapMembersWithParent(moduleTranslation, builder, ompBuilder, dl,
5182 combinedInfo, mapData, mapDataIndex,
5183 memberOfParentFlag, targetDirective);
5184}
5185
5186// This is a variation on Clang's GenerateOpenMPCapturedVars, which
5187// generates different operation (e.g. load/store) combinations for
5188// arguments to the kernel, based on map capture kinds which are then
5189// utilised in the combinedInfo in place of the original Map value.
5190static void
5191createAlteredByCaptureMap(MapInfoData &mapData,
5192 LLVM::ModuleTranslation &moduleTranslation,
5193 llvm::IRBuilderBase &builder) {
5194 assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
5195 "function only supported for host device codegen");
5196 for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
5197 // if it's declare target, skip it, it's handled separately.
5198 if (!mapData.IsDeclareTarget[i]) {
5199 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
5200 omp::VariableCaptureKind captureKind = mapOp.getMapCaptureType();
5201 bool isPtrTy = checkIfPointerMap(mapOp);
5202
5203 // Currently handles array sectioning lowerbound case, but more
5204 // logic may be required in the future. Clang invokes EmitLValue,
5205 // which has specialised logic for special Clang types such as user
5206 // defines, so it is possible we will have to extend this for
5207 // structures or other complex types. As the general idea is that this
5208 // function mimics some of the logic from Clang that we require for
5209 // kernel argument passing from host -> device.
5210 switch (captureKind) {
5211 case omp::VariableCaptureKind::ByRef: {
5212 llvm::Value *newV = mapData.Pointers[i];
5213 std::vector<llvm::Value *> offsetIdx = calculateBoundsOffset(
5214 moduleTranslation, builder, mapData.BaseType[i]->isArrayTy(),
5215 mapOp.getBounds());
5216 if (isPtrTy)
5217 newV = builder.CreateLoad(builder.getPtrTy(), newV);
5218
5219 if (!offsetIdx.empty())
5220 newV = builder.CreateInBoundsGEP(mapData.BaseType[i], newV, offsetIdx,
5221 "array_offset");
5222 mapData.Pointers[i] = newV;
5223 } break;
5224 case omp::VariableCaptureKind::ByCopy: {
5225 llvm::Type *type = mapData.BaseType[i];
5226 llvm::Value *newV;
5227 if (mapData.Pointers[i]->getType()->isPointerTy())
5228 newV = builder.CreateLoad(type, mapData.Pointers[i]);
5229 else
5230 newV = mapData.Pointers[i];
5231
5232 if (!isPtrTy) {
5233 auto curInsert = builder.saveIP();
5234 llvm::DebugLoc DbgLoc = builder.getCurrentDebugLocation();
5235 builder.restoreIP(findAllocaInsertPoint(builder, moduleTranslation));
5236 auto *memTempAlloc =
5237 builder.CreateAlloca(builder.getPtrTy(), nullptr, ".casted");
5238 builder.SetCurrentDebugLocation(DbgLoc);
5239 builder.restoreIP(curInsert);
5240
5241 builder.CreateStore(newV, memTempAlloc);
5242 newV = builder.CreateLoad(builder.getPtrTy(), memTempAlloc);
5243 }
5244
5245 mapData.Pointers[i] = newV;
5246 mapData.BasePointers[i] = newV;
5247 } break;
5248 case omp::VariableCaptureKind::This:
5249 case omp::VariableCaptureKind::VLAType:
5250 mapData.MapClause[i]->emitOpError("Unhandled capture kind");
5251 break;
5252 }
5253 }
5254 }
5255}
5256
5257// Generate all map related information and fill the combinedInfo.
5258static void genMapInfos(llvm::IRBuilderBase &builder,
5259 LLVM::ModuleTranslation &moduleTranslation,
5260 DataLayout &dl, MapInfosTy &combinedInfo,
5261 MapInfoData &mapData,
5262 TargetDirectiveEnumTy targetDirective) {
5263 assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
5264 "function only supported for host device codegen");
5265 // We wish to modify some of the methods in which arguments are
5266 // passed based on their capture type by the target region, this can
5267 // involve generating new loads and stores, which changes the
5268 // MLIR value to LLVM value mapping, however, we only wish to do this
5269 // locally for the current function/target and also avoid altering
5270 // ModuleTranslation, so we remap the base pointer or pointer stored
5271 // in the map infos corresponding MapInfoData, which is later accessed
5272 // by genMapInfos and createTarget to help generate the kernel and
5273 // kernel arg structure. It primarily becomes relevant in cases like
5274 // bycopy, or byref range'd arrays. In the default case, we simply
5275 // pass thee pointer byref as both basePointer and pointer.
5276 createAlteredByCaptureMap(mapData, moduleTranslation, builder);
5277
5278 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
5279
5280 // We operate under the assumption that all vectors that are
5281 // required in MapInfoData are of equal lengths (either filled with
5282 // default constructed data or appropiate information) so we can
5283 // utilise the size from any component of MapInfoData, if we can't
5284 // something is missing from the initial MapInfoData construction.
5285 for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
5286 // NOTE/TODO: We currently do not support arbitrary depth record
5287 // type mapping.
5288 if (mapData.IsAMember[i])
5289 continue;
5290
5291 auto mapInfoOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[i]);
5292 if (!mapInfoOp.getMembers().empty()) {
5293 processMapWithMembersOf(moduleTranslation, builder, *ompBuilder, dl,
5294 combinedInfo, mapData, i, targetDirective);
5295 continue;
5296 }
5297
5298 processIndividualMap(mapData, i, combinedInfo, targetDirective);
5299 }
5300}
5301
5302static llvm::Expected<llvm::Function *>
5303emitUserDefinedMapper(Operation *declMapperOp, llvm::IRBuilderBase &builder,
5304 LLVM::ModuleTranslation &moduleTranslation,
5305 llvm::StringRef mapperFuncName,
5306 TargetDirectiveEnumTy targetDirective);
5307
5308static llvm::Expected<llvm::Function *>
5309getOrCreateUserDefinedMapperFunc(Operation *op, llvm::IRBuilderBase &builder,
5310 LLVM::ModuleTranslation &moduleTranslation,
5311 TargetDirectiveEnumTy targetDirective) {
5312 assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
5313 "function only supported for host device codegen");
5314 auto declMapperOp = cast<omp::DeclareMapperOp>(op);
5315 std::string mapperFuncName =
5316 moduleTranslation.getOpenMPBuilder()->createPlatformSpecificName(
5317 {"omp_mapper", declMapperOp.getSymName()});
5318
5319 if (auto *lookupFunc = moduleTranslation.lookupFunction(mapperFuncName))
5320 return lookupFunc;
5321
5322 // Recursive types can cause re-entrant mapper emission. The mapper function
5323 // is created by OpenMPIRBuilder before the callbacks run, so it may already
5324 // exist in the LLVM module even though it is not yet registered in the
5325 // ModuleTranslation mapping table. Reuse and register it to break the
5326 // recursion.
5327 if (llvm::Function *existingFunc =
5328 moduleTranslation.getLLVMModule()->getFunction(mapperFuncName)) {
5329 moduleTranslation.mapFunction(mapperFuncName, existingFunc);
5330 return existingFunc;
5331 }
5332
5333 return emitUserDefinedMapper(declMapperOp, builder, moduleTranslation,
5334 mapperFuncName, targetDirective);
5335}
5336
5337static llvm::Expected<llvm::Function *>
5338emitUserDefinedMapper(Operation *op, llvm::IRBuilderBase &builder,
5339 LLVM::ModuleTranslation &moduleTranslation,
5340 llvm::StringRef mapperFuncName,
5341 TargetDirectiveEnumTy targetDirective) {
5342 assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
5343 "function only supported for host device codegen");
5344 auto declMapperOp = cast<omp::DeclareMapperOp>(op);
5345 auto declMapperInfoOp = declMapperOp.getDeclareMapperInfo();
5346 DataLayout dl = DataLayout(declMapperOp->getParentOfType<ModuleOp>());
5347 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
5348 llvm::Type *varType = moduleTranslation.convertType(declMapperOp.getType());
5349 SmallVector<Value> mapVars = declMapperInfoOp.getMapVars();
5350
5351 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
5352
5353 // Fill up the arrays with all the mapped variables.
5354 MapInfosTy combinedInfo;
5355 auto genMapInfoCB =
5356 [&](InsertPointTy codeGenIP, llvm::Value *ptrPHI,
5357 llvm::Value *unused2) -> llvm::OpenMPIRBuilder::MapInfosOrErrorTy {
5358 builder.restoreIP(codeGenIP);
5359 moduleTranslation.mapValue(declMapperOp.getSymVal(), ptrPHI);
5360 moduleTranslation.mapBlock(&declMapperOp.getRegion().front(),
5361 builder.GetInsertBlock());
5362 if (failed(moduleTranslation.convertBlock(declMapperOp.getRegion().front(),
5363 /*ignoreArguments=*/true,
5364 builder)))
5365 return llvm::make_error<PreviouslyReportedError>();
5366 MapInfoData mapData;
5367 collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, dl,
5368 builder);
5369 genMapInfos(builder, moduleTranslation, dl, combinedInfo, mapData,
5370 targetDirective);
5371
5372 // Drop the mapping that is no longer necessary so that the same region
5373 // can be processed multiple times.
5374 moduleTranslation.forgetMapping(declMapperOp.getRegion());
5375 return combinedInfo;
5376 };
5377
5378 auto customMapperCB = [&](unsigned i) -> llvm::Expected<llvm::Function *> {
5379 if (!combinedInfo.Mappers[i])
5380 return nullptr;
5381 return getOrCreateUserDefinedMapperFunc(combinedInfo.Mappers[i], builder,
5382 moduleTranslation, targetDirective);
5383 };
5384
5385 llvm::Expected<llvm::Function *> newFn = ompBuilder->emitUserDefinedMapper(
5386 genMapInfoCB, varType, mapperFuncName, customMapperCB);
5387 if (!newFn)
5388 return newFn.takeError();
5389 if (llvm::Function *mappedFunc =
5390 moduleTranslation.lookupFunction(mapperFuncName)) {
5391 assert(mappedFunc == *newFn &&
5392 "mapper function mapping disagrees with emitted function");
5393 } else {
5394 moduleTranslation.mapFunction(mapperFuncName, *newFn);
5395 }
5396 return *newFn;
5397}
5398
5399static LogicalResult
5400convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
5401 LLVM::ModuleTranslation &moduleTranslation) {
5402 llvm::Value *ifCond = nullptr;
5403 llvm::Value *deviceID = builder.getInt64(llvm::omp::OMP_DEVICEID_UNDEF);
5404 SmallVector<Value> mapVars;
5405 SmallVector<Value> useDevicePtrVars;
5406 SmallVector<Value> useDeviceAddrVars;
5407 llvm::omp::RuntimeFunction RTLFn;
5408 DataLayout DL = DataLayout(op->getParentOfType<ModuleOp>());
5409 TargetDirectiveEnumTy targetDirective = getTargetDirectiveEnumTyFromOp(op);
5410
5411 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
5412 llvm::OpenMPIRBuilder::TargetDataInfo info(
5413 /*RequiresDevicePointerInfo=*/true,
5414 /*SeparateBeginEndCalls=*/true);
5415 assert(!ompBuilder->Config.isTargetDevice() &&
5416 "target data/enter/exit/update are host ops");
5417 bool isOffloadEntry = !ompBuilder->Config.TargetTriples.empty();
5418
5419 auto getDeviceID = [&](mlir::Value dev) -> llvm::Value * {
5420 llvm::Value *v = moduleTranslation.lookupValue(dev);
5421 return builder.CreateIntCast(v, builder.getInt64Ty(), /*isSigned=*/true);
5422 };
5423
5424 LogicalResult result =
5426 .Case([&](omp::TargetDataOp dataOp) {
5427 if (failed(checkImplementationStatus(*dataOp)))
5428 return failure();
5429
5430 if (auto ifVar = dataOp.getIfExpr())
5431 ifCond = moduleTranslation.lookupValue(ifVar);
5432
5433 if (mlir::Value devId = dataOp.getDevice())
5434 deviceID = getDeviceID(devId);
5435
5436 mapVars = dataOp.getMapVars();
5437 useDevicePtrVars = dataOp.getUseDevicePtrVars();
5438 useDeviceAddrVars = dataOp.getUseDeviceAddrVars();
5439 return success();
5440 })
5441 .Case([&](omp::TargetEnterDataOp enterDataOp) -> LogicalResult {
5442 if (failed(checkImplementationStatus(*enterDataOp)))
5443 return failure();
5444
5445 if (auto ifVar = enterDataOp.getIfExpr())
5446 ifCond = moduleTranslation.lookupValue(ifVar);
5447
5448 if (mlir::Value devId = enterDataOp.getDevice())
5449 deviceID = getDeviceID(devId);
5450
5451 RTLFn =
5452 enterDataOp.getNowait()
5453 ? llvm::omp::OMPRTL___tgt_target_data_begin_nowait_mapper
5454 : llvm::omp::OMPRTL___tgt_target_data_begin_mapper;
5455 mapVars = enterDataOp.getMapVars();
5456 info.HasNoWait = enterDataOp.getNowait();
5457 return success();
5458 })
5459 .Case([&](omp::TargetExitDataOp exitDataOp) -> LogicalResult {
5460 if (failed(checkImplementationStatus(*exitDataOp)))
5461 return failure();
5462
5463 if (auto ifVar = exitDataOp.getIfExpr())
5464 ifCond = moduleTranslation.lookupValue(ifVar);
5465
5466 if (mlir::Value devId = exitDataOp.getDevice())
5467 deviceID = getDeviceID(devId);
5468
5469 RTLFn = exitDataOp.getNowait()
5470 ? llvm::omp::OMPRTL___tgt_target_data_end_nowait_mapper
5471 : llvm::omp::OMPRTL___tgt_target_data_end_mapper;
5472 mapVars = exitDataOp.getMapVars();
5473 info.HasNoWait = exitDataOp.getNowait();
5474 return success();
5475 })
5476 .Case([&](omp::TargetUpdateOp updateDataOp) -> LogicalResult {
5477 if (failed(checkImplementationStatus(*updateDataOp)))
5478 return failure();
5479
5480 if (auto ifVar = updateDataOp.getIfExpr())
5481 ifCond = moduleTranslation.lookupValue(ifVar);
5482
5483 if (mlir::Value devId = updateDataOp.getDevice())
5484 deviceID = getDeviceID(devId);
5485
5486 RTLFn =
5487 updateDataOp.getNowait()
5488 ? llvm::omp::OMPRTL___tgt_target_data_update_nowait_mapper
5489 : llvm::omp::OMPRTL___tgt_target_data_update_mapper;
5490 mapVars = updateDataOp.getMapVars();
5491 info.HasNoWait = updateDataOp.getNowait();
5492 return success();
5493 })
5494 .DefaultUnreachable("unexpected operation");
5495
5496 if (failed(result))
5497 return failure();
5498 // Pretend we have IF(false) if we're not doing offload.
5499 if (!isOffloadEntry)
5500 ifCond = builder.getFalse();
5501
5502 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
5503 MapInfoData mapData;
5504 collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, DL,
5505 builder, useDevicePtrVars, useDeviceAddrVars);
5506
5507 // Fill up the arrays with all the mapped variables.
5508 MapInfosTy combinedInfo;
5509 auto genMapInfoCB = [&](InsertPointTy codeGenIP) -> MapInfosTy & {
5510 builder.restoreIP(codeGenIP);
5511 genMapInfos(builder, moduleTranslation, DL, combinedInfo, mapData,
5512 targetDirective);
5513 return combinedInfo;
5514 };
5515
5516 // Define a lambda to apply mappings between use_device_addr and
5517 // use_device_ptr base pointers, and their associated block arguments.
5518 auto mapUseDevice =
5519 [&moduleTranslation](
5520 llvm::OpenMPIRBuilder::DeviceInfoTy type,
5522 llvm::SmallVectorImpl<Value> &useDeviceVars, MapInfoData &mapInfoData,
5523 llvm::function_ref<llvm::Value *(llvm::Value *)> mapper = nullptr) {
5524 for (auto [arg, useDevVar] :
5525 llvm::zip_equal(blockArgs, useDeviceVars)) {
5526
5527 auto getMapBasePtr = [](omp::MapInfoOp mapInfoOp) {
5528 return mapInfoOp.getVarPtrPtr() ? mapInfoOp.getVarPtrPtr()
5529 : mapInfoOp.getVarPtr();
5530 };
5531
5532 auto useDevMap = cast<omp::MapInfoOp>(useDevVar.getDefiningOp());
5533 for (auto [mapClause, devicePointer, basePointer] : llvm::zip_equal(
5534 mapInfoData.MapClause, mapInfoData.DevicePointers,
5535 mapInfoData.BasePointers)) {
5536 auto mapOp = cast<omp::MapInfoOp>(mapClause);
5537 if (getMapBasePtr(mapOp) != getMapBasePtr(useDevMap) ||
5538 devicePointer != type)
5539 continue;
5540
5541 if (llvm::Value *devPtrInfoMap =
5542 mapper ? mapper(basePointer) : basePointer) {
5543 moduleTranslation.mapValue(arg, devPtrInfoMap);
5544 break;
5545 }
5546 }
5547 }
5548 };
5549
5550 using BodyGenTy = llvm::OpenMPIRBuilder::BodyGenTy;
5551 auto bodyGenCB = [&](InsertPointTy codeGenIP, BodyGenTy bodyGenType)
5552 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
5553 // We must always restoreIP regardless of doing anything the caller
5554 // does not restore it, leading to incorrect (no) branch generation.
5555 builder.restoreIP(codeGenIP);
5556 assert(isa<omp::TargetDataOp>(op) &&
5557 "BodyGen requested for non TargetDataOp");
5558 auto blockArgIface = cast<omp::BlockArgOpenMPOpInterface>(op);
5559 Region &region = cast<omp::TargetDataOp>(op).getRegion();
5560 switch (bodyGenType) {
5561 case BodyGenTy::Priv:
5562 // Check if any device ptr/addr info is available
5563 if (!info.DevicePtrInfoMap.empty()) {
5564 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
5565 blockArgIface.getUseDeviceAddrBlockArgs(),
5566 useDeviceAddrVars, mapData,
5567 [&](llvm::Value *basePointer) -> llvm::Value * {
5568 if (!info.DevicePtrInfoMap[basePointer].second)
5569 return nullptr;
5570 return builder.CreateLoad(
5571 builder.getPtrTy(),
5572 info.DevicePtrInfoMap[basePointer].second);
5573 });
5574 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
5575 blockArgIface.getUseDevicePtrBlockArgs(), useDevicePtrVars,
5576 mapData, [&](llvm::Value *basePointer) {
5577 return info.DevicePtrInfoMap[basePointer].second;
5578 });
5579
5580 if (failed(inlineConvertOmpRegions(region, "omp.data.region", builder,
5581 moduleTranslation)))
5582 return llvm::make_error<PreviouslyReportedError>();
5583 }
5584 break;
5585 case BodyGenTy::DupNoPriv:
5586 if (info.DevicePtrInfoMap.empty()) {
5587 // For host device we still need to do the mapping for codegen,
5588 // otherwise it may try to lookup a missing value.
5589 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
5590 blockArgIface.getUseDeviceAddrBlockArgs(),
5591 useDeviceAddrVars, mapData);
5592 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
5593 blockArgIface.getUseDevicePtrBlockArgs(), useDevicePtrVars,
5594 mapData);
5595 }
5596 break;
5597 case BodyGenTy::NoPriv:
5598 // If device info is available then region has already been generated
5599 if (info.DevicePtrInfoMap.empty()) {
5600 if (failed(inlineConvertOmpRegions(region, "omp.data.region", builder,
5601 moduleTranslation)))
5602 return llvm::make_error<PreviouslyReportedError>();
5603 }
5604 break;
5605 }
5606 return builder.saveIP();
5607 };
5608
5609 auto customMapperCB =
5610 [&](unsigned int i) -> llvm::Expected<llvm::Function *> {
5611 if (!combinedInfo.Mappers[i])
5612 return nullptr;
5613 info.HasMapper = true;
5614 return getOrCreateUserDefinedMapperFunc(combinedInfo.Mappers[i], builder,
5615 moduleTranslation, targetDirective);
5616 };
5617
5618 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
5619 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
5620 findAllocaInsertPoint(builder, moduleTranslation);
5621 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = [&]() {
5622 if (isa<omp::TargetDataOp>(op))
5623 return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(),
5624 deviceID, ifCond, info, genMapInfoCB,
5625 customMapperCB,
5626 /*MapperFunc=*/nullptr, bodyGenCB,
5627 /*DeviceAddrCB=*/nullptr);
5628 return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(),
5629 deviceID, ifCond, info, genMapInfoCB,
5630 customMapperCB, &RTLFn);
5631 }();
5632
5633 if (failed(handleError(afterIP, *op)))
5634 return failure();
5635
5636 builder.restoreIP(*afterIP);
5637 return success();
5638}
5639
5640static LogicalResult
5641convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder,
5642 LLVM::ModuleTranslation &moduleTranslation) {
5643 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
5644 auto distributeOp = cast<omp::DistributeOp>(opInst);
5645 if (failed(checkImplementationStatus(opInst)))
5646 return failure();
5647
5648 /// Process teams op reduction in distribute if the reduction is contained in
5649 /// the distribute op.
5650 omp::TeamsOp teamsOp = opInst.getParentOfType<omp::TeamsOp>();
5651 bool doDistributeReduction =
5652 teamsOp ? teamsReductionContainedInDistribute(teamsOp) : false;
5653
5654 DenseMap<Value, llvm::Value *> reductionVariableMap;
5655 unsigned numReductionVars = teamsOp ? teamsOp.getNumReductionVars() : 0;
5657 SmallVector<llvm::Value *> privateReductionVariables(numReductionVars);
5658 llvm::ArrayRef<bool> isByRef;
5659
5660 if (doDistributeReduction) {
5661 isByRef = getIsByRef(teamsOp.getReductionByref());
5662 assert(isByRef.size() == teamsOp.getNumReductionVars());
5663
5664 collectReductionDecls(teamsOp, reductionDecls);
5665 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
5666 findAllocaInsertPoint(builder, moduleTranslation);
5667
5668 MutableArrayRef<BlockArgument> reductionArgs =
5669 llvm::cast<omp::BlockArgOpenMPOpInterface>(*teamsOp)
5670 .getReductionBlockArgs();
5671
5673 teamsOp, reductionArgs, builder, moduleTranslation, allocaIP,
5674 reductionDecls, privateReductionVariables, reductionVariableMap,
5675 isByRef)))
5676 return failure();
5677 }
5678
5679 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
5680 auto bodyGenCB = [&](InsertPointTy allocaIP,
5681 InsertPointTy codeGenIP) -> llvm::Error {
5682 // Save the alloca insertion point on ModuleTranslation stack for use in
5683 // nested regions.
5685 moduleTranslation, allocaIP);
5686
5687 // DistributeOp has only one region associated with it.
5688 builder.restoreIP(codeGenIP);
5689 PrivateVarsInfo privVarsInfo(distributeOp);
5690
5692 allocatePrivateVars(builder, moduleTranslation, privVarsInfo, allocaIP);
5693 if (handleError(afterAllocas, opInst).failed())
5694 return llvm::make_error<PreviouslyReportedError>();
5695
5696 if (handleError(initPrivateVars(builder, moduleTranslation, privVarsInfo),
5697 opInst)
5698 .failed())
5699 return llvm::make_error<PreviouslyReportedError>();
5700
5701 if (failed(copyFirstPrivateVars(
5702 distributeOp, builder, moduleTranslation, privVarsInfo.mlirVars,
5703 privVarsInfo.llvmVars, privVarsInfo.privatizers,
5704 distributeOp.getPrivateNeedsBarrier())))
5705 return llvm::make_error<PreviouslyReportedError>();
5706
5707 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
5708 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
5710 convertOmpOpRegions(distributeOp.getRegion(), "omp.distribute.region",
5711 builder, moduleTranslation);
5712 if (!regionBlock)
5713 return regionBlock.takeError();
5714 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
5715
5716 // Skip applying a workshare loop below when translating 'distribute
5717 // parallel do' (it's been already handled by this point while translating
5718 // the nested omp.wsloop).
5719 if (!isa_and_present<omp::WsloopOp>(distributeOp.getNestedWrapper())) {
5720 // TODO: Add support for clauses which are valid for DISTRIBUTE
5721 // constructs. Static schedule is the default.
5722 bool hasDistSchedule = distributeOp.getDistScheduleStatic();
5723 auto schedule = hasDistSchedule ? omp::ClauseScheduleKind::Distribute
5724 : omp::ClauseScheduleKind::Static;
5725 // dist_schedule clauses are ordered - otherise this should be false
5726 bool isOrdered = hasDistSchedule;
5727 std::optional<omp::ScheduleModifier> scheduleMod;
5728 bool isSimd = false;
5729 llvm::omp::WorksharingLoopType workshareLoopType =
5730 llvm::omp::WorksharingLoopType::DistributeStaticLoop;
5731 bool loopNeedsBarrier = false;
5732 llvm::Value *chunk = moduleTranslation.lookupValue(
5733 distributeOp.getDistScheduleChunkSize());
5734 llvm::CanonicalLoopInfo *loopInfo =
5735 findCurrentLoopInfo(moduleTranslation);
5736 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
5737 ompBuilder->applyWorkshareLoop(
5738 ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
5739 convertToScheduleKind(schedule), chunk, isSimd,
5740 scheduleMod == omp::ScheduleModifier::monotonic,
5741 scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
5742 workshareLoopType, false, hasDistSchedule, chunk);
5743
5744 if (!wsloopIP)
5745 return wsloopIP.takeError();
5746 }
5747 if (failed(cleanupPrivateVars(builder, moduleTranslation,
5748 distributeOp.getLoc(), privVarsInfo.llvmVars,
5749 privVarsInfo.privatizers)))
5750 return llvm::make_error<PreviouslyReportedError>();
5751
5752 return llvm::Error::success();
5753 };
5754
5755 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
5756 findAllocaInsertPoint(builder, moduleTranslation);
5757 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
5758 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
5759 ompBuilder->createDistribute(ompLoc, allocaIP, bodyGenCB);
5760
5761 if (failed(handleError(afterIP, opInst)))
5762 return failure();
5763
5764 builder.restoreIP(*afterIP);
5765
5766 if (doDistributeReduction) {
5767 // Process the reductions if required.
5769 teamsOp, builder, moduleTranslation, allocaIP, reductionDecls,
5770 privateReductionVariables, isByRef,
5771 /*isNoWait*/ false, /*isTeamsReduction*/ true);
5772 }
5773 return success();
5774}
5775
5776/// Lowers the FlagsAttr which is applied to the module on the device
5777/// pass when offloading, this attribute contains OpenMP RTL globals that can
5778/// be passed as flags to the frontend, otherwise they are set to default
5779static LogicalResult
5780convertFlagsAttr(Operation *op, mlir::omp::FlagsAttr attribute,
5781 LLVM::ModuleTranslation &moduleTranslation) {
5782 if (!cast<mlir::ModuleOp>(op))
5783 return failure();
5784
5785 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
5786
5787 ompBuilder->M.addModuleFlag(llvm::Module::Max, "openmp-device",
5788 attribute.getOpenmpDeviceVersion());
5789
5790 if (attribute.getNoGpuLib())
5791 return success();
5792
5793 ompBuilder->createGlobalFlag(
5794 attribute.getDebugKind() /*LangOpts().OpenMPTargetDebug*/,
5795 "__omp_rtl_debug_kind");
5796 ompBuilder->createGlobalFlag(
5797 attribute
5798 .getAssumeTeamsOversubscription() /*LangOpts().OpenMPTeamSubscription*/
5799 ,
5800 "__omp_rtl_assume_teams_oversubscription");
5801 ompBuilder->createGlobalFlag(
5802 attribute
5803 .getAssumeThreadsOversubscription() /*LangOpts().OpenMPThreadSubscription*/
5804 ,
5805 "__omp_rtl_assume_threads_oversubscription");
5806 ompBuilder->createGlobalFlag(
5807 attribute.getAssumeNoThreadState() /*LangOpts().OpenMPNoThreadState*/,
5808 "__omp_rtl_assume_no_thread_state");
5809 ompBuilder->createGlobalFlag(
5810 attribute
5811 .getAssumeNoNestedParallelism() /*LangOpts().OpenMPNoNestedParallelism*/
5812 ,
5813 "__omp_rtl_assume_no_nested_parallelism");
5814 return success();
5815}
5816
5817static void getTargetEntryUniqueInfo(llvm::TargetRegionEntryInfo &targetInfo,
5818 omp::TargetOp targetOp,
5819 llvm::StringRef parentName = "") {
5820 auto fileLoc = targetOp.getLoc()->findInstanceOf<FileLineColLoc>();
5821
5822 assert(fileLoc && "No file found from location");
5823 StringRef fileName = fileLoc.getFilename().getValue();
5824
5825 llvm::sys::fs::UniqueID id;
5826 uint64_t line = fileLoc.getLine();
5827 if (auto ec = llvm::sys::fs::getUniqueID(fileName, id)) {
5828 size_t fileHash = llvm::hash_value(fileName.str());
5829 size_t deviceId = 0xdeadf17e;
5830 targetInfo =
5831 llvm::TargetRegionEntryInfo(parentName, deviceId, fileHash, line);
5832 } else {
5833 targetInfo = llvm::TargetRegionEntryInfo(parentName, id.getDevice(),
5834 id.getFile(), line);
5835 }
5836}
5837
5838static void
5839handleDeclareTargetMapVar(MapInfoData &mapData,
5840 LLVM::ModuleTranslation &moduleTranslation,
5841 llvm::IRBuilderBase &builder, llvm::Function *func) {
5842 assert(moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
5843 "function only supported for target device codegen");
5844 llvm::IRBuilderBase::InsertPointGuard guard(builder);
5845 for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
5846 // In the case of declare target mapped variables, the basePointer is
5847 // the reference pointer generated by the convertDeclareTargetAttr
5848 // method. Whereas the kernelValue is the original variable, so for
5849 // the device we must replace all uses of this original global variable
5850 // (stored in kernelValue) with the reference pointer (stored in
5851 // basePointer for declare target mapped variables), as for device the
5852 // data is mapped into this reference pointer and should be loaded
5853 // from it, the original variable is discarded. On host both exist and
5854 // metadata is generated (elsewhere in the convertDeclareTargetAttr)
5855 // function to link the two variables in the runtime and then both the
5856 // reference pointer and the pointer are assigned in the kernel argument
5857 // structure for the host.
5858 if (mapData.IsDeclareTarget[i]) {
5859 // If the original map value is a constant, then we have to make sure all
5860 // of it's uses within the current kernel/function that we are going to
5861 // rewrite are converted to instructions, as we will be altering the old
5862 // use (OriginalValue) from a constant to an instruction, which will be
5863 // illegal and ICE the compiler if the user is a constant expression of
5864 // some kind e.g. a constant GEP.
5865 if (auto *constant = dyn_cast<llvm::Constant>(mapData.OriginalValue[i]))
5866 convertUsersOfConstantsToInstructions(constant, func, false);
5867
5868 // The users iterator will get invalidated if we modify an element,
5869 // so we populate this vector of uses to alter each user on an
5870 // individual basis to emit its own load (rather than one load for
5871 // all).
5873 for (llvm::User *user : mapData.OriginalValue[i]->users())
5874 userVec.push_back(user);
5875
5876 for (llvm::User *user : userVec) {
5877 if (auto *insn = dyn_cast<llvm::Instruction>(user)) {
5878 if (insn->getFunction() == func) {
5879 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
5880 llvm::Value *substitute = mapData.BasePointers[i];
5881 if (isDeclareTargetLink(mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr()
5882 : mapOp.getVarPtr())) {
5883 builder.SetCurrentDebugLocation(insn->getDebugLoc());
5884 substitute = builder.CreateLoad(
5885 mapData.BasePointers[i]->getType(), mapData.BasePointers[i]);
5886 cast<llvm::LoadInst>(substitute)->moveBefore(insn->getIterator());
5887 }
5888 user->replaceUsesOfWith(mapData.OriginalValue[i], substitute);
5889 }
5890 }
5891 }
5892 }
5893 }
5894}
5895
5896// The createDeviceArgumentAccessor function generates
5897// instructions for retrieving (acessing) kernel
5898// arguments inside of the device kernel for use by
5899// the kernel. This enables different semantics such as
5900// the creation of temporary copies of data allowing
5901// semantics like read-only/no host write back kernel
5902// arguments.
5903//
5904// This currently implements a very light version of Clang's
5905// EmitParmDecl's handling of direct argument handling as well
5906// as a portion of the argument access generation based on
5907// capture types found at the end of emitOutlinedFunctionPrologue
5908// in Clang. The indirect path handling of EmitParmDecl's may be
5909// required for future work, but a direct 1-to-1 copy doesn't seem
5910// possible as the logic is rather scattered throughout Clang's
5911// lowering and perhaps we wish to deviate slightly.
5912//
5913// \param mapData - A container containing vectors of information
5914// corresponding to the input argument, which should have a
5915// corresponding entry in the MapInfoData containers
5916// OrigialValue's.
5917// \param arg - This is the generated kernel function argument that
5918// corresponds to the passed in input argument. We generated different
5919// accesses of this Argument, based on capture type and other Input
5920// related information.
5921// \param input - This is the host side value that will be passed to
5922// the kernel i.e. the kernel input, we rewrite all uses of this within
5923// the kernel (as we generate the kernel body based on the target's region
5924// which maintians references to the original input) to the retVal argument
5925// apon exit of this function inside of the OMPIRBuilder. This interlinks
5926// the kernel argument to future uses of it in the function providing
5927// appropriate "glue" instructions inbetween.
5928// \param retVal - This is the value that all uses of input inside of the
5929// kernel will be re-written to, the goal of this function is to generate
5930// an appropriate location for the kernel argument to be accessed from,
5931// e.g. ByRef will result in a temporary allocation location and then
5932// a store of the kernel argument into this allocated memory which
5933// will then be loaded from, ByCopy will use the allocated memory
5934// directly.
5935static llvm::IRBuilderBase::InsertPoint
5936createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg,
5937 llvm::Value *input, llvm::Value *&retVal,
5938 llvm::IRBuilderBase &builder,
5939 llvm::OpenMPIRBuilder &ompBuilder,
5940 LLVM::ModuleTranslation &moduleTranslation,
5941 llvm::IRBuilderBase::InsertPoint allocaIP,
5942 llvm::IRBuilderBase::InsertPoint codeGenIP) {
5943 assert(ompBuilder.Config.isTargetDevice() &&
5944 "function only supported for target device codegen");
5945 builder.restoreIP(allocaIP);
5946
5947 omp::VariableCaptureKind capture = omp::VariableCaptureKind::ByRef;
5948 LLVM::TypeToLLVMIRTranslator typeToLLVMIRTranslator(
5949 ompBuilder.M.getContext());
5950 unsigned alignmentValue = 0;
5951 // Find the associated MapInfoData entry for the current input
5952 for (size_t i = 0; i < mapData.MapClause.size(); ++i)
5953 if (mapData.OriginalValue[i] == input) {
5954 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
5955 capture = mapOp.getMapCaptureType();
5956 // Get information of alignment of mapped object
5957 alignmentValue = typeToLLVMIRTranslator.getPreferredAlignment(
5958 mapOp.getVarType(), ompBuilder.M.getDataLayout());
5959 break;
5960 }
5961
5962 unsigned int allocaAS = ompBuilder.M.getDataLayout().getAllocaAddrSpace();
5963 unsigned int defaultAS =
5964 ompBuilder.M.getDataLayout().getProgramAddressSpace();
5965
5966 // Create the alloca for the argument the current point.
5967 llvm::Value *v = builder.CreateAlloca(arg.getType(), allocaAS);
5968
5969 if (allocaAS != defaultAS && arg.getType()->isPointerTy())
5970 v = builder.CreateAddrSpaceCast(v, builder.getPtrTy(defaultAS));
5971
5972 builder.CreateStore(&arg, v);
5973
5974 builder.restoreIP(codeGenIP);
5975
5976 switch (capture) {
5977 case omp::VariableCaptureKind::ByCopy: {
5978 retVal = v;
5979 break;
5980 }
5981 case omp::VariableCaptureKind::ByRef: {
5982 llvm::LoadInst *loadInst = builder.CreateAlignedLoad(
5983 v->getType(), v,
5984 ompBuilder.M.getDataLayout().getPrefTypeAlign(v->getType()));
5985 // CreateAlignedLoad function creates similar LLVM IR:
5986 // %res = load ptr, ptr %input, align 8
5987 // This LLVM IR does not contain information about alignment
5988 // of the loaded value. We need to add !align metadata to unblock
5989 // optimizer. The existence of the !align metadata on the instruction
5990 // tells the optimizer that the value loaded is known to be aligned to
5991 // a boundary specified by the integer value in the metadata node.
5992 // Example:
5993 // %res = load ptr, ptr %input, align 8, !align !align_md_node
5994 // ^ ^
5995 // | |
5996 // alignment of %input address |
5997 // |
5998 // alignment of %res object
5999 if (v->getType()->isPointerTy() && alignmentValue) {
6000 llvm::MDBuilder MDB(builder.getContext());
6001 loadInst->setMetadata(
6002 llvm::LLVMContext::MD_align,
6003 llvm::MDNode::get(builder.getContext(),
6004 MDB.createConstant(llvm::ConstantInt::get(
6005 llvm::Type::getInt64Ty(builder.getContext()),
6006 alignmentValue))));
6007 }
6008 retVal = loadInst;
6009
6010 break;
6011 }
6012 case omp::VariableCaptureKind::This:
6013 case omp::VariableCaptureKind::VLAType:
6014 // TODO: Consider returning error to use standard reporting for
6015 // unimplemented features.
6016 assert(false && "Currently unsupported capture kind");
6017 break;
6018 }
6019
6020 return builder.saveIP();
6021}
6022
6023/// Follow uses of `host_eval`-defined block arguments of the given `omp.target`
6024/// operation and populate output variables with their corresponding host value
6025/// (i.e. operand evaluated outside of the target region), based on their uses
6026/// inside of the target region.
6027///
6028/// Loop bounds and steps are only optionally populated, if output vectors are
6029/// provided.
6030static void
6031extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
6032 Value &numTeamsLower, Value &numTeamsUpper,
6033 Value &threadLimit,
6034 llvm::SmallVectorImpl<Value> *lowerBounds = nullptr,
6035 llvm::SmallVectorImpl<Value> *upperBounds = nullptr,
6036 llvm::SmallVectorImpl<Value> *steps = nullptr) {
6037 auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
6038 for (auto item : llvm::zip_equal(targetOp.getHostEvalVars(),
6039 blockArgIface.getHostEvalBlockArgs())) {
6040 Value hostEvalVar = std::get<0>(item), blockArg = std::get<1>(item);
6041
6042 for (Operation *user : blockArg.getUsers()) {
6044 .Case([&](omp::TeamsOp teamsOp) {
6045 if (teamsOp.getNumTeamsLower() == blockArg)
6046 numTeamsLower = hostEvalVar;
6047 else if (llvm::is_contained(teamsOp.getNumTeamsUpperVars(),
6048 blockArg))
6049 numTeamsUpper = hostEvalVar;
6050 else if (!teamsOp.getThreadLimitVars().empty() &&
6051 teamsOp.getThreadLimit(0) == blockArg)
6052 threadLimit = hostEvalVar;
6053 else
6054 llvm_unreachable("unsupported host_eval use");
6055 })
6056 .Case([&](omp::ParallelOp parallelOp) {
6057 if (!parallelOp.getNumThreadsVars().empty() &&
6058 parallelOp.getNumThreads(0) == blockArg)
6059 numThreads = hostEvalVar;
6060 else
6061 llvm_unreachable("unsupported host_eval use");
6062 })
6063 .Case([&](omp::LoopNestOp loopOp) {
6064 auto processBounds =
6065 [&](OperandRange opBounds,
6066 llvm::SmallVectorImpl<Value> *outBounds) -> bool {
6067 bool found = false;
6068 for (auto [i, lb] : llvm::enumerate(opBounds)) {
6069 if (lb == blockArg) {
6070 found = true;
6071 if (outBounds)
6072 (*outBounds)[i] = hostEvalVar;
6073 }
6074 }
6075 return found;
6076 };
6077 bool found =
6078 processBounds(loopOp.getLoopLowerBounds(), lowerBounds);
6079 found = processBounds(loopOp.getLoopUpperBounds(), upperBounds) ||
6080 found;
6081 found = processBounds(loopOp.getLoopSteps(), steps) || found;
6082 (void)found;
6083 assert(found && "unsupported host_eval use");
6084 })
6085 .DefaultUnreachable("unsupported host_eval use");
6086 }
6087 }
6088}
6089
6090/// If \p op is of the given type parameter, return it casted to that type.
6091/// Otherwise, if its immediate parent operation (or some other higher-level
6092/// parent, if \p immediateParent is false) is of that type, return that parent
6093/// casted to the given type.
6094///
6095/// If \p op is \c null or neither it or its parent(s) are of the specified
6096/// type, return a \c null operation.
6097template <typename OpTy>
6098static OpTy castOrGetParentOfType(Operation *op, bool immediateParent = false) {
6099 if (!op)
6100 return OpTy();
6101
6102 if (OpTy casted = dyn_cast<OpTy>(op))
6103 return casted;
6104
6105 if (immediateParent)
6106 return dyn_cast_if_present<OpTy>(op->getParentOp());
6107
6108 return op->getParentOfType<OpTy>();
6109}
6110
6111/// If the given \p value is defined by an \c llvm.mlir.constant operation and
6112/// it is of an integer type, return its value.
6113static std::optional<int64_t> extractConstInteger(Value value) {
6114 if (!value)
6115 return std::nullopt;
6116
6117 if (auto constOp = value.getDefiningOp<LLVM::ConstantOp>())
6118 if (auto constAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
6119 return constAttr.getInt();
6120
6121 return std::nullopt;
6122}
6123
6124static uint64_t getTypeByteSize(mlir::Type type, const DataLayout &dl) {
6125 uint64_t sizeInBits = dl.getTypeSizeInBits(type);
6126 uint64_t sizeInBytes = sizeInBits / 8;
6127 return sizeInBytes;
6128}
6129
6130template <typename OpTy>
6131static uint64_t getReductionDataSize(OpTy &op) {
6132 if (op.getNumReductionVars() > 0) {
6134 collectReductionDecls(op, reductions);
6135
6137 members.reserve(reductions.size());
6138 for (omp::DeclareReductionOp &red : reductions)
6139 members.push_back(red.getType());
6140 Operation *opp = op.getOperation();
6141 auto structType = mlir::LLVM::LLVMStructType::getLiteral(
6142 opp->getContext(), members, /*isPacked=*/false);
6143 DataLayout dl = DataLayout(opp->getParentOfType<ModuleOp>());
6144 return getTypeByteSize(structType, dl);
6145 }
6146 return 0;
6147}
6148
6149/// Populate default `MinTeams`, `MaxTeams` and `MaxThreads` to their default
6150/// values as stated by the corresponding clauses, if constant.
6151///
6152/// These default values must be set before the creation of the outlined LLVM
6153/// function for the target region, so that they can be used to initialize the
6154/// corresponding global `ConfigurationEnvironmentTy` structure.
6155static void
6156initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
6157 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs,
6158 bool isTargetDevice, bool isGPU) {
6159 // TODO: Handle constant 'if' clauses.
6160
6161 Value numThreads, numTeamsLower, numTeamsUpper, threadLimit;
6162 if (!isTargetDevice) {
6163 extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper,
6164 threadLimit);
6165 } else {
6166 // In the target device, values for these clauses are not passed as
6167 // host_eval, but instead evaluated prior to entry to the region. This
6168 // ensures values are mapped and available inside of the target region.
6169 if (auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
6170 numTeamsLower = teamsOp.getNumTeamsLower();
6171 // Handle num_teams upper bounds (only first value for now)
6172 if (!teamsOp.getNumTeamsUpperVars().empty())
6173 numTeamsUpper = teamsOp.getNumTeams(0);
6174 if (!teamsOp.getThreadLimitVars().empty())
6175 threadLimit = teamsOp.getThreadLimit(0);
6176 }
6177
6178 if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) {
6179 if (!parallelOp.getNumThreadsVars().empty())
6180 numThreads = parallelOp.getNumThreads(0);
6181 }
6182 }
6183
6184 // Handle clauses impacting the number of teams.
6185
6186 int32_t minTeamsVal = 1, maxTeamsVal = -1;
6187 if (castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
6188 // TODO: Use `hostNumTeamsLower` to initialize `minTeamsVal`. For now,
6189 // match clang and set min and max to the same value.
6190 if (numTeamsUpper) {
6191 if (auto val = extractConstInteger(numTeamsUpper))
6192 minTeamsVal = maxTeamsVal = *val;
6193 } else {
6194 minTeamsVal = maxTeamsVal = 0;
6195 }
6196 } else if (castOrGetParentOfType<omp::ParallelOp>(capturedOp,
6197 /*immediateParent=*/true) ||
6199 /*immediateParent=*/true)) {
6200 minTeamsVal = maxTeamsVal = 1;
6201 } else {
6202 minTeamsVal = maxTeamsVal = -1;
6203 }
6204
6205 // Handle clauses impacting the number of threads.
6206
6207 auto setMaxValueFromClause = [](Value clauseValue, int32_t &result) {
6208 if (!clauseValue)
6209 return;
6210
6211 if (auto val = extractConstInteger(clauseValue))
6212 result = *val;
6213
6214 // Found an applicable clause, so it's not undefined. Mark as unknown
6215 // because it's not constant.
6216 if (result < 0)
6217 result = 0;
6218 };
6219
6220 // Extract 'thread_limit' clause from 'target' and 'teams' directives.
6221 int32_t targetThreadLimitVal = -1, teamsThreadLimitVal = -1;
6222 if (!targetOp.getThreadLimitVars().empty())
6223 setMaxValueFromClause(targetOp.getThreadLimit(0), targetThreadLimitVal);
6224 setMaxValueFromClause(threadLimit, teamsThreadLimitVal);
6225
6226 // Extract 'max_threads' clause from 'parallel' or set to 1 if it's SIMD.
6227 int32_t maxThreadsVal = -1;
6229 setMaxValueFromClause(numThreads, maxThreadsVal);
6230 else if (castOrGetParentOfType<omp::SimdOp>(capturedOp,
6231 /*immediateParent=*/true))
6232 maxThreadsVal = 1;
6233
6234 // For max values, < 0 means unset, == 0 means set but unknown. Select the
6235 // minimum value between 'max_threads' and 'thread_limit' clauses that were
6236 // set.
6237 int32_t combinedMaxThreadsVal = targetThreadLimitVal;
6238 if (combinedMaxThreadsVal < 0 ||
6239 (teamsThreadLimitVal >= 0 && teamsThreadLimitVal < combinedMaxThreadsVal))
6240 combinedMaxThreadsVal = teamsThreadLimitVal;
6241
6242 if (combinedMaxThreadsVal < 0 ||
6243 (maxThreadsVal >= 0 && maxThreadsVal < combinedMaxThreadsVal))
6244 combinedMaxThreadsVal = maxThreadsVal;
6245
6246 int32_t reductionDataSize = 0;
6247 if (isGPU && capturedOp) {
6248 if (auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp))
6249 reductionDataSize = getReductionDataSize(teamsOp);
6250 }
6251
6252 // Update kernel bounds structure for the `OpenMPIRBuilder` to use.
6253 omp::TargetRegionFlags kernelFlags = targetOp.getKernelExecFlags(capturedOp);
6254 assert(
6255 omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic |
6256 omp::TargetRegionFlags::spmd) &&
6257 "invalid kernel flags");
6258 attrs.ExecFlags =
6259 omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic)
6260 ? omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::spmd)
6261 ? llvm::omp::OMP_TGT_EXEC_MODE_GENERIC_SPMD
6262 : llvm::omp::OMP_TGT_EXEC_MODE_GENERIC
6263 : llvm::omp::OMP_TGT_EXEC_MODE_SPMD;
6264 if (omp::bitEnumContainsAll(kernelFlags,
6265 omp::TargetRegionFlags::spmd |
6266 omp::TargetRegionFlags::no_loop) &&
6267 !omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic))
6268 attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_SPMD_NO_LOOP;
6269
6270 attrs.MinTeams = minTeamsVal;
6271 attrs.MaxTeams.front() = maxTeamsVal;
6272 attrs.MinThreads = 1;
6273 attrs.MaxThreads.front() = combinedMaxThreadsVal;
6274 attrs.ReductionDataSize = reductionDataSize;
6275 // TODO: Allow modified buffer length similar to
6276 // fopenmp-cuda-teams-reduction-recs-num flag in clang.
6277 if (attrs.ReductionDataSize != 0)
6278 attrs.ReductionBufferLength = 1024;
6279}
6280
6281/// Gather LLVM runtime values for all clauses evaluated in the host that are
6282/// passed to the kernel invocation.
6283///
6284/// This function must be called only when compiling for the host. Also, it will
6285/// only provide correct results if it's called after the body of \c targetOp
6286/// has been fully generated.
6287static void
6288initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
6289 LLVM::ModuleTranslation &moduleTranslation,
6290 omp::TargetOp targetOp, Operation *capturedOp,
6291 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs) {
6292 omp::LoopNestOp loopOp = castOrGetParentOfType<omp::LoopNestOp>(capturedOp);
6293 unsigned numLoops = loopOp ? loopOp.getNumLoops() : 0;
6294
6295 Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
6296 llvm::SmallVector<Value> lowerBounds(numLoops), upperBounds(numLoops),
6297 steps(numLoops);
6298 extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper,
6299 teamsThreadLimit, &lowerBounds, &upperBounds, &steps);
6300
6301 // TODO: Handle constant 'if' clauses.
6302 if (!targetOp.getThreadLimitVars().empty()) {
6303 Value targetThreadLimit = targetOp.getThreadLimit(0);
6304 attrs.TargetThreadLimit.front() =
6305 moduleTranslation.lookupValue(targetThreadLimit);
6306 }
6307
6308 // The __kmpc_push_num_teams_51 function expects int32 as the arguments. So,
6309 // truncate or sign extend lower and upper num_teams bounds as well as
6310 // thread_limit to match int32 ABI requirements for the OpenMP runtime.
6311 if (numTeamsLower)
6312 attrs.MinTeams = builder.CreateSExtOrTrunc(
6313 moduleTranslation.lookupValue(numTeamsLower), builder.getInt32Ty());
6314
6315 if (numTeamsUpper)
6316 attrs.MaxTeams.front() = builder.CreateSExtOrTrunc(
6317 moduleTranslation.lookupValue(numTeamsUpper), builder.getInt32Ty());
6318
6319 if (teamsThreadLimit)
6320 attrs.TeamsThreadLimit.front() = builder.CreateSExtOrTrunc(
6321 moduleTranslation.lookupValue(teamsThreadLimit), builder.getInt32Ty());
6322
6323 if (numThreads)
6324 attrs.MaxThreads = moduleTranslation.lookupValue(numThreads);
6325
6326 if (omp::bitEnumContainsAny(targetOp.getKernelExecFlags(capturedOp),
6327 omp::TargetRegionFlags::trip_count)) {
6328 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
6329 attrs.LoopTripCount = nullptr;
6330
6331 // To calculate the trip count, we multiply together the trip counts of
6332 // every collapsed canonical loop. We don't need to create the loop nests
6333 // here, since we're only interested in the trip count.
6334 for (auto [loopLower, loopUpper, loopStep] :
6335 llvm::zip_equal(lowerBounds, upperBounds, steps)) {
6336 llvm::Value *lowerBound = moduleTranslation.lookupValue(loopLower);
6337 llvm::Value *upperBound = moduleTranslation.lookupValue(loopUpper);
6338 llvm::Value *step = moduleTranslation.lookupValue(loopStep);
6339
6340 if (!lowerBound || !upperBound || !step) {
6341 attrs.LoopTripCount = nullptr;
6342 break;
6343 }
6344
6345 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
6346 llvm::Value *tripCount = ompBuilder->calculateCanonicalLoopTripCount(
6347 loc, lowerBound, upperBound, step, /*IsSigned=*/true,
6348 loopOp.getLoopInclusive());
6349
6350 if (!attrs.LoopTripCount) {
6351 attrs.LoopTripCount = tripCount;
6352 continue;
6353 }
6354
6355 // TODO: Enable UndefinedSanitizer to diagnose an overflow here.
6356 attrs.LoopTripCount = builder.CreateMul(attrs.LoopTripCount, tripCount,
6357 {}, /*HasNUW=*/true);
6358 }
6359 }
6360
6361 attrs.DeviceID = builder.getInt64(llvm::omp::OMP_DEVICEID_UNDEF);
6362 if (mlir::Value devId = targetOp.getDevice()) {
6363 attrs.DeviceID = moduleTranslation.lookupValue(devId);
6364 attrs.DeviceID =
6365 builder.CreateSExtOrTrunc(attrs.DeviceID, builder.getInt64Ty());
6366 }
6367}
6368
6369static LogicalResult
6370convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
6371 LLVM::ModuleTranslation &moduleTranslation) {
6372 auto targetOp = cast<omp::TargetOp>(opInst);
6373 // The current debug location already has the DISubprogram for the outlined
6374 // function that will be created for the target op. We save it here so that
6375 // we can set it on the outlined function.
6376 llvm::DebugLoc outlinedFnLoc = builder.getCurrentDebugLocation();
6377 if (failed(checkImplementationStatus(opInst)))
6378 return failure();
6379
6380 // During the handling of target op, we will generate instructions in the
6381 // parent function like call to the oulined function or branch to a new
6382 // BasicBlock. We set the debug location here to parent function so that those
6383 // get the correct debug locations. For outlined functions, the normal MLIR op
6384 // conversion will automatically pick the correct location.
6385 llvm::BasicBlock *parentBB = builder.GetInsertBlock();
6386 assert(parentBB && "No insert block is set for the builder");
6387 llvm::Function *parentLLVMFn = parentBB->getParent();
6388 assert(parentLLVMFn && "Parent Function must be valid");
6389 if (llvm::DISubprogram *SP = parentLLVMFn->getSubprogram())
6390 builder.SetCurrentDebugLocation(llvm::DILocation::get(
6391 parentLLVMFn->getContext(), outlinedFnLoc.getLine(),
6392 outlinedFnLoc.getCol(), SP, outlinedFnLoc.getInlinedAt()));
6393
6394 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
6395 bool isTargetDevice = ompBuilder->Config.isTargetDevice();
6396 bool isGPU = ompBuilder->Config.isGPU();
6397
6398 auto parentFn = opInst.getParentOfType<LLVM::LLVMFuncOp>();
6399 auto argIface = cast<omp::BlockArgOpenMPOpInterface>(opInst);
6400 auto &targetRegion = targetOp.getRegion();
6401 // Holds the private vars that have been mapped along with the block
6402 // argument that corresponds to the MapInfoOp corresponding to the private
6403 // var in question. So, for instance:
6404 //
6405 // %10 = omp.map.info var_ptr(%6#0 : !fir.ref<!fir.box<!fir.heap<i32>>>, ..)
6406 // omp.target map_entries(%10 -> %arg0) private(@box.privatizer %6#0-> %arg1)
6407 //
6408 // Then, %10 has been created so that the descriptor can be used by the
6409 // privatizer @box.privatizer on the device side. Here we'd record {%6#0,
6410 // %arg0} in the mappedPrivateVars map.
6411 llvm::DenseMap<Value, Value> mappedPrivateVars;
6412 DataLayout dl = DataLayout(opInst.getParentOfType<ModuleOp>());
6413 SmallVector<Value> mapVars = targetOp.getMapVars();
6414 SmallVector<Value> hdaVars = targetOp.getHasDeviceAddrVars();
6415 ArrayRef<BlockArgument> mapBlockArgs = argIface.getMapBlockArgs();
6416 ArrayRef<BlockArgument> hdaBlockArgs = argIface.getHasDeviceAddrBlockArgs();
6417 llvm::Function *llvmOutlinedFn = nullptr;
6418 TargetDirectiveEnumTy targetDirective =
6419 getTargetDirectiveEnumTyFromOp(&opInst);
6420
6421 // TODO: It can also be false if a compile-time constant `false` IF clause is
6422 // specified.
6423 bool isOffloadEntry =
6424 isTargetDevice || !ompBuilder->Config.TargetTriples.empty();
6425
6426 // For some private variables, the MapsForPrivatizedVariablesPass
6427 // creates MapInfoOp instances. Go through the private variables and
6428 // the mapped variables so that during codegeneration we are able
6429 // to quickly look up the corresponding map variable, if any for each
6430 // private variable.
6431 if (!targetOp.getPrivateVars().empty() && !targetOp.getMapVars().empty()) {
6432 OperandRange privateVars = targetOp.getPrivateVars();
6433 std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms();
6434 std::optional<DenseI64ArrayAttr> privateMapIndices =
6435 targetOp.getPrivateMapsAttr();
6436
6437 for (auto [privVarIdx, privVarSymPair] :
6438 llvm::enumerate(llvm::zip_equal(privateVars, *privateSyms))) {
6439 auto privVar = std::get<0>(privVarSymPair);
6440 auto privSym = std::get<1>(privVarSymPair);
6441
6442 SymbolRefAttr privatizerName = llvm::cast<SymbolRefAttr>(privSym);
6443 omp::PrivateClauseOp privatizer =
6444 findPrivatizer(targetOp, privatizerName);
6445
6446 if (!privatizer.needsMap())
6447 continue;
6448
6449 mlir::Value mappedValue =
6450 targetOp.getMappedValueForPrivateVar(privVarIdx);
6451 assert(mappedValue && "Expected to find mapped value for a privatized "
6452 "variable that needs mapping");
6453
6454 // The MapInfoOp defining the map var isn't really needed later.
6455 // So, we don't store it in any datastructure. Instead, we just
6456 // do some sanity checks on it right now.
6457 auto mapInfoOp = mappedValue.getDefiningOp<omp::MapInfoOp>();
6458 [[maybe_unused]] Type varType = mapInfoOp.getVarType();
6459
6460 // Check #1: Check that the type of the private variable matches
6461 // the type of the variable being mapped.
6462 if (!isa<LLVM::LLVMPointerType>(privVar.getType()))
6463 assert(
6464 varType == privVar.getType() &&
6465 "Type of private var doesn't match the type of the mapped value");
6466
6467 // Ok, only 1 sanity check for now.
6468 // Record the block argument corresponding to this mapvar.
6469 mappedPrivateVars.insert(
6470 {privVar,
6471 targetRegion.getArgument(argIface.getMapBlockArgsStart() +
6472 (*privateMapIndices)[privVarIdx])});
6473 }
6474 }
6475
6476 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
6477 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP)
6478 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
6479 llvm::IRBuilderBase::InsertPointGuard guard(builder);
6480 builder.SetCurrentDebugLocation(llvm::DebugLoc());
6481 // Forward target-cpu and target-features function attributes from the
6482 // original function to the new outlined function.
6483 llvm::Function *llvmParentFn =
6484 moduleTranslation.lookupFunction(parentFn.getName());
6485 llvmOutlinedFn = codeGenIP.getBlock()->getParent();
6486 assert(llvmParentFn && llvmOutlinedFn &&
6487 "Both parent and outlined functions must exist at this point");
6488
6489 if (outlinedFnLoc && llvmParentFn->getSubprogram())
6490 llvmOutlinedFn->setSubprogram(outlinedFnLoc->getScope()->getSubprogram());
6491
6492 if (auto attr = llvmParentFn->getFnAttribute("target-cpu");
6493 attr.isStringAttribute())
6494 llvmOutlinedFn->addFnAttr(attr);
6495
6496 if (auto attr = llvmParentFn->getFnAttribute("target-features");
6497 attr.isStringAttribute())
6498 llvmOutlinedFn->addFnAttr(attr);
6499
6500 for (auto [arg, mapOp] : llvm::zip_equal(mapBlockArgs, mapVars)) {
6501 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
6502 llvm::Value *mapOpValue =
6503 moduleTranslation.lookupValue(mapInfoOp.getVarPtr());
6504 moduleTranslation.mapValue(arg, mapOpValue);
6505 }
6506 for (auto [arg, mapOp] : llvm::zip_equal(hdaBlockArgs, hdaVars)) {
6507 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
6508 llvm::Value *mapOpValue =
6509 moduleTranslation.lookupValue(mapInfoOp.getVarPtr());
6510 moduleTranslation.mapValue(arg, mapOpValue);
6511 }
6512
6513 // Do privatization after moduleTranslation has already recorded
6514 // mapped values.
6515 PrivateVarsInfo privateVarsInfo(targetOp);
6516
6518 allocatePrivateVars(builder, moduleTranslation, privateVarsInfo,
6519 allocaIP, &mappedPrivateVars);
6520
6521 if (failed(handleError(afterAllocas, *targetOp)))
6522 return llvm::make_error<PreviouslyReportedError>();
6523
6524 builder.restoreIP(codeGenIP);
6525 if (handleError(initPrivateVars(builder, moduleTranslation, privateVarsInfo,
6526 &mappedPrivateVars),
6527 *targetOp)
6528 .failed())
6529 return llvm::make_error<PreviouslyReportedError>();
6530
6531 if (failed(copyFirstPrivateVars(
6532 targetOp, builder, moduleTranslation, privateVarsInfo.mlirVars,
6533 privateVarsInfo.llvmVars, privateVarsInfo.privatizers,
6534 targetOp.getPrivateNeedsBarrier(), &mappedPrivateVars)))
6535 return llvm::make_error<PreviouslyReportedError>();
6536
6537 SmallVector<Region *> privateCleanupRegions;
6538 llvm::transform(privateVarsInfo.privatizers,
6539 std::back_inserter(privateCleanupRegions),
6540 [](omp::PrivateClauseOp privatizer) {
6541 return &privatizer.getDeallocRegion();
6542 });
6543
6545 targetRegion, "omp.target", builder, moduleTranslation);
6546
6547 if (!exitBlock)
6548 return exitBlock.takeError();
6549
6550 builder.SetInsertPoint(*exitBlock);
6551 if (!privateCleanupRegions.empty()) {
6552 if (failed(inlineOmpRegionCleanup(
6553 privateCleanupRegions, privateVarsInfo.llvmVars,
6554 moduleTranslation, builder, "omp.targetop.private.cleanup",
6555 /*shouldLoadCleanupRegionArg=*/false))) {
6556 return llvm::createStringError(
6557 "failed to inline `dealloc` region of `omp.private` "
6558 "op in the target region");
6559 }
6560 return builder.saveIP();
6561 }
6562
6563 return InsertPointTy(exitBlock.get(), exitBlock.get()->end());
6564 };
6565
6566 StringRef parentName = parentFn.getName();
6567
6568 llvm::TargetRegionEntryInfo entryInfo;
6569
6570 getTargetEntryUniqueInfo(entryInfo, targetOp, parentName);
6571
6572 MapInfoData mapData;
6573 collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, dl,
6574 builder, /*useDevPtrOperands=*/{},
6575 /*useDevAddrOperands=*/{}, hdaVars);
6576
6577 MapInfosTy combinedInfos;
6578 auto genMapInfoCB =
6579 [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP) -> MapInfosTy & {
6580 builder.restoreIP(codeGenIP);
6581 genMapInfos(builder, moduleTranslation, dl, combinedInfos, mapData,
6582 targetDirective);
6583 return combinedInfos;
6584 };
6585
6586 auto argAccessorCB = [&](llvm::Argument &arg, llvm::Value *input,
6587 llvm::Value *&retVal, InsertPointTy allocaIP,
6588 InsertPointTy codeGenIP)
6589 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
6590 llvm::IRBuilderBase::InsertPointGuard guard(builder);
6591 builder.SetCurrentDebugLocation(llvm::DebugLoc());
6592 // We just return the unaltered argument for the host function
6593 // for now, some alterations may be required in the future to
6594 // keep host fallback functions working identically to the device
6595 // version (e.g. pass ByCopy values should be treated as such on
6596 // host and device, currently not always the case)
6597 if (!isTargetDevice) {
6598 retVal = cast<llvm::Value>(&arg);
6599 return codeGenIP;
6600 }
6601
6602 return createDeviceArgumentAccessor(mapData, arg, input, retVal, builder,
6603 *ompBuilder, moduleTranslation,
6604 allocaIP, codeGenIP);
6605 };
6606
6607 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs runtimeAttrs;
6608 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs;
6609 Operation *targetCapturedOp = targetOp.getInnermostCapturedOmpOp();
6610 initTargetDefaultAttrs(targetOp, targetCapturedOp, defaultAttrs,
6611 isTargetDevice, isGPU);
6612
6613 // Collect host-evaluated values needed to properly launch the kernel from the
6614 // host.
6615 if (!isTargetDevice)
6616 initTargetRuntimeAttrs(builder, moduleTranslation, targetOp,
6617 targetCapturedOp, runtimeAttrs);
6618
6619 // Pass host-evaluated values as parameters to the kernel / host fallback,
6620 // except if they are constants. In any case, map the MLIR block argument to
6621 // the corresponding LLVM values.
6623 SmallVector<Value> hostEvalVars = targetOp.getHostEvalVars();
6624 ArrayRef<BlockArgument> hostEvalBlockArgs = argIface.getHostEvalBlockArgs();
6625 for (auto [arg, var] : llvm::zip_equal(hostEvalBlockArgs, hostEvalVars)) {
6626 llvm::Value *value = moduleTranslation.lookupValue(var);
6627 moduleTranslation.mapValue(arg, value);
6628
6629 if (!llvm::isa<llvm::Constant>(value))
6630 kernelInput.push_back(value);
6631 }
6632
6633 for (size_t i = 0, e = mapData.OriginalValue.size(); i != e; ++i) {
6634 // declare target arguments are not passed to kernels as arguments
6635 // TODO: We currently do not handle cases where a member is explicitly
6636 // passed in as an argument, this will likley need to be handled in
6637 // the near future, rather than using IsAMember, it may be better to
6638 // test if the relevant BlockArg is used within the target region and
6639 // then use that as a basis for exclusion in the kernel inputs.
6640 if (!mapData.IsDeclareTarget[i] && !mapData.IsAMember[i])
6641 kernelInput.push_back(mapData.OriginalValue[i]);
6642 }
6643
6645 buildDependData(targetOp.getDependKinds(), targetOp.getDependVars(),
6646 moduleTranslation, dds);
6647
6648 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
6649 findAllocaInsertPoint(builder, moduleTranslation);
6650 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
6651
6652 llvm::OpenMPIRBuilder::TargetDataInfo info(
6653 /*RequiresDevicePointerInfo=*/false,
6654 /*SeparateBeginEndCalls=*/true);
6655
6656 auto customMapperCB =
6657 [&](unsigned int i) -> llvm::Expected<llvm::Function *> {
6658 if (!combinedInfos.Mappers[i])
6659 return nullptr;
6660 info.HasMapper = true;
6661 return getOrCreateUserDefinedMapperFunc(combinedInfos.Mappers[i], builder,
6662 moduleTranslation, targetDirective);
6663 };
6664
6665 llvm::Value *ifCond = nullptr;
6666 if (Value targetIfCond = targetOp.getIfExpr())
6667 ifCond = moduleTranslation.lookupValue(targetIfCond);
6668
6669 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
6670 moduleTranslation.getOpenMPBuilder()->createTarget(
6671 ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), info, entryInfo,
6672 defaultAttrs, runtimeAttrs, ifCond, kernelInput, genMapInfoCB, bodyCB,
6673 argAccessorCB, customMapperCB, dds, targetOp.getNowait());
6674
6675 if (failed(handleError(afterIP, opInst)))
6676 return failure();
6677
6678 builder.restoreIP(*afterIP);
6679
6680 // Remap access operations to declare target reference pointers for the
6681 // device, essentially generating extra loadop's as necessary
6682 if (moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice())
6683 handleDeclareTargetMapVar(mapData, moduleTranslation, builder,
6684 llvmOutlinedFn);
6685
6686 return success();
6687}
6688
6689static LogicalResult
6690convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute,
6691 LLVM::ModuleTranslation &moduleTranslation) {
6692 // Amend omp.declare_target by deleting the IR of the outlined functions
6693 // created for target regions. They cannot be filtered out from MLIR earlier
6694 // because the omp.target operation inside must be translated to LLVM, but
6695 // the wrapper functions themselves must not remain at the end of the
6696 // process. We know that functions where omp.declare_target does not match
6697 // omp.is_target_device at this stage can only be wrapper functions because
6698 // those that aren't are removed earlier as an MLIR transformation pass.
6699 if (FunctionOpInterface funcOp = dyn_cast<FunctionOpInterface>(op)) {
6700 if (auto offloadMod = dyn_cast<omp::OffloadModuleInterface>(
6701 op->getParentOfType<ModuleOp>().getOperation())) {
6702 if (!offloadMod.getIsTargetDevice())
6703 return success();
6704
6705 omp::DeclareTargetDeviceType declareType =
6706 attribute.getDeviceType().getValue();
6707
6708 if (declareType == omp::DeclareTargetDeviceType::host) {
6709 llvm::Function *llvmFunc =
6710 moduleTranslation.lookupFunction(funcOp.getName());
6711 llvmFunc->dropAllReferences();
6712 llvmFunc->eraseFromParent();
6713 }
6714 }
6715 return success();
6716 }
6717
6718 if (LLVM::GlobalOp gOp = dyn_cast<LLVM::GlobalOp>(op)) {
6719 llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
6720 if (auto *gVal = llvmModule->getNamedValue(gOp.getSymName())) {
6721 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
6722 bool isDeclaration = gOp.isDeclaration();
6723 bool isExternallyVisible =
6724 gOp.getVisibility() != mlir::SymbolTable::Visibility::Private;
6725 auto loc = op->getLoc()->findInstanceOf<FileLineColLoc>();
6726 llvm::StringRef mangledName = gOp.getSymName();
6727 auto captureClause =
6728 convertToCaptureClauseKind(attribute.getCaptureClause().getValue());
6729 auto deviceClause =
6730 convertToDeviceClauseKind(attribute.getDeviceType().getValue());
6731 // unused for MLIR at the moment, required in Clang for book
6732 // keeping
6733 std::vector<llvm::GlobalVariable *> generatedRefs;
6734
6735 std::vector<llvm::Triple> targetTriple;
6736 auto targetTripleAttr = dyn_cast_or_null<mlir::StringAttr>(
6737 op->getParentOfType<mlir::ModuleOp>()->getAttr(
6738 LLVM::LLVMDialect::getTargetTripleAttrName()));
6739 if (targetTripleAttr)
6740 targetTriple.emplace_back(targetTripleAttr.data());
6741
6742 auto fileInfoCallBack = [&loc]() {
6743 std::string filename = "";
6744 std::uint64_t lineNo = 0;
6745
6746 if (loc) {
6747 filename = loc.getFilename().str();
6748 lineNo = loc.getLine();
6749 }
6750
6751 return std::pair<std::string, std::uint64_t>(llvm::StringRef(filename),
6752 lineNo);
6753 };
6754
6755 auto vfs = llvm::vfs::getRealFileSystem();
6756
6757 ompBuilder->registerTargetGlobalVariable(
6758 captureClause, deviceClause, isDeclaration, isExternallyVisible,
6759 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack, *vfs),
6760 mangledName, generatedRefs, /*OpenMPSimd*/ false, targetTriple,
6761 /*GlobalInitializer*/ nullptr, /*VariableLinkage*/ nullptr,
6762 gVal->getType(), gVal);
6763
6764 if (ompBuilder->Config.isTargetDevice() &&
6765 (attribute.getCaptureClause().getValue() !=
6766 mlir::omp::DeclareTargetCaptureClause::to ||
6767 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
6768 ompBuilder->getAddrOfDeclareTargetVar(
6769 captureClause, deviceClause, isDeclaration, isExternallyVisible,
6770 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack, *vfs),
6771 mangledName, generatedRefs, /*OpenMPSimd*/ false, targetTriple,
6772 gVal->getType(), /*GlobalInitializer*/ nullptr,
6773 /*VariableLinkage*/ nullptr);
6774 }
6775 }
6776 }
6777
6778 return success();
6779}
6780
6781namespace {
6782
6783/// Implementation of the dialect interface that converts operations belonging
6784/// to the OpenMP dialect to LLVM IR.
6785class OpenMPDialectLLVMIRTranslationInterface
6786 : public LLVMTranslationDialectInterface {
6787public:
6789
6790 /// Translates the given operation to LLVM IR using the provided IR builder
6791 /// and saving the state in `moduleTranslation`.
6792 LogicalResult
6793 convertOperation(Operation *op, llvm::IRBuilderBase &builder,
6794 LLVM::ModuleTranslation &moduleTranslation) const final;
6795
6796 /// Given an OpenMP MLIR attribute, create the corresponding LLVM-IR,
6797 /// runtime calls, or operation amendments
6798 LogicalResult
6799 amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
6800 NamedAttribute attribute,
6801 LLVM::ModuleTranslation &moduleTranslation) const final;
6802};
6803
6804} // namespace
6805
6806LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
6807 Operation *op, ArrayRef<llvm::Instruction *> instructions,
6808 NamedAttribute attribute,
6809 LLVM::ModuleTranslation &moduleTranslation) const {
6810 return llvm::StringSwitch<llvm::function_ref<LogicalResult(Attribute)>>(
6811 attribute.getName())
6812 .Case("omp.is_target_device",
6813 [&](Attribute attr) {
6814 if (auto deviceAttr = dyn_cast<BoolAttr>(attr)) {
6815 llvm::OpenMPIRBuilderConfig &config =
6816 moduleTranslation.getOpenMPBuilder()->Config;
6817 config.setIsTargetDevice(deviceAttr.getValue());
6818 return success();
6819 }
6820 return failure();
6821 })
6822 .Case("omp.is_gpu",
6823 [&](Attribute attr) {
6824 if (auto gpuAttr = dyn_cast<BoolAttr>(attr)) {
6825 llvm::OpenMPIRBuilderConfig &config =
6826 moduleTranslation.getOpenMPBuilder()->Config;
6827 config.setIsGPU(gpuAttr.getValue());
6828 return success();
6829 }
6830 return failure();
6831 })
6832 .Case("omp.host_ir_filepath",
6833 [&](Attribute attr) {
6834 if (auto filepathAttr = dyn_cast<StringAttr>(attr)) {
6835 llvm::OpenMPIRBuilder *ompBuilder =
6836 moduleTranslation.getOpenMPBuilder();
6837 auto VFS = llvm::vfs::getRealFileSystem();
6838 ompBuilder->loadOffloadInfoMetadata(*VFS,
6839 filepathAttr.getValue());
6840 return success();
6841 }
6842 return failure();
6843 })
6844 .Case("omp.flags",
6845 [&](Attribute attr) {
6846 if (auto rtlAttr = dyn_cast<omp::FlagsAttr>(attr))
6847 return convertFlagsAttr(op, rtlAttr, moduleTranslation);
6848 return failure();
6849 })
6850 .Case("omp.version",
6851 [&](Attribute attr) {
6852 if (auto versionAttr = dyn_cast<omp::VersionAttr>(attr)) {
6853 llvm::OpenMPIRBuilder *ompBuilder =
6854 moduleTranslation.getOpenMPBuilder();
6855 ompBuilder->M.addModuleFlag(llvm::Module::Max, "openmp",
6856 versionAttr.getVersion());
6857 return success();
6858 }
6859 return failure();
6860 })
6861 .Case("omp.declare_target",
6862 [&](Attribute attr) {
6863 if (auto declareTargetAttr =
6864 dyn_cast<omp::DeclareTargetAttr>(attr))
6865 return convertDeclareTargetAttr(op, declareTargetAttr,
6866 moduleTranslation);
6867 return failure();
6868 })
6869 .Case("omp.requires",
6870 [&](Attribute attr) {
6871 if (auto requiresAttr = dyn_cast<omp::ClauseRequiresAttr>(attr)) {
6872 using Requires = omp::ClauseRequires;
6873 Requires flags = requiresAttr.getValue();
6874 llvm::OpenMPIRBuilderConfig &config =
6875 moduleTranslation.getOpenMPBuilder()->Config;
6876 config.setHasRequiresReverseOffload(
6877 bitEnumContainsAll(flags, Requires::reverse_offload));
6878 config.setHasRequiresUnifiedAddress(
6879 bitEnumContainsAll(flags, Requires::unified_address));
6880 config.setHasRequiresUnifiedSharedMemory(
6881 bitEnumContainsAll(flags, Requires::unified_shared_memory));
6882 config.setHasRequiresDynamicAllocators(
6883 bitEnumContainsAll(flags, Requires::dynamic_allocators));
6884 return success();
6885 }
6886 return failure();
6887 })
6888 .Case("omp.target_triples",
6889 [&](Attribute attr) {
6890 if (auto triplesAttr = dyn_cast<ArrayAttr>(attr)) {
6891 llvm::OpenMPIRBuilderConfig &config =
6892 moduleTranslation.getOpenMPBuilder()->Config;
6893 config.TargetTriples.clear();
6894 config.TargetTriples.reserve(triplesAttr.size());
6895 for (Attribute tripleAttr : triplesAttr) {
6896 if (auto tripleStrAttr = dyn_cast<StringAttr>(tripleAttr))
6897 config.TargetTriples.emplace_back(tripleStrAttr.getValue());
6898 else
6899 return failure();
6900 }
6901 return success();
6902 }
6903 return failure();
6904 })
6905 .Default([](Attribute) {
6906 // Fall through for omp attributes that do not require lowering.
6907 return success();
6908 })(attribute.getValue());
6909
6910 return failure();
6911}
6912
6913// Returns true if the operation is not inside a TargetOp, it is part of a
6914// function and that function is not declare target.
6915static bool isHostDeviceOp(Operation *op) {
6916 // Assumes no reverse offloading
6917 if (op->getParentOfType<omp::TargetOp>())
6918 return false;
6919
6920 if (auto parentFn = op->getParentOfType<LLVM::LLVMFuncOp>()) {
6921 if (auto declareTargetIface =
6922 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
6923 parentFn.getOperation()))
6924 if (declareTargetIface.isDeclareTarget() &&
6925 declareTargetIface.getDeclareTargetDeviceType() !=
6926 mlir::omp::DeclareTargetDeviceType::host)
6927 return false;
6928
6929 return true;
6930 }
6931
6932 return false;
6933}
6934
6935static llvm::Function *getOmpTargetAlloc(llvm::IRBuilderBase &builder,
6936 llvm::Module *llvmModule) {
6937 llvm::Type *i64Ty = builder.getInt64Ty();
6938 llvm::Type *i32Ty = builder.getInt32Ty();
6939 llvm::Type *returnType = builder.getPtrTy(0);
6940 llvm::FunctionType *fnType =
6941 llvm::FunctionType::get(returnType, {i64Ty, i32Ty}, false);
6942 llvm::Function *func = cast<llvm::Function>(
6943 llvmModule->getOrInsertFunction("omp_target_alloc", fnType).getCallee());
6944 return func;
6945}
6946
6947static LogicalResult
6948convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder,
6949 LLVM::ModuleTranslation &moduleTranslation) {
6950 auto allocMemOp = cast<omp::TargetAllocMemOp>(opInst);
6951 if (!allocMemOp)
6952 return failure();
6953
6954 // Get "omp_target_alloc" function
6955 llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
6956 llvm::Function *ompTargetAllocFunc = getOmpTargetAlloc(builder, llvmModule);
6957 // Get the corresponding device value in llvm
6958 mlir::Value deviceNum = allocMemOp.getDevice();
6959 llvm::Value *llvmDeviceNum = moduleTranslation.lookupValue(deviceNum);
6960 // Get the allocation size.
6961 llvm::DataLayout dataLayout = llvmModule->getDataLayout();
6962 mlir::Type heapTy = allocMemOp.getAllocatedType();
6963 llvm::Type *llvmHeapTy = moduleTranslation.convertType(heapTy);
6964 llvm::TypeSize typeSize = dataLayout.getTypeStoreSize(llvmHeapTy);
6965 llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue());
6966 for (auto typeParam : allocMemOp.getTypeparams())
6967 allocSize =
6968 builder.CreateMul(allocSize, moduleTranslation.lookupValue(typeParam));
6969 // Create call to "omp_target_alloc" with the args as translated llvm values.
6970 llvm::CallInst *call =
6971 builder.CreateCall(ompTargetAllocFunc, {allocSize, llvmDeviceNum});
6972 llvm::Value *resultI64 = builder.CreatePtrToInt(call, builder.getInt64Ty());
6973
6974 // Map the result
6975 moduleTranslation.mapValue(allocMemOp.getResult(), resultI64);
6976 return success();
6977}
6978
6979static llvm::Function *getOmpTargetFree(llvm::IRBuilderBase &builder,
6980 llvm::Module *llvmModule) {
6981 llvm::Type *ptrTy = builder.getPtrTy(0);
6982 llvm::Type *i32Ty = builder.getInt32Ty();
6983 llvm::Type *voidTy = builder.getVoidTy();
6984 llvm::FunctionType *fnType =
6985 llvm::FunctionType::get(voidTy, {ptrTy, i32Ty}, false);
6986 llvm::Function *func = dyn_cast<llvm::Function>(
6987 llvmModule->getOrInsertFunction("omp_target_free", fnType).getCallee());
6988 return func;
6989}
6990
6991static LogicalResult
6992convertTargetFreeMemOp(Operation &opInst, llvm::IRBuilderBase &builder,
6993 LLVM::ModuleTranslation &moduleTranslation) {
6994 auto freeMemOp = cast<omp::TargetFreeMemOp>(opInst);
6995 if (!freeMemOp)
6996 return failure();
6997
6998 // Get "omp_target_free" function
6999 llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
7000 llvm::Function *ompTragetFreeFunc = getOmpTargetFree(builder, llvmModule);
7001 // Get the corresponding device value in llvm
7002 mlir::Value deviceNum = freeMemOp.getDevice();
7003 llvm::Value *llvmDeviceNum = moduleTranslation.lookupValue(deviceNum);
7004 // Get the corresponding heapref value in llvm
7005 mlir::Value heapref = freeMemOp.getHeapref();
7006 llvm::Value *llvmHeapref = moduleTranslation.lookupValue(heapref);
7007 // Convert heapref int to ptr and call "omp_target_free"
7008 llvm::Value *intToPtr =
7009 builder.CreateIntToPtr(llvmHeapref, builder.getPtrTy(0));
7010 builder.CreateCall(ompTragetFreeFunc, {intToPtr, llvmDeviceNum});
7011 return success();
7012}
7013
7014/// Given an OpenMP MLIR operation, create the corresponding LLVM IR (including
7015/// OpenMP runtime calls).
7016LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
7017 Operation *op, llvm::IRBuilderBase &builder,
7018 LLVM::ModuleTranslation &moduleTranslation) const {
7019 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
7020
7021 if (ompBuilder->Config.isTargetDevice() &&
7022 !isa<omp::TargetOp, omp::MapInfoOp, omp::TerminatorOp, omp::YieldOp>(
7023 op) &&
7024 isHostDeviceOp(op))
7025 return op->emitOpError() << "unsupported host op found in device";
7026
7027 // For each loop, introduce one stack frame to hold loop information. Ensure
7028 // this is only done for the outermost loop wrapper to prevent introducing
7029 // multiple stack frames for a single loop. Initially set to null, the loop
7030 // information structure is initialized during translation of the nested
7031 // omp.loop_nest operation, making it available to translation of all loop
7032 // wrappers after their body has been successfully translated.
7033 bool isOutermostLoopWrapper =
7034 isa_and_present<omp::LoopWrapperInterface>(op) &&
7035 !dyn_cast_if_present<omp::LoopWrapperInterface>(op->getParentOp());
7036
7037 if (isOutermostLoopWrapper)
7038 moduleTranslation.stackPush<OpenMPLoopInfoStackFrame>();
7039
7040 auto result =
7041 llvm::TypeSwitch<Operation *, LogicalResult>(op)
7042 .Case([&](omp::BarrierOp op) -> LogicalResult {
7044 return failure();
7045
7046 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
7047 ompBuilder->createBarrier(builder.saveIP(),
7048 llvm::omp::OMPD_barrier);
7049 LogicalResult res = handleError(afterIP, *op);
7050 if (res.succeeded()) {
7051 // If the barrier generated a cancellation check, the insertion
7052 // point might now need to be changed to a new continuation block
7053 builder.restoreIP(*afterIP);
7054 }
7055 return res;
7056 })
7057 .Case([&](omp::TaskyieldOp op) {
7059 return failure();
7060
7061 ompBuilder->createTaskyield(builder.saveIP());
7062 return success();
7063 })
7064 .Case([&](omp::FlushOp op) {
7066 return failure();
7067
7068 // No support in Openmp runtime function (__kmpc_flush) to accept
7069 // the argument list.
7070 // OpenMP standard states the following:
7071 // "An implementation may implement a flush with a list by ignoring
7072 // the list, and treating it the same as a flush without a list."
7073 //
7074 // The argument list is discarded so that, flush with a list is
7075 // treated same as a flush without a list.
7076 ompBuilder->createFlush(builder.saveIP());
7077 return success();
7078 })
7079 .Case([&](omp::ParallelOp op) {
7080 return convertOmpParallel(op, builder, moduleTranslation);
7081 })
7082 .Case([&](omp::MaskedOp) {
7083 return convertOmpMasked(*op, builder, moduleTranslation);
7084 })
7085 .Case([&](omp::MasterOp) {
7086 return convertOmpMaster(*op, builder, moduleTranslation);
7087 })
7088 .Case([&](omp::CriticalOp) {
7089 return convertOmpCritical(*op, builder, moduleTranslation);
7090 })
7091 .Case([&](omp::OrderedRegionOp) {
7092 return convertOmpOrderedRegion(*op, builder, moduleTranslation);
7093 })
7094 .Case([&](omp::OrderedOp) {
7095 return convertOmpOrdered(*op, builder, moduleTranslation);
7096 })
7097 .Case([&](omp::WsloopOp) {
7098 return convertOmpWsloop(*op, builder, moduleTranslation);
7099 })
7100 .Case([&](omp::SimdOp) {
7101 return convertOmpSimd(*op, builder, moduleTranslation);
7102 })
7103 .Case([&](omp::AtomicReadOp) {
7104 return convertOmpAtomicRead(*op, builder, moduleTranslation);
7105 })
7106 .Case([&](omp::AtomicWriteOp) {
7107 return convertOmpAtomicWrite(*op, builder, moduleTranslation);
7108 })
7109 .Case([&](omp::AtomicUpdateOp op) {
7110 return convertOmpAtomicUpdate(op, builder, moduleTranslation);
7111 })
7112 .Case([&](omp::AtomicCaptureOp op) {
7113 return convertOmpAtomicCapture(op, builder, moduleTranslation);
7114 })
7115 .Case([&](omp::CancelOp op) {
7116 return convertOmpCancel(op, builder, moduleTranslation);
7117 })
7118 .Case([&](omp::CancellationPointOp op) {
7119 return convertOmpCancellationPoint(op, builder, moduleTranslation);
7120 })
7121 .Case([&](omp::SectionsOp) {
7122 return convertOmpSections(*op, builder, moduleTranslation);
7123 })
7124 .Case([&](omp::SingleOp op) {
7125 return convertOmpSingle(op, builder, moduleTranslation);
7126 })
7127 .Case([&](omp::TeamsOp op) {
7128 return convertOmpTeams(op, builder, moduleTranslation);
7129 })
7130 .Case([&](omp::TaskOp op) {
7131 return convertOmpTaskOp(op, builder, moduleTranslation);
7132 })
7133 .Case([&](omp::TaskloopOp op) {
7134 return convertOmpTaskloopOp(*op, builder, moduleTranslation);
7135 })
7136 .Case([&](omp::TaskgroupOp op) {
7137 return convertOmpTaskgroupOp(op, builder, moduleTranslation);
7138 })
7139 .Case([&](omp::TaskwaitOp op) {
7140 return convertOmpTaskwaitOp(op, builder, moduleTranslation);
7141 })
7142 .Case<omp::YieldOp, omp::TerminatorOp, omp::DeclareMapperOp,
7143 omp::DeclareMapperInfoOp, omp::DeclareReductionOp,
7144 omp::CriticalDeclareOp>([](auto op) {
7145 // `yield` and `terminator` can be just omitted. The block structure
7146 // was created in the region that handles their parent operation.
7147 // `declare_reduction` will be used by reductions and is not
7148 // converted directly, skip it.
7149 // `declare_mapper` and `declare_mapper.info` are handled whenever
7150 // they are referred to through a `map` clause.
7151 // `critical.declare` is only used to declare names of critical
7152 // sections which will be used by `critical` ops and hence can be
7153 // ignored for lowering. The OpenMP IRBuilder will create unique
7154 // name for critical section names.
7155 return success();
7156 })
7157 .Case([&](omp::ThreadprivateOp) {
7158 return convertOmpThreadprivate(*op, builder, moduleTranslation);
7159 })
7160 .Case<omp::TargetDataOp, omp::TargetEnterDataOp,
7161 omp::TargetExitDataOp, omp::TargetUpdateOp>([&](auto op) {
7162 return convertOmpTargetData(op, builder, moduleTranslation);
7163 })
7164 .Case([&](omp::TargetOp) {
7165 return convertOmpTarget(*op, builder, moduleTranslation);
7166 })
7167 .Case([&](omp::DistributeOp) {
7168 return convertOmpDistribute(*op, builder, moduleTranslation);
7169 })
7170 .Case([&](omp::LoopNestOp) {
7171 return convertOmpLoopNest(*op, builder, moduleTranslation);
7172 })
7173 .Case<omp::MapInfoOp, omp::MapBoundsOp, omp::PrivateClauseOp>(
7174 [&](auto op) {
7175 // No-op, should be handled by relevant owning operations e.g.
7176 // TargetOp, TargetEnterDataOp, TargetExitDataOp, TargetDataOp
7177 // etc. and then discarded
7178 return success();
7179 })
7180 .Case([&](omp::NewCliOp op) {
7181 // Meta-operation: Doesn't do anything by itself, but used to
7182 // identify a loop.
7183 return success();
7184 })
7185 .Case([&](omp::CanonicalLoopOp op) {
7186 return convertOmpCanonicalLoopOp(op, builder, moduleTranslation);
7187 })
7188 .Case([&](omp::UnrollHeuristicOp op) {
7189 // FIXME: Handling omp.unroll_heuristic as an executable requires
7190 // that the generator (e.g. omp.canonical_loop) has been seen first.
7191 // For construct that require all codegen to occur inside a callback
7192 // (e.g. OpenMPIRBilder::createParallel), all codegen of that
7193 // contained region including their transformations must occur at
7194 // the omp.canonical_loop.
7195 return applyUnrollHeuristic(op, builder, moduleTranslation);
7196 })
7197 .Case([&](omp::TileOp op) {
7198 return applyTile(op, builder, moduleTranslation);
7199 })
7200 .Case([&](omp::TargetAllocMemOp) {
7201 return convertTargetAllocMemOp(*op, builder, moduleTranslation);
7202 })
7203 .Case([&](omp::TargetFreeMemOp) {
7204 return convertTargetFreeMemOp(*op, builder, moduleTranslation);
7205 })
7206 .Default([&](Operation *inst) {
7207 return inst->emitError()
7208 << "not yet implemented: " << inst->getName();
7209 });
7210
7211 if (isOutermostLoopWrapper)
7212 moduleTranslation.stackPop();
7213
7214 return result;
7215}
7216
7218 registry.insert<omp::OpenMPDialect>();
7219 registry.addExtension(+[](MLIRContext *ctx, omp::OpenMPDialect *dialect) {
7220 dialect->addInterfaces<OpenMPDialectLLVMIRTranslationInterface>();
7221 });
7222}
7223
7225 DialectRegistry registry;
7227 context.appendDialectRegistry(registry);
7228}
for(Operation *op :ops)
return success()
lhs
static ze_device_handle_t getDevice(const uint32_t driverIdx=0, const int32_t devIdx=0)
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
if(!isCopyOut)
static bool mapTypeToBool(ClauseMapFlags value, ClauseMapFlags flag)
static void handleDeclareTargetMapVar(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::Function *func)
static LogicalResult convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP atomic update operation using OpenMPIRBuilder.
static llvm::omp::OrderKind convertOrderKind(std::optional< omp::ClauseOrderKind > o)
Convert Order attribute to llvm::omp::OrderKind.
static bool teamsReductionContainedInDistribute(omp::TeamsOp teamsOp)
static LogicalResult convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
owningDataPtrPtrReductionGens[i]
static Operation * getGlobalOpFromValue(Value value)
static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind convertToCaptureClauseKind(mlir::omp::DeclareTargetCaptureClause captureClause)
static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo, bool first)
static OpTy castOrGetParentOfType(Operation *op, bool immediateParent=false)
If op is of the given type parameter, return it casted to that type. Otherwise, if its immediate pare...
static LogicalResult convertOmpOrderedRegion(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'ordered_region' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertTargetFreeMemOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static void processMapMembersWithParent(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, llvm::omp::OpenMPOffloadMappingFlags memberOfFlag, TargetDirectiveEnumTy targetDirective)
static LogicalResult convertOmpAtomicWrite(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an omp.atomic.write operation to LLVM IR.
static OwningAtomicReductionGen makeAtomicReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Create an OpenMPIRBuilder-compatible atomic reduction generator for the given reduction declaration.
static OwningDataPtrPtrReductionGen makeRefDataPtrGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, bool isByRef)
Create an OpenMPIRBuilder-compatible data_ptr_ptr reduction generator for the given reduction declara...
static llvm::Value * getRefPtrIfDeclareTarget(Value value, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult applyUnrollHeuristic(omp::UnrollHeuristicOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Apply a #pragma omp unroll / "!$omp unroll" transformation using the OpenMPIRBuilder.
static LogicalResult convertOmpMaster(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'master' operation into LLVM IR using OpenMPIRBuilder.
static void getAsIntegers(ArrayAttr values, llvm::SmallVector< int64_t > &ints)
static llvm::Value * findAssociatedValue(Value privateVar, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Return the llvm::Value * corresponding to the privateVar that is being privatized....
static ArrayRef< bool > getIsByRef(std::optional< ArrayRef< bool > > attr)
static LogicalResult convertOmpTaskwaitOp(omp::TaskwaitOp twOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP loop nest into LLVM IR using OpenMPIRBuilder.
static void popCancelFinalizationCB(const ArrayRef< llvm::BranchInst * > cancelTerminators, llvm::OpenMPIRBuilder &ompBuilder, const llvm::OpenMPIRBuilder::InsertPointTy &afterIP)
If we cancelled the construct, we should branch to the finalization block of that construct....
static LogicalResult cleanupPrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, Location loc, SmallVectorImpl< llvm::Value * > &llvmPrivateVars, SmallVectorImpl< omp::PrivateClauseOp > &privateDecls)
static void createAlteredByCaptureMap(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder)
static LogicalResult convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP task construct into LLVM IR using OpenMPIRBuilder.
static void genMapInfos(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, TargetDirectiveEnumTy targetDirective)
static llvm::AtomicOrdering convertAtomicOrdering(std::optional< omp::ClauseMemoryOrderKind > ao)
Convert an Atomic Ordering attribute to llvm::AtomicOrdering.
static void setInsertPointForPossiblyEmptyBlock(llvm::IRBuilderBase &builder, llvm::BasicBlock *block=nullptr)
static LogicalResult convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'critical' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, TargetDirectiveEnumTy targetDirective)
static LogicalResult convertOmpCanonicalLoopOp(omp::CanonicalLoopOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Convert an omp.canonical_loop to LLVM-IR.
static LogicalResult convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static std::optional< int64_t > extractConstInteger(Value value)
If the given value is defined by an llvm.mlir.constant operation and it is of an integer type,...
static llvm::Expected< llvm::Value * > initPrivateVar(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, omp::PrivateClauseOp &privDecl, llvm::Value *nonPrivateVar, BlockArgument &blockArg, llvm::Value *llvmPrivateVar, llvm::BasicBlock *privInitBlock, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Initialize a single (first)private variable. You probably want to use allocateAndInitPrivateVars inst...
static LogicalResult allocAndInitializeReductionVars(OP op, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, llvm::ArrayRef< bool > isByRef)
static LogicalResult convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Function * getOmpTargetAlloc(llvm::IRBuilderBase &builder, llvm::Module *llvmModule)
static llvm::Expected< llvm::Function * > emitUserDefinedMapper(Operation *declMapperOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::StringRef mapperFuncName, TargetDirectiveEnumTy targetDirective)
static llvm::Expected< llvm::BasicBlock * > allocatePrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, PrivateVarsInfo &privateVarsInfo, const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Allocate and initialize delayed private variables. Returns the basic block which comes after all of t...
static LogicalResult convertOmpOrdered(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'ordered' operation into LLVM IR using OpenMPIRBuilder.
static void processMapWithMembersOf(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, TargetDirectiveEnumTy targetDirective)
static LogicalResult convertOmpTaskloopOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpMasked(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'masked' operation into LLVM IR using OpenMPIRBuilder.
static llvm::AtomicRMWInst::BinOp convertBinOpToAtomic(Operation &op)
Converts an LLVM dialect binary operation to the corresponding enum value for atomicrmw supported bin...
static void pushCancelFinalizationCB(SmallVectorImpl< llvm::BranchInst * > &cancelTerminators, llvm::IRBuilderBase &llvmBuilder, llvm::OpenMPIRBuilder &ompBuilder, mlir::Operation *op, llvm::omp::Directive cancelDirective)
Shared implementation of a callback which adds a termiator for the new block created for the branch t...
static LogicalResult convertOmpCancel(omp::CancelOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp)
allocatedType moduleTranslation static convertType(allocatedType) LogicalResult inlineOmpRegionCleanup(llvm::SmallVectorImpl< Region * > &cleanupRegions, llvm::ArrayRef< llvm::Value * > privateVariables, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, StringRef regionName, bool shouldLoadCleanupRegionArg=true)
handling of DeclareReductionOp's cleanup region
static llvm::SmallString< 64 > getDeclareTargetRefPtrSuffix(LLVM::GlobalOp globalOp, llvm::OpenMPIRBuilder &ompBuilder)
static llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Error initPrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, PrivateVarsInfo &privateVarsInfo, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
static llvm::CanonicalLoopInfo * findCurrentLoopInfo(LLVM::ModuleTranslation &moduleTranslation)
Find the loop information structure for the loop nest being translated.
static OwningReductionGen makeReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Create an OpenMPIRBuilder-compatible reduction generator for the given reduction declaration.
static std::vector< llvm::Value * > calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, bool isArrayTy, OperandRange bounds)
This function calculates the array/pointer offset for map data provided with bounds operations,...
static LogicalResult convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts the OpenMP parallel operation to LLVM IR.
static llvm::OpenMPIRBuilder::InsertPointTy findAllocaInsertPoint(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Find the insertion point for allocas given the current insertion point for normal operations in the b...
static LogicalResult inlineConvertOmpRegions(Region &region, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::Value * > *continuationBlockArgs=nullptr)
Translates the blocks contained in the given region and appends them to at the current insertion poin...
static LogicalResult convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP Threadprivate operation into LLVM IR using OpenMPIRBuilder.
static omp::PrivateClauseOp findPrivatizer(Operation *from, SymbolRefAttr symbolName)
Looks up from the operation from and returns the PrivateClauseOp with name symbolName.
static LogicalResult convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Expected< llvm::Function * > getOrCreateUserDefinedMapperFunc(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, TargetDirectiveEnumTy targetDirective)
static uint64_t getTypeByteSize(mlir::Type type, const DataLayout &dl)
static void processIndividualMap(MapInfoData &mapData, size_t mapDataIdx, MapInfosTy &combinedInfo, TargetDirectiveEnumTy targetDirective, int mapDataParentIdx=-1)
static void getTargetEntryUniqueInfo(llvm::TargetRegionEntryInfo &targetInfo, omp::TargetOp targetOp, llvm::StringRef parentName="")
static void extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads, Value &numTeamsLower, Value &numTeamsUpper, Value &threadLimit, llvm::SmallVectorImpl< Value > *lowerBounds=nullptr, llvm::SmallVectorImpl< Value > *upperBounds=nullptr, llvm::SmallVectorImpl< Value > *steps=nullptr)
Follow uses of host_eval-defined block arguments of the given omp.target operation and populate outpu...
static llvm::Expected< llvm::BasicBlock * > convertOmpOpRegions(Region &region, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::PHINode * > *continuationBlockPHIs=nullptr)
Converts the given region that appears within an OpenMP dialect operation to LLVM IR,...
static LogicalResult copyFirstPrivateVars(mlir::Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::Value * > &moldVars, ArrayRef< llvm::Value * > llvmPrivateVars, SmallVectorImpl< omp::PrivateClauseOp > &privateDecls, bool insertBarrier, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
static LogicalResult allocReductionVars(T loop, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, SmallVectorImpl< DeferredStore > &deferredStores, llvm::ArrayRef< bool > isByRefs)
Allocate space for privatized reduction variables.
static bool constructIsCancellable(Operation *op)
Returns true if the construct contains omp.cancel or omp.cancellation_point.
static llvm::omp::OpenMPOffloadMappingFlags convertClauseMapFlags(omp::ClauseMapFlags mlirFlags)
static llvm::IRBuilderBase::InsertPoint createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg, llvm::Value *input, llvm::Value *&retVal, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase::InsertPoint allocaIP, llvm::IRBuilderBase::InsertPoint codeGenIP)
static void mapInitializationArgs(T loop, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, DenseMap< Value, llvm::Value * > &reductionVariableMap, unsigned i)
Map input arguments to reduction initialization region.
static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind)
Convert ProcBindKind from MLIR-generated enum to LLVM enum.
static void getOverlappedMembers(llvm::SmallVectorImpl< size_t > &overlapMapDataIdxs, omp::MapInfoOp parentOp)
static LogicalResult convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP single construct into LLVM IR using OpenMPIRBuilder.
static bool isDeclareTargetTo(Value value)
static uint64_t getArrayElementSizeInBits(LLVM::LLVMArrayType arrTy, DataLayout &dl)
static void collectReductionDecls(T op, SmallVectorImpl< omp::DeclareReductionOp > &reductions)
Populates reductions with reduction declarations used in the given op.
static LogicalResult handleError(llvm::Error error, Operation &op)
static LogicalResult convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind convertToDeviceClauseKind(mlir::omp::DeclareTargetDeviceType deviceClause)
static LogicalResult checkImplementationStatus(Operation &op)
Check whether translation to LLVM IR for the given operation is currently supported.
static LogicalResult createReductionsAndCleanup(OP op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, ArrayRef< llvm::Value * > privateReductionVariables, ArrayRef< bool > isByRef, bool isNowait=false, bool isTeamsReduction=false)
static LogicalResult convertOmpCancellationPoint(omp::CancellationPointOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static bool opIsInSingleThread(mlir::Operation *op)
This can't always be determined statically, but when we can, it is good to avoid generating compiler-...
static uint64_t getReductionDataSize(OpTy &op)
static LogicalResult convertOmpAtomicRead(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Convert omp.atomic.read operation to LLVM IR.
static llvm::omp::Directive convertCancellationConstructType(omp::ClauseCancellationConstructType directive)
static void initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp, llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs, bool isTargetDevice, bool isGPU)
Populate default MinTeams, MaxTeams and MaxThreads to their default values as stated by the correspon...
static void initTargetRuntimeAttrs(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, omp::TargetOp targetOp, Operation *capturedOp, llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs)
Gather LLVM runtime values for all clauses evaluated in the host that are passed to the kernel invoca...
static LogicalResult convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static bool isHostDeviceOp(Operation *op)
static bool isDeclareTargetLink(Value value)
static LogicalResult convertFlagsAttr(Operation *op, mlir::omp::FlagsAttr attribute, LLVM::ModuleTranslation &moduleTranslation)
Lowers the FlagsAttr which is applied to the module on the device pass when offloading,...
static bool checkIfPointerMap(omp::MapInfoOp mapOp)
static LogicalResult applyTile(omp::TileOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Apply a #pragma omp tile / !$omp tile transformation using the OpenMPIRBuilder.
static void sortMapIndices(llvm::SmallVectorImpl< size_t > &indices, omp::MapInfoOp mapInfo)
static void buildDependData(std::optional< ArrayAttr > dependKinds, OperandRange dependVars, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::OpenMPIRBuilder::DependData > &dds)
static llvm::Function * getOmpTargetFree(llvm::IRBuilderBase &builder, llvm::Module *llvmModule)
static LogicalResult convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder.
static void collectMapDataFromMapOperands(MapInfoData &mapData, SmallVectorImpl< Value > &mapVars, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, llvm::IRBuilderBase &builder, ArrayRef< Value > useDevPtrOperands={}, ArrayRef< Value > useDevAddrOperands={}, ArrayRef< Value > hasDevAddrOperands={})
static void extractAtomicControlFlags(omp::AtomicUpdateOp atomicUpdateOp, bool &isIgnoreDenormalMode, bool &isFineGrainedMemory, bool &isRemoteMemory)
static Operation * genLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr, unsigned numCases, bool needsUniv, ArrayRef< TensorLevel > tidLvls)
Generates a for-loop or a while-loop, depending on whether it implements singleton iteration or co-it...
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition TypeID.h:331
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:309
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
Operation & front()
Definition Block.h:163
Operation & back()
Definition Block.h:162
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
iterator begin()
Definition Block.h:153
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
An instance of this location represents a tuple of file, line number, and column number.
Definition Location.h:174
Implementation class for module translation.
llvm::BasicBlock * lookupBlock(Block *block) const
Finds an LLVM IR basic block that corresponds to the given MLIR block.
WalkResult stackWalk(llvm::function_ref< WalkResult(T &)> callback)
Calls callback for every ModuleTranslation stack frame of type T starting from the top of the stack.
void stackPush(Args &&...args)
Creates a stack frame of type T on ModuleTranslation stack.
LogicalResult convertBlock(Block &bb, bool ignoreArguments, llvm::IRBuilderBase &builder)
Translates the contents of the given block to LLVM IR using this translator.
SmallVector< llvm::Value * > lookupValues(ValueRange values)
Looks up remapped a list of remapped values.
void mapFunction(StringRef name, llvm::Function *func)
Stores the mapping between a function name and its LLVM IR representation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
void invalidateOmpLoop(omp::NewCliOp mlir)
Mark an OpenMP loop as having been consumed.
SymbolTableCollection & symbolTable()
llvm::Type * convertType(Type type)
Converts the type from MLIR LLVM dialect to LLVM.
llvm::OpenMPIRBuilder * getOpenMPBuilder()
Returns the OpenMP IR builder associated with the LLVM IR module being constructed.
void mapOmpLoop(omp::NewCliOp mlir, llvm::CanonicalLoopInfo *llvm)
Map an MLIR OpenMP dialect CanonicalLoopInfo to its lowered LLVM-IR OpenMPIRBuilder CanonicalLoopInfo...
llvm::GlobalValue * lookupGlobal(Operation *op)
Finds an LLVM IR global value that corresponds to the given MLIR operation defining a global value.
SaveStateStack< T, ModuleTranslation > SaveStack
RAII object calling stackPush/stackPop on construction/destruction.
llvm::Function * lookupFunction(StringRef name) const
Finds an LLVM IR function by its name.
void mapBlock(Block *mlir, llvm::BasicBlock *llvm)
Stores the mapping between an MLIR block and LLVM IR basic block.
llvm::Module * getLLVMModule()
Returns the LLVM module in which the IR is being constructed.
void stackPop()
Pops the last element from the ModuleTranslation stack.
void forgetMapping(Region &region)
Removes the mapping for blocks contained in the region and values defined in these blocks.
void mapValue(Value mlir, llvm::Value *llvm)
Stores the mapping between an MLIR value and its LLVM IR counterpart.
llvm::CanonicalLoopInfo * lookupOMPLoop(omp::NewCliOp mlir) const
Find the LLVM-IR loop that represents an MLIR loop.
llvm::LLVMContext & getLLVMContext() const
Returns the LLVM context in which the IR is being constructed.
Utility class to translate MLIR LLVM dialect types to LLVM IR.
Definition TypeToLLVM.h:39
unsigned getPreferredAlignment(Type type, const llvm::DataLayout &layout)
Returns the preferred alignment for the type given the data layout.
T findInstanceOf()
Return an instance of the given location type if one is nested under the current location.
Definition Location.h:45
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
void appendDialectRegistry(const DialectRegistry &registry)
Append the contents of the given dialect registry to the registry associated with this context.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:686
Value getOperand(unsigned idx)
Definition Operation.h:350
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
unsigned getNumOperands()
Definition Operation.h:346
OperandRange operand_range
Definition Operation.h:371
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:238
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:797
user_range getUsers()
Returns a range of all users.
Definition Operation.h:873
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
BlockArgListType getArguments()
Definition Region.h:81
bool empty()
Definition Region.h:60
unsigned getNumArguments()
Definition Region.h:123
iterator begin()
Definition Region.h:55
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition Region.h:200
BlockListType & getBlocks()
Definition Region.h:45
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
Concrete CRTP base class for StateStack frames.
Definition StateStack.h:47
@ Private
The symbol is private and may only be referenced by SymbolRefAttrs local to the operations within the...
Definition SymbolTable.h:97
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult advance()
Definition WalkResult.h:47
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition WalkResult.h:51
static WalkResult interrupt()
Definition WalkResult.h:46
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Definition CallGraph.h:229
void connectPHINodes(Region &region, const ModuleTranslation &state)
For all blocks in the region that were converted to LLVM IR using the given ModuleTranslation,...
llvm::Constant * createMappingInformation(Location loc, llvm::OpenMPIRBuilder &builder)
Create a constant string representing the mapping information extracted from the MLIR location inform...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
Include the generated interface declarations.
SetVector< Block * > getBlocksSortedByDominance(Region &region)
Gets a list of blocks that is sorted according to dominance.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
const FrozenRewritePatternSet GreedyRewriteConfig config
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void registerOpenMPDialectTranslation(DialectRegistry &registry)
Register the OpenMP dialect and the translation from it to the LLVM IR in the given registry;.
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:123
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition Utils.cpp:1297
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:118
A util to collect info needed to convert delayed privatizers from MLIR to LLVM.
SmallVector< mlir::Value > mlirVars
SmallVector< omp::PrivateClauseOp > privatizers
MutableArrayRef< BlockArgument > blockArgs
SmallVector< llvm::Value * > llvmVars