MLIR 23.0.0git
OpenMPToLLVMIRTranslation.cpp
Go to the documentation of this file.
1//===- OpenMPToLLVMIRTranslation.cpp - Translate OpenMP dialect to LLVM IR-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a translation between the MLIR OpenMP dialect and LLVM
10// IR.
11//
12//===----------------------------------------------------------------------===//
19#include "mlir/IR/Operation.h"
21#include "mlir/Support/LLVM.h"
24
25#include "llvm/ADT/ArrayRef.h"
26#include "llvm/ADT/SmallVector.h"
27#include "llvm/ADT/TypeSwitch.h"
28#include "llvm/Frontend/OpenMP/OMPConstants.h"
29#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
30#include "llvm/IR/Constants.h"
31#include "llvm/IR/DebugInfoMetadata.h"
32#include "llvm/IR/DerivedTypes.h"
33#include "llvm/IR/IRBuilder.h"
34#include "llvm/IR/MDBuilder.h"
35#include "llvm/IR/ReplaceConstant.h"
36#include "llvm/Support/FileSystem.h"
37#include "llvm/Support/VirtualFileSystem.h"
38#include "llvm/TargetParser/Triple.h"
39#include "llvm/Transforms/Utils/ModuleUtils.h"
40
41#include <cstdint>
42#include <iterator>
43#include <numeric>
44#include <optional>
45#include <utility>
46
47using namespace mlir;
48
49namespace {
50static llvm::omp::ScheduleKind
51convertToScheduleKind(std::optional<omp::ClauseScheduleKind> schedKind) {
52 if (!schedKind.has_value())
53 return llvm::omp::OMP_SCHEDULE_Default;
54 switch (schedKind.value()) {
55 case omp::ClauseScheduleKind::Static:
56 return llvm::omp::OMP_SCHEDULE_Static;
57 case omp::ClauseScheduleKind::Dynamic:
58 return llvm::omp::OMP_SCHEDULE_Dynamic;
59 case omp::ClauseScheduleKind::Guided:
60 return llvm::omp::OMP_SCHEDULE_Guided;
61 case omp::ClauseScheduleKind::Auto:
62 return llvm::omp::OMP_SCHEDULE_Auto;
63 case omp::ClauseScheduleKind::Runtime:
64 return llvm::omp::OMP_SCHEDULE_Runtime;
65 case omp::ClauseScheduleKind::Distribute:
66 return llvm::omp::OMP_SCHEDULE_Distribute;
67 }
68 llvm_unreachable("unhandled schedule clause argument");
69}
70
71/// ModuleTranslation stack frame for OpenMP operations. This keeps track of the
72/// insertion points for allocas.
73class OpenMPAllocaStackFrame
74 : public StateStackFrameBase<OpenMPAllocaStackFrame> {
75public:
77
78 explicit OpenMPAllocaStackFrame(llvm::OpenMPIRBuilder::InsertPointTy allocaIP)
79 : allocaInsertPoint(allocaIP) {}
80 llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
81};
82
83/// Stack frame to hold a \see llvm::CanonicalLoopInfo representing the
84/// collapsed canonical loop information corresponding to an \c omp.loop_nest
85/// operation.
86class OpenMPLoopInfoStackFrame
87 : public StateStackFrameBase<OpenMPLoopInfoStackFrame> {
88public:
89 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPLoopInfoStackFrame)
90 llvm::CanonicalLoopInfo *loopInfo = nullptr;
91};
92
93/// Custom error class to signal translation errors that don't need reporting,
94/// since encountering them will have already triggered relevant error messages.
95///
96/// Its purpose is to serve as the glue between MLIR failures represented as
97/// \see LogicalResult instances and \see llvm::Error instances used to
98/// propagate errors through the \see llvm::OpenMPIRBuilder. Generally, when an
99/// error of the first type is raised, a message is emitted directly (the \see
100/// LogicalResult itself does not hold any information). If we need to forward
101/// this error condition as an \see llvm::Error while avoiding triggering some
102/// redundant error reporting later on, we need a custom \see llvm::ErrorInfo
103/// class to just signal this situation has happened.
104///
105/// For example, this class should be used to trigger errors from within
106/// callbacks passed to the \see OpenMPIRBuilder when they were triggered by the
107/// translation of their own regions. This unclutters the error log from
108/// redundant messages.
109class PreviouslyReportedError
110 : public llvm::ErrorInfo<PreviouslyReportedError> {
111public:
112 void log(raw_ostream &) const override {
113 // Do not log anything.
114 }
115
116 std::error_code convertToErrorCode() const override {
117 llvm_unreachable(
118 "PreviouslyReportedError doesn't support ECError conversion");
119 }
120
121 // Used by ErrorInfo::classID.
122 static char ID;
123};
124
125char PreviouslyReportedError::ID = 0;
126
127/*
128 * Custom class for processing linear clause for omp.wsloop
129 * and omp.simd. Linear clause translation requires setup,
130 * initialization, update, and finalization at varying
131 * basic blocks in the IR. This class helps maintain
132 * internal state to allow consistent translation in
133 * each of these stages.
134 */
135
136class LinearClauseProcessor {
137
138private:
139 SmallVector<llvm::Value *> linearPreconditionVars;
140 SmallVector<llvm::Value *> linearLoopBodyTemps;
141 SmallVector<llvm::Value *> linearOrigVal;
142 SmallVector<llvm::Value *> linearSteps;
143 SmallVector<llvm::Type *> linearVarTypes;
144 llvm::BasicBlock *linearFinalizationBB;
145 llvm::BasicBlock *linearExitBB;
146 llvm::BasicBlock *linearLastIterExitBB;
147
148public:
149 // Register type for the linear variables
150 void registerType(LLVM::ModuleTranslation &moduleTranslation,
151 mlir::Attribute &ty) {
152 linearVarTypes.push_back(moduleTranslation.convertType(
153 mlir::cast<mlir::TypeAttr>(ty).getValue()));
154 }
155
156 // Allocate space for linear variabes
157 void createLinearVar(llvm::IRBuilderBase &builder,
158 LLVM::ModuleTranslation &moduleTranslation,
159 llvm::Value *linearVar, int idx) {
160 linearPreconditionVars.push_back(
161 builder.CreateAlloca(linearVarTypes[idx], nullptr, ".linear_var"));
162 llvm::Value *linearLoopBodyTemp =
163 builder.CreateAlloca(linearVarTypes[idx], nullptr, ".linear_result");
164 linearOrigVal.push_back(linearVar);
165 linearLoopBodyTemps.push_back(linearLoopBodyTemp);
166 }
167
168 // Initialize linear step
169 inline void initLinearStep(LLVM::ModuleTranslation &moduleTranslation,
170 mlir::Value &linearStep) {
171 linearSteps.push_back(moduleTranslation.lookupValue(linearStep));
172 }
173
174 // Emit IR for initialization of linear variables
175 void initLinearVar(llvm::IRBuilderBase &builder,
176 LLVM::ModuleTranslation &moduleTranslation,
177 llvm::BasicBlock *loopPreHeader) {
178 builder.SetInsertPoint(loopPreHeader->getTerminator());
179 for (size_t index = 0; index < linearOrigVal.size(); index++) {
180 llvm::LoadInst *linearVarLoad =
181 builder.CreateLoad(linearVarTypes[index], linearOrigVal[index]);
182 builder.CreateStore(linearVarLoad, linearPreconditionVars[index]);
183 }
184 }
185
186 // Emit IR for updating Linear variables
187 void updateLinearVar(llvm::IRBuilderBase &builder, llvm::BasicBlock *loopBody,
188 llvm::Value *loopInductionVar) {
189 builder.SetInsertPoint(loopBody->getTerminator());
190 for (size_t index = 0; index < linearPreconditionVars.size(); index++) {
191 llvm::Type *linearVarType = linearVarTypes[index];
192 llvm::Value *iv = loopInductionVar;
193 llvm::Value *step = linearSteps[index];
194
195 if (!iv->getType()->isIntegerTy())
196 llvm_unreachable("OpenMP loop induction variable must be an integer "
197 "type");
198
199 if (linearVarType->isIntegerTy()) {
200 // Integer path: normalize all arithmetic to linearVarType
201 iv = builder.CreateSExtOrTrunc(iv, linearVarType);
202 step = builder.CreateSExtOrTrunc(step, linearVarType);
203
204 llvm::LoadInst *linearVarStart =
205 builder.CreateLoad(linearVarType, linearPreconditionVars[index]);
206 llvm::Value *mulInst = builder.CreateMul(iv, step);
207 llvm::Value *addInst = builder.CreateAdd(linearVarStart, mulInst);
208 builder.CreateStore(addInst, linearLoopBodyTemps[index]);
209 } else if (linearVarType->isFloatingPointTy()) {
210 // Float path: perform multiply in integer, then convert to float
211 step = builder.CreateSExtOrTrunc(step, iv->getType());
212 llvm::Value *mulInst = builder.CreateMul(iv, step);
213
214 llvm::LoadInst *linearVarStart =
215 builder.CreateLoad(linearVarType, linearPreconditionVars[index]);
216 llvm::Value *mulFp = builder.CreateSIToFP(mulInst, linearVarType);
217 llvm::Value *addInst = builder.CreateFAdd(linearVarStart, mulFp);
218 builder.CreateStore(addInst, linearLoopBodyTemps[index]);
219 } else {
220 llvm_unreachable(
221 "Linear variable must be of integer or floating-point type");
222 }
223 }
224 }
225
226 // Linear variable finalization is conditional on the last logical iteration.
227 // Create BB splits to manage the same.
228 void splitLinearFiniBB(llvm::IRBuilderBase &builder,
229 llvm::BasicBlock *loopExit) {
230 linearFinalizationBB = loopExit->splitBasicBlock(
231 loopExit->getTerminator(), "omp_loop.linear_finalization");
232 linearExitBB = linearFinalizationBB->splitBasicBlock(
233 linearFinalizationBB->getTerminator(), "omp_loop.linear_exit");
234 linearLastIterExitBB = linearFinalizationBB->splitBasicBlock(
235 linearFinalizationBB->getTerminator(), "omp_loop.linear_lastiter_exit");
236 }
237
238 // Finalize the linear vars
239 llvm::OpenMPIRBuilder::InsertPointOrErrorTy
240 finalizeLinearVar(llvm::IRBuilderBase &builder,
241 LLVM::ModuleTranslation &moduleTranslation,
242 llvm::Value *lastIter) {
243 // Emit condition to check whether last logical iteration is being executed
244 builder.SetInsertPoint(linearFinalizationBB->getTerminator());
245 llvm::Value *loopLastIterLoad = builder.CreateLoad(
246 llvm::Type::getInt32Ty(builder.getContext()), lastIter);
247 llvm::Value *isLast =
248 builder.CreateCmp(llvm::CmpInst::ICMP_NE, loopLastIterLoad,
249 llvm::ConstantInt::get(
250 llvm::Type::getInt32Ty(builder.getContext()), 0));
251 // Store the linear variable values to original variables.
252 builder.SetInsertPoint(linearLastIterExitBB->getTerminator());
253 for (size_t index = 0; index < linearOrigVal.size(); index++) {
254 llvm::LoadInst *linearVarTemp =
255 builder.CreateLoad(linearVarTypes[index], linearLoopBodyTemps[index]);
256 builder.CreateStore(linearVarTemp, linearOrigVal[index]);
257 }
258
259 // Create conditional branch such that the linear variable
260 // values are stored to original variables only at the
261 // last logical iteration
262 builder.SetInsertPoint(linearFinalizationBB->getTerminator());
263 builder.CreateCondBr(isLast, linearLastIterExitBB, linearExitBB);
264 linearFinalizationBB->getTerminator()->eraseFromParent();
265 // Emit barrier
266 builder.SetInsertPoint(linearExitBB->getTerminator());
267 return moduleTranslation.getOpenMPBuilder()->createBarrier(
268 builder.saveIP(), llvm::omp::OMPD_barrier);
269 }
270
271 // Emit stores for linear variables. Useful in case of SIMD
272 // construct.
273 void emitStoresForLinearVar(llvm::IRBuilderBase &builder) {
274 for (size_t index = 0; index < linearOrigVal.size(); index++) {
275 llvm::LoadInst *linearVarTemp =
276 builder.CreateLoad(linearVarTypes[index], linearLoopBodyTemps[index]);
277 builder.CreateStore(linearVarTemp, linearOrigVal[index]);
278 }
279 }
280
281 // Rewrite all uses of the original variable in `BBName`
282 // with the linear variable in-place
283 void rewriteInPlace(llvm::IRBuilderBase &builder, const std::string &BBName,
284 size_t varIndex) {
285 llvm::SmallVector<llvm::User *> users;
286 for (llvm::User *user : linearOrigVal[varIndex]->users())
287 users.push_back(user);
288 for (auto *user : users) {
289 if (auto *userInst = dyn_cast<llvm::Instruction>(user)) {
290 if (userInst->getParent()->getName().str().find(BBName) !=
291 std::string::npos)
292 user->replaceUsesOfWith(linearOrigVal[varIndex],
293 linearLoopBodyTemps[varIndex]);
294 }
295 }
296 }
297};
298
299} // namespace
300
301/// Looks up from the operation from and returns the PrivateClauseOp with
302/// name symbolName
303static omp::PrivateClauseOp findPrivatizer(Operation *from,
304 SymbolRefAttr symbolName) {
305 omp::PrivateClauseOp privatizer =
307 symbolName);
308 assert(privatizer && "privatizer not found in the symbol table");
309 return privatizer;
310}
311
312/// Check whether translation to LLVM IR for the given operation is currently
313/// supported. If not, descriptive diagnostics will be emitted to let users know
314/// this is a not-yet-implemented feature.
315///
316/// \returns success if no unimplemented features are needed to translate the
317/// given operation.
318static LogicalResult checkImplementationStatus(Operation &op) {
319 auto todo = [&op](StringRef clauseName) {
320 return op.emitError() << "not yet implemented: Unhandled clause "
321 << clauseName << " in " << op.getName()
322 << " operation";
323 };
324
325 auto checkAllocate = [&todo](auto op, LogicalResult &result) {
326 if (!op.getAllocateVars().empty() || !op.getAllocatorVars().empty())
327 result = todo("allocate");
328 };
329 auto checkBare = [&todo](auto op, LogicalResult &result) {
330 if (op.getBare())
331 result = todo("ompx_bare");
332 };
333 auto checkDepend = [&todo](auto op, LogicalResult &result) {
334 if (!op.getDependVars().empty() || op.getDependKinds())
335 result = todo("depend");
336 };
337 auto checkHint = [](auto op, LogicalResult &) {
338 if (op.getHint())
339 op.emitWarning("hint clause discarded");
340 };
341 auto checkInReduction = [&todo](auto op, LogicalResult &result) {
342 if (!op.getInReductionVars().empty() || op.getInReductionByref() ||
343 op.getInReductionSyms())
344 result = todo("in_reduction");
345 };
346 auto checkNowait = [&todo](auto op, LogicalResult &result) {
347 if (op.getNowait())
348 result = todo("nowait");
349 };
350 auto checkOrder = [&todo](auto op, LogicalResult &result) {
351 if (op.getOrder() || op.getOrderMod())
352 result = todo("order");
353 };
354 auto checkPrivate = [&todo](auto op, LogicalResult &result) {
355 if (!op.getPrivateVars().empty() || op.getPrivateSyms())
356 result = todo("privatization");
357 };
358 auto checkReduction = [&todo](auto op, LogicalResult &result) {
359 if (isa<omp::TeamsOp>(op) || isa<omp::TaskloopContextOp>(op))
360 if (!op.getReductionVars().empty() || op.getReductionByref() ||
361 op.getReductionSyms())
362 result = todo("reduction");
363 if (op.getReductionMod() &&
364 op.getReductionMod().value() != omp::ReductionModifier::defaultmod)
365 result = todo("reduction with modifier");
366 };
367 auto checkTaskReduction = [&todo](auto op, LogicalResult &result) {
368 if (!op.getTaskReductionVars().empty() || op.getTaskReductionByref() ||
369 op.getTaskReductionSyms())
370 result = todo("task_reduction");
371 };
372 auto checkNumTeams = [&todo](auto op, LogicalResult &result) {
373 if (op.hasNumTeamsMultiDim())
374 result = todo("num_teams with multi-dimensional values");
375 };
376 auto checkNumThreads = [&todo](auto op, LogicalResult &result) {
377 if (op.hasNumThreadsMultiDim())
378 result = todo("num_threads with multi-dimensional values");
379 };
380
381 auto checkThreadLimit = [&todo](auto op, LogicalResult &result) {
382 if (op.hasThreadLimitMultiDim())
383 result = todo("thread_limit with multi-dimensional values");
384 };
385
386 LogicalResult result = success();
388 .Case([&](omp::DistributeOp op) {
389 checkAllocate(op, result);
390 checkOrder(op, result);
391 })
392 .Case([&](omp::SectionsOp op) {
393 checkAllocate(op, result);
394 checkPrivate(op, result);
395 checkReduction(op, result);
396 })
397 .Case([&](omp::SingleOp op) {
398 checkAllocate(op, result);
399 checkPrivate(op, result);
400 })
401 .Case([&](omp::TeamsOp op) {
402 checkAllocate(op, result);
403 checkPrivate(op, result);
404 checkNumTeams(op, result);
405 checkThreadLimit(op, result);
406 })
407 .Case([&](omp::TaskOp op) {
408 checkAllocate(op, result);
409 checkInReduction(op, result);
410 })
411 .Case([&](omp::TaskgroupOp op) {
412 checkAllocate(op, result);
413 checkTaskReduction(op, result);
414 })
415 .Case([&](omp::TaskwaitOp op) {
416 checkDepend(op, result);
417 checkNowait(op, result);
418 })
419 .Case([&](omp::TaskloopContextOp op) {
420 checkAllocate(op, result);
421 checkInReduction(op, result);
422 checkReduction(op, result);
423 })
424 .Case([&](omp::WsloopOp op) {
425 checkAllocate(op, result);
426 checkOrder(op, result);
427 checkReduction(op, result);
428 })
429 .Case([&](omp::ParallelOp op) {
430 checkAllocate(op, result);
431 checkReduction(op, result);
432 checkNumThreads(op, result);
433 })
434 .Case([&](omp::SimdOp op) { checkReduction(op, result); })
435 .Case<omp::AtomicReadOp, omp::AtomicWriteOp, omp::AtomicUpdateOp,
436 omp::AtomicCaptureOp>([&](auto op) { checkHint(op, result); })
437 .Case<omp::TargetEnterDataOp, omp::TargetExitDataOp>(
438 [&](auto op) { checkDepend(op, result); })
439 .Case([&](omp::TargetUpdateOp op) { checkDepend(op, result); })
440 .Case([&](omp::TargetOp op) {
441 checkAllocate(op, result);
442 checkBare(op, result);
443 checkInReduction(op, result);
444 checkThreadLimit(op, result);
445 })
446 .Default([](Operation &) {
447 // Assume all clauses for an operation can be translated unless they are
448 // checked above.
449 });
450 return result;
451}
452
453static LogicalResult handleError(llvm::Error error, Operation &op) {
454 LogicalResult result = success();
455 if (error) {
456 llvm::handleAllErrors(
457 std::move(error),
458 [&](const PreviouslyReportedError &) { result = failure(); },
459 [&](const llvm::ErrorInfoBase &err) {
460 result = op.emitError(err.message());
461 });
462 }
463 return result;
464}
465
466template <typename T>
467static LogicalResult handleError(llvm::Expected<T> &result, Operation &op) {
468 if (!result)
469 return handleError(result.takeError(), op);
470
471 return success();
472}
473
474/// Find the insertion point for allocas given the current insertion point for
475/// normal operations in the builder.
476static llvm::OpenMPIRBuilder::InsertPointTy
477findAllocaInsertPoint(llvm::IRBuilderBase &builder,
478 LLVM::ModuleTranslation &moduleTranslation) {
479 // If there is an alloca insertion point on stack, i.e. we are in a nested
480 // operation and a specific point was provided by some surrounding operation,
481 // use it.
482 llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
483 WalkResult walkResult = moduleTranslation.stackWalk<OpenMPAllocaStackFrame>(
484 [&](OpenMPAllocaStackFrame &frame) {
485 allocaInsertPoint = frame.allocaInsertPoint;
486 return WalkResult::interrupt();
487 });
488 // In cases with multiple levels of outlining, the tree walk might find an
489 // alloca insertion point that is inside the original function while the
490 // builder insertion point is inside the outlined function. We need to make
491 // sure that we do not use it in those cases.
492 if (walkResult.wasInterrupted() &&
493 allocaInsertPoint.getBlock()->getParent() ==
494 builder.GetInsertBlock()->getParent())
495 return allocaInsertPoint;
496
497 // Otherwise, insert to the entry block of the surrounding function.
498 // If the current IRBuilder InsertPoint is the function's entry, it cannot
499 // also be used for alloca insertion which would result in insertion order
500 // confusion. Create a new BasicBlock for the Builder and use the entry block
501 // for the allocs.
502 // TODO: Create a dedicated alloca BasicBlock at function creation such that
503 // we do not need to move the current InertPoint here.
504 if (builder.GetInsertBlock() ==
505 &builder.GetInsertBlock()->getParent()->getEntryBlock()) {
506 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end() &&
507 "Assuming end of basic block");
508 llvm::BasicBlock *entryBB = llvm::BasicBlock::Create(
509 builder.getContext(), "entry", builder.GetInsertBlock()->getParent(),
510 builder.GetInsertBlock()->getNextNode());
511 builder.CreateBr(entryBB);
512 builder.SetInsertPoint(entryBB);
513 }
514
515 llvm::BasicBlock &funcEntryBlock =
516 builder.GetInsertBlock()->getParent()->getEntryBlock();
517 return llvm::OpenMPIRBuilder::InsertPointTy(
518 &funcEntryBlock, funcEntryBlock.getFirstInsertionPt());
519}
520
521/// Find the loop information structure for the loop nest being translated. It
522/// will return a `null` value unless called from the translation function for
523/// a loop wrapper operation after successfully translating its body.
524static llvm::CanonicalLoopInfo *
526 llvm::CanonicalLoopInfo *loopInfo = nullptr;
527 moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
528 [&](OpenMPLoopInfoStackFrame &frame) {
529 loopInfo = frame.loopInfo;
530 return WalkResult::interrupt();
531 });
532 return loopInfo;
533}
534
535/// Converts the given region that appears within an OpenMP dialect operation to
536/// LLVM IR, creating a branch from the `sourceBlock` to the entry block of the
537/// region, and a branch from any block with an successor-less OpenMP terminator
538/// to `continuationBlock`. Populates `continuationBlockPHIs` with the PHI nodes
539/// of the continuation block if provided.
541 Region &region, StringRef blockName, llvm::IRBuilderBase &builder,
542 LLVM::ModuleTranslation &moduleTranslation,
543 SmallVectorImpl<llvm::PHINode *> *continuationBlockPHIs = nullptr) {
544 bool isLoopWrapper = isa<omp::LoopWrapperInterface>(region.getParentOp());
545
546 llvm::BasicBlock *continuationBlock =
547 splitBB(builder, true, "omp.region.cont");
548 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
549
550 llvm::LLVMContext &llvmContext = builder.getContext();
551 for (Block &bb : region) {
552 llvm::BasicBlock *llvmBB = llvm::BasicBlock::Create(
553 llvmContext, blockName, builder.GetInsertBlock()->getParent(),
554 builder.GetInsertBlock()->getNextNode());
555 moduleTranslation.mapBlock(&bb, llvmBB);
556 }
557
558 llvm::Instruction *sourceTerminator = sourceBlock->getTerminator();
559
560 // Terminators (namely YieldOp) may be forwarding values to the region that
561 // need to be available in the continuation block. Collect the types of these
562 // operands in preparation of creating PHI nodes. This is skipped for loop
563 // wrapper operations, for which we know in advance they have no terminators.
564 SmallVector<llvm::Type *> continuationBlockPHITypes;
565 unsigned numYields = 0;
566
567 if (!isLoopWrapper) {
568 bool operandsProcessed = false;
569 for (Block &bb : region.getBlocks()) {
570 if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(bb.getTerminator())) {
571 if (!operandsProcessed) {
572 for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
573 continuationBlockPHITypes.push_back(
574 moduleTranslation.convertType(yield->getOperand(i).getType()));
575 }
576 operandsProcessed = true;
577 } else {
578 assert(continuationBlockPHITypes.size() == yield->getNumOperands() &&
579 "mismatching number of values yielded from the region");
580 for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
581 llvm::Type *operandType =
582 moduleTranslation.convertType(yield->getOperand(i).getType());
583 (void)operandType;
584 assert(continuationBlockPHITypes[i] == operandType &&
585 "values of mismatching types yielded from the region");
586 }
587 }
588 numYields++;
589 }
590 }
591 }
592
593 // Insert PHI nodes in the continuation block for any values forwarded by the
594 // terminators in this region.
595 if (!continuationBlockPHITypes.empty())
596 assert(
597 continuationBlockPHIs &&
598 "expected continuation block PHIs if converted regions yield values");
599 if (continuationBlockPHIs) {
600 llvm::IRBuilderBase::InsertPointGuard guard(builder);
601 continuationBlockPHIs->reserve(continuationBlockPHITypes.size());
602 builder.SetInsertPoint(continuationBlock, continuationBlock->begin());
603 for (llvm::Type *ty : continuationBlockPHITypes)
604 continuationBlockPHIs->push_back(builder.CreatePHI(ty, numYields));
605 }
606
607 // Convert blocks one by one in topological order to ensure
608 // defs are converted before uses.
610 for (Block *bb : blocks) {
611 llvm::BasicBlock *llvmBB = moduleTranslation.lookupBlock(bb);
612 // Retarget the branch of the entry block to the entry block of the
613 // converted region (regions are single-entry).
614 if (bb->isEntryBlock()) {
615 assert(sourceTerminator->getNumSuccessors() == 1 &&
616 "provided entry block has multiple successors");
617 assert(sourceTerminator->getSuccessor(0) == continuationBlock &&
618 "ContinuationBlock is not the successor of the entry block");
619 sourceTerminator->setSuccessor(0, llvmBB);
620 }
621
622 llvm::IRBuilderBase::InsertPointGuard guard(builder);
623 if (failed(
624 moduleTranslation.convertBlock(*bb, bb->isEntryBlock(), builder)))
625 return llvm::make_error<PreviouslyReportedError>();
626
627 // Create a direct branch here for loop wrappers to prevent their lack of a
628 // terminator from causing a crash below.
629 if (isLoopWrapper) {
630 builder.CreateBr(continuationBlock);
631 continue;
632 }
633
634 // Special handling for `omp.yield` and `omp.terminator` (we may have more
635 // than one): they return the control to the parent OpenMP dialect operation
636 // so replace them with the branch to the continuation block. We handle this
637 // here to avoid relying inter-function communication through the
638 // ModuleTranslation class to set up the correct insertion point. This is
639 // also consistent with MLIR's idiom of handling special region terminators
640 // in the same code that handles the region-owning operation.
641 Operation *terminator = bb->getTerminator();
642 if (isa<omp::TerminatorOp, omp::YieldOp>(terminator)) {
643 builder.CreateBr(continuationBlock);
644
645 for (unsigned i = 0, e = terminator->getNumOperands(); i < e; ++i)
646 (*continuationBlockPHIs)[i]->addIncoming(
647 moduleTranslation.lookupValue(terminator->getOperand(i)), llvmBB);
648 }
649 }
650 // After all blocks have been traversed and values mapped, connect the PHI
651 // nodes to the results of preceding blocks.
652 LLVM::detail::connectPHINodes(region, moduleTranslation);
653
654 // Remove the blocks and values defined in this region from the mapping since
655 // they are not visible outside of this region. This allows the same region to
656 // be converted several times, that is cloned, without clashes, and slightly
657 // speeds up the lookups.
658 moduleTranslation.forgetMapping(region);
659
660 return continuationBlock;
661}
662
663/// Convert ProcBindKind from MLIR-generated enum to LLVM enum.
664static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind) {
665 switch (kind) {
666 case omp::ClauseProcBindKind::Close:
667 return llvm::omp::ProcBindKind::OMP_PROC_BIND_close;
668 case omp::ClauseProcBindKind::Master:
669 return llvm::omp::ProcBindKind::OMP_PROC_BIND_master;
670 case omp::ClauseProcBindKind::Primary:
671 return llvm::omp::ProcBindKind::OMP_PROC_BIND_primary;
672 case omp::ClauseProcBindKind::Spread:
673 return llvm::omp::ProcBindKind::OMP_PROC_BIND_spread;
674 }
675 llvm_unreachable("Unknown ClauseProcBindKind kind");
676}
677
678/// Converts an OpenMP 'masked' operation into LLVM IR using OpenMPIRBuilder.
679static LogicalResult
680convertOmpMasked(Operation &opInst, llvm::IRBuilderBase &builder,
681 LLVM::ModuleTranslation &moduleTranslation) {
682 auto maskedOp = cast<omp::MaskedOp>(opInst);
683 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
684
685 if (failed(checkImplementationStatus(opInst)))
686 return failure();
687
688 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
689 // MaskedOp has only one region associated with it.
690 auto &region = maskedOp.getRegion();
691 builder.restoreIP(codeGenIP);
692 return convertOmpOpRegions(region, "omp.masked.region", builder,
693 moduleTranslation)
694 .takeError();
695 };
696
697 // TODO: Perform finalization actions for variables. This has to be
698 // called for variables which have destructors/finalizers.
699 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
700
701 llvm::Value *filterVal = nullptr;
702 if (auto filterVar = maskedOp.getFilteredThreadId()) {
703 filterVal = moduleTranslation.lookupValue(filterVar);
704 } else {
705 llvm::LLVMContext &llvmContext = builder.getContext();
706 filterVal =
707 llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext), /*V=*/0);
708 }
709 assert(filterVal != nullptr);
710 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
711 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
712 moduleTranslation.getOpenMPBuilder()->createMasked(ompLoc, bodyGenCB,
713 finiCB, filterVal);
714
715 if (failed(handleError(afterIP, opInst)))
716 return failure();
717
718 builder.restoreIP(*afterIP);
719 return success();
720}
721
722/// Converts an OpenMP 'master' operation into LLVM IR using OpenMPIRBuilder.
723static LogicalResult
724convertOmpMaster(Operation &opInst, llvm::IRBuilderBase &builder,
725 LLVM::ModuleTranslation &moduleTranslation) {
726 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
727 auto masterOp = cast<omp::MasterOp>(opInst);
728
729 if (failed(checkImplementationStatus(opInst)))
730 return failure();
731
732 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
733 // MasterOp has only one region associated with it.
734 auto &region = masterOp.getRegion();
735 builder.restoreIP(codeGenIP);
736 return convertOmpOpRegions(region, "omp.master.region", builder,
737 moduleTranslation)
738 .takeError();
739 };
740
741 // TODO: Perform finalization actions for variables. This has to be
742 // called for variables which have destructors/finalizers.
743 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
744
745 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
746 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
747 moduleTranslation.getOpenMPBuilder()->createMaster(ompLoc, bodyGenCB,
748 finiCB);
749
750 if (failed(handleError(afterIP, opInst)))
751 return failure();
752
753 builder.restoreIP(*afterIP);
754 return success();
755}
756
757/// Converts an OpenMP 'critical' operation into LLVM IR using OpenMPIRBuilder.
758static LogicalResult
759convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder,
760 LLVM::ModuleTranslation &moduleTranslation) {
761 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
762 auto criticalOp = cast<omp::CriticalOp>(opInst);
763
764 if (failed(checkImplementationStatus(opInst)))
765 return failure();
766
767 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
768 // CriticalOp has only one region associated with it.
769 auto &region = cast<omp::CriticalOp>(opInst).getRegion();
770 builder.restoreIP(codeGenIP);
771 return convertOmpOpRegions(region, "omp.critical.region", builder,
772 moduleTranslation)
773 .takeError();
774 };
775
776 // TODO: Perform finalization actions for variables. This has to be
777 // called for variables which have destructors/finalizers.
778 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
779
780 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
781 llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
782 llvm::Constant *hint = nullptr;
783
784 // If it has a name, it probably has a hint too.
785 if (criticalOp.getNameAttr()) {
786 // The verifiers in OpenMP Dialect guarentee that all the pointers are
787 // non-null
788 auto symbolRef = cast<SymbolRefAttr>(criticalOp.getNameAttr());
789 auto criticalDeclareOp =
791 symbolRef);
792 hint =
793 llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext),
794 static_cast<int>(criticalDeclareOp.getHint()));
795 }
796 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
797 moduleTranslation.getOpenMPBuilder()->createCritical(
798 ompLoc, bodyGenCB, finiCB, criticalOp.getName().value_or(""), hint);
799
800 if (failed(handleError(afterIP, opInst)))
801 return failure();
802
803 builder.restoreIP(*afterIP);
804 return success();
805}
806
807/// A util to collect info needed to convert delayed privatizers from MLIR to
808/// LLVM.
810 template <typename OP>
812 : blockArgs(
813 cast<omp::BlockArgOpenMPOpInterface>(*op).getPrivateBlockArgs()) {
814 mlirVars.reserve(blockArgs.size());
815 llvmVars.reserve(blockArgs.size());
816 collectPrivatizationDecls<OP>(op);
817
818 for (mlir::Value privateVar : op.getPrivateVars())
819 mlirVars.push_back(privateVar);
820 }
821
826
827private:
828 /// Populates `privatizations` with privatization declarations used for the
829 /// given op.
830 template <class OP>
831 void collectPrivatizationDecls(OP op) {
832 std::optional<ArrayAttr> attr = op.getPrivateSyms();
833 if (!attr)
834 return;
835
836 privatizers.reserve(privatizers.size() + attr->size());
837 for (auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
838 privatizers.push_back(findPrivatizer(op, symbolRef));
839 }
840 }
841};
842
843/// Populates `reductions` with reduction declarations used in the given op.
844template <typename T>
845static void
848 std::optional<ArrayAttr> attr = op.getReductionSyms();
849 if (!attr)
850 return;
851
852 reductions.reserve(reductions.size() + op.getNumReductionVars());
853 for (auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
854 reductions.push_back(
856 op, symbolRef));
857 }
858}
859
860/// Translates the blocks contained in the given region and appends them to at
861/// the current insertion point of `builder`. The operations of the entry block
862/// are appended to the current insertion block. If set, `continuationBlockArgs`
863/// is populated with translated values that correspond to the values
864/// omp.yield'ed from the region.
865static LogicalResult inlineConvertOmpRegions(
866 Region &region, StringRef blockName, llvm::IRBuilderBase &builder,
867 LLVM::ModuleTranslation &moduleTranslation,
868 SmallVectorImpl<llvm::Value *> *continuationBlockArgs = nullptr) {
869 if (region.empty())
870 return success();
871
872 // Special case for single-block regions that don't create additional blocks:
873 // insert operations without creating additional blocks.
874 if (region.hasOneBlock()) {
875 llvm::Instruction *potentialTerminator =
876 builder.GetInsertBlock()->empty() ? nullptr
877 : &builder.GetInsertBlock()->back();
878
879 if (potentialTerminator && potentialTerminator->isTerminator())
880 potentialTerminator->removeFromParent();
881 moduleTranslation.mapBlock(&region.front(), builder.GetInsertBlock());
882
883 if (failed(moduleTranslation.convertBlock(
884 region.front(), /*ignoreArguments=*/true, builder)))
885 return failure();
886
887 // The continuation arguments are simply the translated terminator operands.
888 if (continuationBlockArgs)
889 llvm::append_range(
890 *continuationBlockArgs,
891 moduleTranslation.lookupValues(region.front().back().getOperands()));
892
893 // Drop the mapping that is no longer necessary so that the same region can
894 // be processed multiple times.
895 moduleTranslation.forgetMapping(region);
896
897 if (potentialTerminator && potentialTerminator->isTerminator()) {
898 llvm::BasicBlock *block = builder.GetInsertBlock();
899 if (block->empty()) {
900 // this can happen for really simple reduction init regions e.g.
901 // %0 = llvm.mlir.constant(0 : i32) : i32
902 // omp.yield(%0 : i32)
903 // because the llvm.mlir.constant (MLIR op) isn't converted into any
904 // llvm op
905 potentialTerminator->insertInto(block, block->begin());
906 } else {
907 potentialTerminator->insertAfter(&block->back());
908 }
909 }
910
911 return success();
912 }
913
915 llvm::Expected<llvm::BasicBlock *> continuationBlock =
916 convertOmpOpRegions(region, blockName, builder, moduleTranslation, &phis);
917
918 if (failed(handleError(continuationBlock, *region.getParentOp())))
919 return failure();
920
921 if (continuationBlockArgs)
922 llvm::append_range(*continuationBlockArgs, phis);
923 builder.SetInsertPoint(*continuationBlock,
924 (*continuationBlock)->getFirstInsertionPt());
925 return success();
926}
927
928namespace {
929/// Owning equivalents of OpenMPIRBuilder::(Atomic)ReductionGen that are used to
930/// store lambdas with capture.
931using OwningReductionGen =
932 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
933 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *,
934 llvm::Value *&)>;
935using OwningAtomicReductionGen =
936 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
937 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Type *, llvm::Value *,
938 llvm::Value *)>;
939using OwningDataPtrPtrReductionGen =
940 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
941 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *&)>;
942} // namespace
943
944/// Create an OpenMPIRBuilder-compatible reduction generator for the given
945/// reduction declaration. The generator uses `builder` but ignores its
946/// insertion point.
947static OwningReductionGen
948makeReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder,
949 LLVM::ModuleTranslation &moduleTranslation) {
950 // The lambda is mutable because we need access to non-const methods of decl
951 // (which aren't actually mutating it), and we must capture decl by-value to
952 // avoid the dangling reference after the parent function returns.
953 OwningReductionGen gen =
954 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
955 llvm::Value *lhs, llvm::Value *rhs,
956 llvm::Value *&result) mutable
957 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
958 moduleTranslation.mapValue(decl.getReductionLhsArg(), lhs);
959 moduleTranslation.mapValue(decl.getReductionRhsArg(), rhs);
960 builder.restoreIP(insertPoint);
962 if (failed(inlineConvertOmpRegions(decl.getReductionRegion(),
963 "omp.reduction.nonatomic.body", builder,
964 moduleTranslation, &phis)))
965 return llvm::createStringError(
966 "failed to inline `combiner` region of `omp.declare_reduction`");
967 result = llvm::getSingleElement(phis);
968 return builder.saveIP();
969 };
970 return gen;
971}
972
973/// Create an OpenMPIRBuilder-compatible atomic reduction generator for the
974/// given reduction declaration. The generator uses `builder` but ignores its
975/// insertion point. Returns null if there is no atomic region available in the
976/// reduction declaration.
977static OwningAtomicReductionGen
978makeAtomicReductionGen(omp::DeclareReductionOp decl,
979 llvm::IRBuilderBase &builder,
980 LLVM::ModuleTranslation &moduleTranslation) {
981 if (decl.getAtomicReductionRegion().empty())
982 return OwningAtomicReductionGen();
983
984 // The lambda is mutable because we need access to non-const methods of decl
985 // (which aren't actually mutating it), and we must capture decl by-value to
986 // avoid the dangling reference after the parent function returns.
987 OwningAtomicReductionGen atomicGen =
988 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint, llvm::Type *,
989 llvm::Value *lhs, llvm::Value *rhs) mutable
990 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
991 moduleTranslation.mapValue(decl.getAtomicReductionLhsArg(), lhs);
992 moduleTranslation.mapValue(decl.getAtomicReductionRhsArg(), rhs);
993 builder.restoreIP(insertPoint);
995 if (failed(inlineConvertOmpRegions(decl.getAtomicReductionRegion(),
996 "omp.reduction.atomic.body", builder,
997 moduleTranslation, &phis)))
998 return llvm::createStringError(
999 "failed to inline `atomic` region of `omp.declare_reduction`");
1000 assert(phis.empty());
1001 return builder.saveIP();
1002 };
1003 return atomicGen;
1004}
1005
1006/// Create an OpenMPIRBuilder-compatible `data_ptr_ptr` reduction generator for
1007/// the given reduction declaration. The generator uses `builder` but ignores
1008/// its insertion point. Returns null if there is no `data_ptr_ptr` region
1009/// available in the reduction declaration.
1010static OwningDataPtrPtrReductionGen
1011makeRefDataPtrGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder,
1012 LLVM::ModuleTranslation &moduleTranslation, bool isByRef) {
1013 if (!isByRef)
1014 return OwningDataPtrPtrReductionGen();
1015
1016 OwningDataPtrPtrReductionGen refDataPtrGen =
1017 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
1018 llvm::Value *byRefVal, llvm::Value *&result) mutable
1019 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1020 moduleTranslation.mapValue(decl.getDataPtrPtrRegionArg(), byRefVal);
1021 builder.restoreIP(insertPoint);
1023 if (failed(inlineConvertOmpRegions(decl.getDataPtrPtrRegion(),
1024 "omp.data_ptr_ptr.body", builder,
1025 moduleTranslation, &phis)))
1026 return llvm::createStringError(
1027 "failed to inline `data_ptr_ptr` region of `omp.declare_reduction`");
1028 result = llvm::getSingleElement(phis);
1029 return builder.saveIP();
1030 };
1031
1032 return refDataPtrGen;
1033}
1034
1035/// Converts an OpenMP 'ordered' operation into LLVM IR using OpenMPIRBuilder.
1036static LogicalResult
1037convertOmpOrdered(Operation &opInst, llvm::IRBuilderBase &builder,
1038 LLVM::ModuleTranslation &moduleTranslation) {
1039 auto orderedOp = cast<omp::OrderedOp>(opInst);
1040
1041 if (failed(checkImplementationStatus(opInst)))
1042 return failure();
1043
1044 omp::ClauseDepend dependType = *orderedOp.getDoacrossDependType();
1045 bool isDependSource = dependType == omp::ClauseDepend::dependsource;
1046 unsigned numLoops = *orderedOp.getDoacrossNumLoops();
1047 SmallVector<llvm::Value *> vecValues =
1048 moduleTranslation.lookupValues(orderedOp.getDoacrossDependVars());
1049
1050 size_t indexVecValues = 0;
1051 while (indexVecValues < vecValues.size()) {
1052 SmallVector<llvm::Value *> storeValues;
1053 storeValues.reserve(numLoops);
1054 for (unsigned i = 0; i < numLoops; i++) {
1055 storeValues.push_back(vecValues[indexVecValues]);
1056 indexVecValues++;
1057 }
1058 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1059 findAllocaInsertPoint(builder, moduleTranslation);
1060 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1061 builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createOrderedDepend(
1062 ompLoc, allocaIP, numLoops, storeValues, ".cnt.addr", isDependSource));
1063 }
1064 return success();
1065}
1066
1067/// Converts an OpenMP 'ordered_region' operation into LLVM IR using
1068/// OpenMPIRBuilder.
1069static LogicalResult
1070convertOmpOrderedRegion(Operation &opInst, llvm::IRBuilderBase &builder,
1071 LLVM::ModuleTranslation &moduleTranslation) {
1072 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1073 auto orderedRegionOp = cast<omp::OrderedRegionOp>(opInst);
1074
1075 if (failed(checkImplementationStatus(opInst)))
1076 return failure();
1077
1078 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
1079 // OrderedOp has only one region associated with it.
1080 auto &region = cast<omp::OrderedRegionOp>(opInst).getRegion();
1081 builder.restoreIP(codeGenIP);
1082 return convertOmpOpRegions(region, "omp.ordered.region", builder,
1083 moduleTranslation)
1084 .takeError();
1085 };
1086
1087 // TODO: Perform finalization actions for variables. This has to be
1088 // called for variables which have destructors/finalizers.
1089 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
1090
1091 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1092 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1093 moduleTranslation.getOpenMPBuilder()->createOrderedThreadsSimd(
1094 ompLoc, bodyGenCB, finiCB, !orderedRegionOp.getParLevelSimd());
1095
1096 if (failed(handleError(afterIP, opInst)))
1097 return failure();
1098
1099 builder.restoreIP(*afterIP);
1100 return success();
1101}
1102
1103namespace {
1104/// Contains the arguments for an LLVM store operation
1105struct DeferredStore {
1106 DeferredStore(llvm::Value *value, llvm::Value *address)
1107 : value(value), address(address) {}
1108
1109 llvm::Value *value;
1110 llvm::Value *address;
1111};
1112} // namespace
1113
1114/// Allocate space for privatized reduction variables.
1115/// `deferredStores` contains information to create store operations which needs
1116/// to be inserted after all allocas
1117template <typename T>
1118static LogicalResult
1120 llvm::IRBuilderBase &builder,
1121 LLVM::ModuleTranslation &moduleTranslation,
1122 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1124 SmallVectorImpl<llvm::Value *> &privateReductionVariables,
1125 DenseMap<Value, llvm::Value *> &reductionVariableMap,
1126 SmallVectorImpl<DeferredStore> &deferredStores,
1127 llvm::ArrayRef<bool> isByRefs) {
1128 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1129 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1130
1131 // delay creating stores until after all allocas
1132 deferredStores.reserve(loop.getNumReductionVars());
1133
1134 for (std::size_t i = 0; i < loop.getNumReductionVars(); ++i) {
1135 Region &allocRegion = reductionDecls[i].getAllocRegion();
1136 if (isByRefs[i]) {
1137 if (allocRegion.empty())
1138 continue;
1139
1141 if (failed(inlineConvertOmpRegions(allocRegion, "omp.reduction.alloc",
1142 builder, moduleTranslation, &phis)))
1143 return loop.emitError(
1144 "failed to inline `alloc` region of `omp.declare_reduction`");
1145
1146 assert(phis.size() == 1 && "expected one allocation to be yielded");
1147 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1148
1149 // Allocate reduction variable (which is a pointer to the real reduction
1150 // variable allocated in the inlined region)
1151 llvm::Value *var = builder.CreateAlloca(
1152 moduleTranslation.convertType(reductionDecls[i].getType()));
1153
1154 llvm::Type *ptrTy = builder.getPtrTy();
1155 llvm::Value *castVar =
1156 builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
1157 llvm::Value *castPhi =
1158 builder.CreatePointerBitCastOrAddrSpaceCast(phis[0], ptrTy);
1159
1160 deferredStores.emplace_back(castPhi, castVar);
1161
1162 privateReductionVariables[i] = castVar;
1163 moduleTranslation.mapValue(reductionArgs[i], castPhi);
1164 reductionVariableMap.try_emplace(loop.getReductionVars()[i], castPhi);
1165 } else {
1166 assert(allocRegion.empty() &&
1167 "allocaction is implicit for by-val reduction");
1168 llvm::Value *var = builder.CreateAlloca(
1169 moduleTranslation.convertType(reductionDecls[i].getType()));
1170
1171 llvm::Type *ptrTy = builder.getPtrTy();
1172 llvm::Value *castVar =
1173 builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
1174
1175 moduleTranslation.mapValue(reductionArgs[i], castVar);
1176 privateReductionVariables[i] = castVar;
1177 reductionVariableMap.try_emplace(loop.getReductionVars()[i], castVar);
1178 }
1179 }
1180
1181 return success();
1182}
1183
1184/// Map input arguments to reduction initialization region
1185template <typename T>
1186static void
1188 llvm::IRBuilderBase &builder,
1190 DenseMap<Value, llvm::Value *> &reductionVariableMap,
1191 unsigned i) {
1192 // map input argument to the initialization region
1193 mlir::omp::DeclareReductionOp &reduction = reductionDecls[i];
1194 Region &initializerRegion = reduction.getInitializerRegion();
1195 Block &entry = initializerRegion.front();
1196
1197 mlir::Value mlirSource = loop.getReductionVars()[i];
1198 llvm::Value *llvmSource = moduleTranslation.lookupValue(mlirSource);
1199 llvm::Value *origVal = llvmSource;
1200 // If a non-pointer value is expected, load the value from the source pointer.
1201 if (!isa<LLVM::LLVMPointerType>(
1202 reduction.getInitializerMoldArg().getType()) &&
1203 isa<LLVM::LLVMPointerType>(mlirSource.getType())) {
1204 origVal =
1205 builder.CreateLoad(moduleTranslation.convertType(
1206 reduction.getInitializerMoldArg().getType()),
1207 llvmSource, "omp_orig");
1208 }
1209 moduleTranslation.mapValue(reduction.getInitializerMoldArg(), origVal);
1210
1211 if (entry.getNumArguments() > 1) {
1212 llvm::Value *allocation =
1213 reductionVariableMap.lookup(loop.getReductionVars()[i]);
1214 moduleTranslation.mapValue(reduction.getInitializerAllocArg(), allocation);
1215 }
1216}
1217
1218static void
1219setInsertPointForPossiblyEmptyBlock(llvm::IRBuilderBase &builder,
1220 llvm::BasicBlock *block = nullptr) {
1221 if (block == nullptr)
1222 block = builder.GetInsertBlock();
1223
1224 if (!block->hasTerminator())
1225 builder.SetInsertPoint(block);
1226 else
1227 builder.SetInsertPoint(block->getTerminator());
1228}
1229
1230/// Inline reductions' `init` regions. This functions assumes that the
1231/// `builder`'s insertion point is where the user wants the `init` regions to be
1232/// inlined; i.e. it does not try to find a proper insertion location for the
1233/// `init` regions. It also leaves the `builder's insertions point in a state
1234/// where the user can continue the code-gen directly afterwards.
1235template <typename OP>
1236static LogicalResult
1237initReductionVars(OP op, ArrayRef<BlockArgument> reductionArgs,
1238 llvm::IRBuilderBase &builder,
1239 LLVM::ModuleTranslation &moduleTranslation,
1240 llvm::BasicBlock *latestAllocaBlock,
1242 SmallVectorImpl<llvm::Value *> &privateReductionVariables,
1243 DenseMap<Value, llvm::Value *> &reductionVariableMap,
1244 llvm::ArrayRef<bool> isByRef,
1245 SmallVectorImpl<DeferredStore> &deferredStores) {
1246 if (op.getNumReductionVars() == 0)
1247 return success();
1248
1249 llvm::BasicBlock *initBlock = splitBB(builder, true, "omp.reduction.init");
1250 auto allocaIP = llvm::IRBuilderBase::InsertPoint(
1251 latestAllocaBlock, latestAllocaBlock->getTerminator()->getIterator());
1252 builder.restoreIP(allocaIP);
1253 SmallVector<llvm::Value *> byRefVars(op.getNumReductionVars());
1254
1255 for (unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1256 if (isByRef[i]) {
1257 if (!reductionDecls[i].getAllocRegion().empty())
1258 continue;
1259
1260 // TODO: remove after all users of by-ref are updated to use the alloc
1261 // region: Allocate reduction variable (which is a pointer to the real
1262 // reduciton variable allocated in the inlined region)
1263 byRefVars[i] = builder.CreateAlloca(
1264 moduleTranslation.convertType(reductionDecls[i].getType()));
1265 }
1266 }
1267
1268 setInsertPointForPossiblyEmptyBlock(builder, initBlock);
1269
1270 // store result of the alloc region to the allocated pointer to the real
1271 // reduction variable
1272 for (auto [data, addr] : deferredStores)
1273 builder.CreateStore(data, addr);
1274
1275 // Before the loop, store the initial values of reductions into reduction
1276 // variables. Although this could be done after allocas, we don't want to mess
1277 // up with the alloca insertion point.
1278 for (unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1280
1281 // map block argument to initializer region
1282 mapInitializationArgs(op, moduleTranslation, builder, reductionDecls,
1283 reductionVariableMap, i);
1284
1285 // TODO In some cases (specially on the GPU), the init regions may
1286 // contains stack alloctaions. If the region is inlined in a loop, this is
1287 // problematic. Instead of just inlining the region, handle allocations by
1288 // hoisting fixed length allocations to the function entry and using
1289 // stacksave and restore for variable length ones.
1290 if (failed(inlineConvertOmpRegions(reductionDecls[i].getInitializerRegion(),
1291 "omp.reduction.neutral", builder,
1292 moduleTranslation, &phis)))
1293 return failure();
1294
1295 assert(phis.size() == 1 && "expected one value to be yielded from the "
1296 "reduction neutral element declaration region");
1297
1299
1300 if (isByRef[i]) {
1301 if (!reductionDecls[i].getAllocRegion().empty())
1302 // done in allocReductionVars
1303 continue;
1304
1305 // TODO: this path can be removed once all users of by-ref are updated to
1306 // use an alloc region
1307
1308 // Store the result of the inlined region to the allocated reduction var
1309 // ptr
1310 builder.CreateStore(phis[0], byRefVars[i]);
1311
1312 privateReductionVariables[i] = byRefVars[i];
1313 moduleTranslation.mapValue(reductionArgs[i], phis[0]);
1314 reductionVariableMap.try_emplace(op.getReductionVars()[i], phis[0]);
1315 } else {
1316 // for by-ref case the store is inside of the reduction region
1317 builder.CreateStore(phis[0], privateReductionVariables[i]);
1318 // the rest was handled in allocByValReductionVars
1319 }
1320
1321 // forget the mapping for the initializer region because we might need a
1322 // different mapping if this reduction declaration is re-used for a
1323 // different variable
1324 moduleTranslation.forgetMapping(reductionDecls[i].getInitializerRegion());
1325 }
1326
1327 return success();
1328}
1329
1330/// Collect reduction info
1331template <typename T>
1332static void collectReductionInfo(
1333 T loop, llvm::IRBuilderBase &builder,
1334 LLVM::ModuleTranslation &moduleTranslation,
1337 SmallVectorImpl<OwningAtomicReductionGen> &owningAtomicReductionGens,
1339 const ArrayRef<llvm::Value *> privateReductionVariables,
1341 ArrayRef<bool> isByRef) {
1342 unsigned numReductions = loop.getNumReductionVars();
1343
1344 for (unsigned i = 0; i < numReductions; ++i) {
1345 owningReductionGens.push_back(
1346 makeReductionGen(reductionDecls[i], builder, moduleTranslation));
1347 owningAtomicReductionGens.push_back(
1348 makeAtomicReductionGen(reductionDecls[i], builder, moduleTranslation));
1350 reductionDecls[i], builder, moduleTranslation, isByRef[i]));
1351 }
1352
1353 // Collect the reduction information.
1354 reductionInfos.reserve(numReductions);
1355 for (unsigned i = 0; i < numReductions; ++i) {
1356 llvm::OpenMPIRBuilder::ReductionGenAtomicCBTy atomicGen = nullptr;
1357 if (owningAtomicReductionGens[i])
1358 atomicGen = owningAtomicReductionGens[i];
1359 llvm::Value *variable =
1360 moduleTranslation.lookupValue(loop.getReductionVars()[i]);
1361 mlir::Type allocatedType;
1362 reductionDecls[i].getAllocRegion().walk([&](mlir::Operation *op) {
1363 if (auto alloca = mlir::dyn_cast<LLVM::AllocaOp>(op)) {
1364 allocatedType = alloca.getElemType();
1366 }
1367
1369 });
1370
1371 reductionInfos.push_back(
1372 {moduleTranslation.convertType(reductionDecls[i].getType()), variable,
1373 privateReductionVariables[i],
1374 /*EvaluationKind=*/llvm::OpenMPIRBuilder::EvalKind::Scalar,
1376 /*ReductionGenClang=*/nullptr, atomicGen,
1378 allocatedType ? moduleTranslation.convertType(allocatedType) : nullptr,
1379 reductionDecls[i].getByrefElementType()
1380 ? moduleTranslation.convertType(
1381 *reductionDecls[i].getByrefElementType())
1382 : nullptr});
1383 }
1384}
1385
1386/// handling of DeclareReductionOp's cleanup region
1387static LogicalResult
1389 llvm::ArrayRef<llvm::Value *> privateVariables,
1390 LLVM::ModuleTranslation &moduleTranslation,
1391 llvm::IRBuilderBase &builder, StringRef regionName,
1392 bool shouldLoadCleanupRegionArg = true) {
1393 for (auto [i, cleanupRegion] : llvm::enumerate(cleanupRegions)) {
1394 if (cleanupRegion->empty())
1395 continue;
1396
1397 // map the argument to the cleanup region
1398 Block &entry = cleanupRegion->front();
1399
1400 llvm::Instruction *potentialTerminator =
1401 builder.GetInsertBlock()->empty() ? nullptr
1402 : &builder.GetInsertBlock()->back();
1403 if (potentialTerminator && potentialTerminator->isTerminator())
1404 builder.SetInsertPoint(potentialTerminator);
1405 llvm::Value *privateVarValue =
1406 shouldLoadCleanupRegionArg
1407 ? builder.CreateLoad(
1408 moduleTranslation.convertType(entry.getArgument(0).getType()),
1409 privateVariables[i])
1410 : privateVariables[i];
1411
1412 moduleTranslation.mapValue(entry.getArgument(0), privateVarValue);
1413
1414 if (failed(inlineConvertOmpRegions(*cleanupRegion, regionName, builder,
1415 moduleTranslation)))
1416 return failure();
1417
1418 // clear block argument mapping in case it needs to be re-created with a
1419 // different source for another use of the same reduction decl
1420 moduleTranslation.forgetMapping(*cleanupRegion);
1421 }
1422 return success();
1423}
1424
1425// TODO: not used by ParallelOp
1426template <class OP>
1427static LogicalResult createReductionsAndCleanup(
1428 OP op, llvm::IRBuilderBase &builder,
1429 LLVM::ModuleTranslation &moduleTranslation,
1430 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1432 ArrayRef<llvm::Value *> privateReductionVariables, ArrayRef<bool> isByRef,
1433 bool isNowait = false, bool isTeamsReduction = false) {
1434 // Process the reductions if required.
1435 if (op.getNumReductionVars() == 0)
1436 return success();
1437
1439 SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
1440 SmallVector<OwningDataPtrPtrReductionGen> owningReductionGenRefDataPtrGens;
1442
1443 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
1444
1445 // Create the reduction generators. We need to own them here because
1446 // ReductionInfo only accepts references to the generators.
1447 collectReductionInfo(op, builder, moduleTranslation, reductionDecls,
1448 owningReductionGens, owningAtomicReductionGens,
1449 owningReductionGenRefDataPtrGens,
1450 privateReductionVariables, reductionInfos, isByRef);
1451
1452 // The call to createReductions below expects the block to have a
1453 // terminator. Create an unreachable instruction to serve as terminator
1454 // and remove it later.
1455 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
1456 builder.SetInsertPoint(tempTerminator);
1457 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
1458 ompBuilder->createReductions(builder.saveIP(), allocaIP, reductionInfos,
1459 isByRef, isNowait, isTeamsReduction);
1460
1461 if (failed(handleError(contInsertPoint, *op)))
1462 return failure();
1463
1464 if (!contInsertPoint->getBlock())
1465 return op->emitOpError() << "failed to convert reductions";
1466
1467 llvm::OpenMPIRBuilder::InsertPointTy afterIP = *contInsertPoint;
1468 if (!isTeamsReduction) {
1469 llvm::OpenMPIRBuilder::InsertPointOrErrorTy barrierIP =
1470 ompBuilder->createBarrier(*contInsertPoint, llvm::omp::OMPD_for);
1471
1472 if (failed(handleError(barrierIP, *op)))
1473 return failure();
1474 afterIP = *barrierIP;
1475 }
1476
1477 tempTerminator->eraseFromParent();
1478 builder.restoreIP(afterIP);
1479
1480 // after the construct, deallocate private reduction variables
1481 SmallVector<Region *> reductionRegions;
1482 llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
1483 [](omp::DeclareReductionOp reductionDecl) {
1484 return &reductionDecl.getCleanupRegion();
1485 });
1486 return inlineOmpRegionCleanup(reductionRegions, privateReductionVariables,
1487 moduleTranslation, builder,
1488 "omp.reduction.cleanup");
1489 return success();
1490}
1491
1492static ArrayRef<bool> getIsByRef(std::optional<ArrayRef<bool>> attr) {
1493 if (!attr)
1494 return {};
1495 return *attr;
1496}
1497
1498// TODO: not used by omp.parallel
1499template <typename OP>
1501 OP op, ArrayRef<BlockArgument> reductionArgs, llvm::IRBuilderBase &builder,
1502 LLVM::ModuleTranslation &moduleTranslation,
1503 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1505 SmallVectorImpl<llvm::Value *> &privateReductionVariables,
1506 DenseMap<Value, llvm::Value *> &reductionVariableMap,
1507 llvm::ArrayRef<bool> isByRef) {
1508 if (op.getNumReductionVars() == 0)
1509 return success();
1510
1511 SmallVector<DeferredStore> deferredStores;
1512
1513 if (failed(allocReductionVars(op, reductionArgs, builder, moduleTranslation,
1514 allocaIP, reductionDecls,
1515 privateReductionVariables, reductionVariableMap,
1516 deferredStores, isByRef)))
1517 return failure();
1518
1519 return initReductionVars(op, reductionArgs, builder, moduleTranslation,
1520 allocaIP.getBlock(), reductionDecls,
1521 privateReductionVariables, reductionVariableMap,
1522 isByRef, deferredStores);
1523}
1524
1525/// Return the llvm::Value * corresponding to the `privateVar` that
1526/// is being privatized. It isn't always as simple as looking up
1527/// moduleTranslation with privateVar. For instance, in case of
1528/// an allocatable, the descriptor for the allocatable is privatized.
1529/// This descriptor is mapped using an MapInfoOp. So, this function
1530/// will return a pointer to the llvm::Value corresponding to the
1531/// block argument for the mapped descriptor.
1532static llvm::Value *
1533findAssociatedValue(Value privateVar, llvm::IRBuilderBase &builder,
1534 LLVM::ModuleTranslation &moduleTranslation,
1535 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1536 if (mappedPrivateVars == nullptr || !mappedPrivateVars->contains(privateVar))
1537 return moduleTranslation.lookupValue(privateVar);
1538
1539 Value blockArg = (*mappedPrivateVars)[privateVar];
1540 Type privVarType = privateVar.getType();
1541 Type blockArgType = blockArg.getType();
1542 assert(isa<LLVM::LLVMPointerType>(blockArgType) &&
1543 "A block argument corresponding to a mapped var should have "
1544 "!llvm.ptr type");
1545
1546 if (privVarType == blockArgType)
1547 return moduleTranslation.lookupValue(blockArg);
1548
1549 // This typically happens when the privatized type is lowered from
1550 // boxchar<KIND> and gets lowered to !llvm.struct<(ptr, i64)>. That is the
1551 // struct/pair is passed by value. But, mapped values are passed only as
1552 // pointers, so before we privatize, we must load the pointer.
1553 if (!isa<LLVM::LLVMPointerType>(privVarType))
1554 return builder.CreateLoad(moduleTranslation.convertType(privVarType),
1555 moduleTranslation.lookupValue(blockArg));
1556
1557 return moduleTranslation.lookupValue(privateVar);
1558}
1559
1560/// Initialize a single (first)private variable. You probably want to use
1561/// allocateAndInitPrivateVars instead of this.
1562/// This returns the private variable which has been initialized. This
1563/// variable should be mapped before constructing the body of the Op.
1565initPrivateVar(llvm::IRBuilderBase &builder,
1566 LLVM::ModuleTranslation &moduleTranslation,
1567 omp::PrivateClauseOp &privDecl, llvm::Value *nonPrivateVar,
1568 BlockArgument &blockArg, llvm::Value *llvmPrivateVar,
1569 llvm::BasicBlock *privInitBlock,
1570 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1571 Region &initRegion = privDecl.getInitRegion();
1572 if (initRegion.empty())
1573 return llvmPrivateVar;
1574
1575 assert(nonPrivateVar);
1576 moduleTranslation.mapValue(privDecl.getInitMoldArg(), nonPrivateVar);
1577 moduleTranslation.mapValue(privDecl.getInitPrivateArg(), llvmPrivateVar);
1578
1579 // in-place convert the private initialization region
1581 if (failed(inlineConvertOmpRegions(initRegion, "omp.private.init", builder,
1582 moduleTranslation, &phis)))
1583 return llvm::createStringError(
1584 "failed to inline `init` region of `omp.private`");
1585
1586 assert(phis.size() == 1 && "expected one allocation to be yielded");
1587
1588 // clear init region block argument mapping in case it needs to be
1589 // re-created with a different source for another use of the same
1590 // reduction decl
1591 moduleTranslation.forgetMapping(initRegion);
1592
1593 // Prefer the value yielded from the init region to the allocated private
1594 // variable in case the region is operating on arguments by-value (e.g.
1595 // Fortran character boxes).
1596 return phis[0];
1597}
1598
1599/// Version of initPrivateVar which looks up the nonPrivateVar from mlirPrivVar.
1601 llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation,
1602 omp::PrivateClauseOp &privDecl, Value mlirPrivVar, BlockArgument &blockArg,
1603 llvm::Value *llvmPrivateVar, llvm::BasicBlock *privInitBlock,
1604 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1605 return initPrivateVar(
1606 builder, moduleTranslation, privDecl,
1607 findAssociatedValue(mlirPrivVar, builder, moduleTranslation,
1608 mappedPrivateVars),
1609 blockArg, llvmPrivateVar, privInitBlock, mappedPrivateVars);
1610}
1611
1612static llvm::Error
1613initPrivateVars(llvm::IRBuilderBase &builder,
1614 LLVM::ModuleTranslation &moduleTranslation,
1615 PrivateVarsInfo &privateVarsInfo,
1616 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1617 if (privateVarsInfo.blockArgs.empty())
1618 return llvm::Error::success();
1619
1620 llvm::BasicBlock *privInitBlock = splitBB(builder, true, "omp.private.init");
1621 setInsertPointForPossiblyEmptyBlock(builder, privInitBlock);
1622
1623 for (auto [idx, zip] : llvm::enumerate(llvm::zip_equal(
1624 privateVarsInfo.privatizers, privateVarsInfo.mlirVars,
1625 privateVarsInfo.blockArgs, privateVarsInfo.llvmVars))) {
1626 auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVar] = zip;
1628 builder, moduleTranslation, privDecl, mlirPrivVar, blockArg,
1629 llvmPrivateVar, privInitBlock, mappedPrivateVars);
1630
1631 if (!privVarOrErr)
1632 return privVarOrErr.takeError();
1633
1634 llvmPrivateVar = privVarOrErr.get();
1635 moduleTranslation.mapValue(blockArg, llvmPrivateVar);
1636
1638 }
1639
1640 return llvm::Error::success();
1641}
1642
1643/// Allocate and initialize delayed private variables. Returns the basic block
1644/// which comes after all of these allocations. llvm::Value * for each of these
1645/// private variables are populated in llvmPrivateVars.
1647allocatePrivateVars(llvm::IRBuilderBase &builder,
1648 LLVM::ModuleTranslation &moduleTranslation,
1649 PrivateVarsInfo &privateVarsInfo,
1650 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1651 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1652 // Allocate private vars
1653 llvm::Instruction *allocaTerminator = allocaIP.getBlock()->getTerminator();
1654 splitBB(llvm::OpenMPIRBuilder::InsertPointTy(allocaIP.getBlock(),
1655 allocaTerminator->getIterator()),
1656 true, allocaTerminator->getStableDebugLoc(),
1657 "omp.region.after_alloca");
1658
1659 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1660 // Update the allocaTerminator since the alloca block was split above.
1661 allocaTerminator = allocaIP.getBlock()->getTerminator();
1662 builder.SetInsertPoint(allocaTerminator);
1663 // The new terminator is an uncondition branch created by the splitBB above.
1664 assert(allocaTerminator->getNumSuccessors() == 1 &&
1665 "This is an unconditional branch created by splitBB");
1666
1667 llvm::DataLayout dataLayout = builder.GetInsertBlock()->getDataLayout();
1668 llvm::BasicBlock *afterAllocas = allocaTerminator->getSuccessor(0);
1669
1670 unsigned int allocaAS =
1671 moduleTranslation.getLLVMModule()->getDataLayout().getAllocaAddrSpace();
1672 unsigned int defaultAS = moduleTranslation.getLLVMModule()
1673 ->getDataLayout()
1674 .getProgramAddressSpace();
1675
1676 for (auto [privDecl, mlirPrivVar, blockArg] :
1677 llvm::zip_equal(privateVarsInfo.privatizers, privateVarsInfo.mlirVars,
1678 privateVarsInfo.blockArgs)) {
1679 llvm::Type *llvmAllocType =
1680 moduleTranslation.convertType(privDecl.getType());
1681 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1682 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
1683 llvmAllocType, /*ArraySize=*/nullptr, "omp.private.alloc");
1684 if (allocaAS != defaultAS)
1685 llvmPrivateVar = builder.CreateAddrSpaceCast(llvmPrivateVar,
1686 builder.getPtrTy(defaultAS));
1687
1688 privateVarsInfo.llvmVars.push_back(llvmPrivateVar);
1689 }
1690
1691 return afterAllocas;
1692}
1693
1694/// This can't always be determined statically, but when we can, it is good to
1695/// avoid generating compiler-added barriers which will deadlock the program.
1697 for (mlir::Operation *parent = op->getParentOp(); parent != nullptr;
1698 parent = parent->getParentOp()) {
1699 if (mlir::isa<omp::SingleOp, omp::CriticalOp>(parent))
1700 return true;
1701
1702 // e.g.
1703 // omp.single {
1704 // omp.parallel {
1705 // op
1706 // }
1707 // }
1708 if (mlir::isa<omp::ParallelOp>(parent))
1709 return false;
1710 }
1711 return false;
1712}
1713
1714static LogicalResult copyFirstPrivateVars(
1715 mlir::Operation *op, llvm::IRBuilderBase &builder,
1716 LLVM::ModuleTranslation &moduleTranslation,
1718 ArrayRef<llvm::Value *> llvmPrivateVars,
1719 SmallVectorImpl<omp::PrivateClauseOp> &privateDecls, bool insertBarrier,
1720 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1721 // Apply copy region for firstprivate.
1722 bool needsFirstprivate =
1723 llvm::any_of(privateDecls, [](omp::PrivateClauseOp &privOp) {
1724 return privOp.getDataSharingType() ==
1725 omp::DataSharingClauseType::FirstPrivate;
1726 });
1727
1728 if (!needsFirstprivate)
1729 return success();
1730
1731 llvm::BasicBlock *copyBlock =
1732 splitBB(builder, /*CreateBranch=*/true, "omp.private.copy");
1733 setInsertPointForPossiblyEmptyBlock(builder, copyBlock);
1734
1735 for (auto [decl, moldVar, llvmVar] :
1736 llvm::zip_equal(privateDecls, moldVars, llvmPrivateVars)) {
1737 if (decl.getDataSharingType() != omp::DataSharingClauseType::FirstPrivate)
1738 continue;
1739
1740 // copyRegion implements `lhs = rhs`
1741 Region &copyRegion = decl.getCopyRegion();
1742
1743 moduleTranslation.mapValue(decl.getCopyMoldArg(), moldVar);
1744
1745 // map copyRegion lhs arg
1746 moduleTranslation.mapValue(decl.getCopyPrivateArg(), llvmVar);
1747
1748 // in-place convert copy region
1749 if (failed(inlineConvertOmpRegions(copyRegion, "omp.private.copy", builder,
1750 moduleTranslation)))
1751 return decl.emitError("failed to inline `copy` region of `omp.private`");
1752
1754
1755 // ignore unused value yielded from copy region
1756
1757 // clear copy region block argument mapping in case it needs to be
1758 // re-created with different sources for reuse of the same reduction
1759 // decl
1760 moduleTranslation.forgetMapping(copyRegion);
1761 }
1762
1763 if (insertBarrier && !opIsInSingleThread(op)) {
1764 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
1765 llvm::OpenMPIRBuilder::InsertPointOrErrorTy res =
1766 ompBuilder->createBarrier(builder.saveIP(), llvm::omp::OMPD_barrier);
1767 if (failed(handleError(res, *op)))
1768 return failure();
1769 }
1770
1771 return success();
1772}
1773
1774static LogicalResult copyFirstPrivateVars(
1775 mlir::Operation *op, llvm::IRBuilderBase &builder,
1776 LLVM::ModuleTranslation &moduleTranslation,
1777 SmallVectorImpl<mlir::Value> &mlirPrivateVars,
1778 ArrayRef<llvm::Value *> llvmPrivateVars,
1779 SmallVectorImpl<omp::PrivateClauseOp> &privateDecls, bool insertBarrier,
1780 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1781 llvm::SmallVector<llvm::Value *> moldVars(mlirPrivateVars.size());
1782 llvm::transform(mlirPrivateVars, moldVars.begin(), [&](mlir::Value mlirVar) {
1783 // map copyRegion rhs arg
1784 llvm::Value *moldVar = findAssociatedValue(
1785 mlirVar, builder, moduleTranslation, mappedPrivateVars);
1786 assert(moldVar);
1787 return moldVar;
1788 });
1789 return copyFirstPrivateVars(op, builder, moduleTranslation, moldVars,
1790 llvmPrivateVars, privateDecls, insertBarrier,
1791 mappedPrivateVars);
1792}
1793
1794static LogicalResult
1795cleanupPrivateVars(llvm::IRBuilderBase &builder,
1796 LLVM::ModuleTranslation &moduleTranslation, Location loc,
1797 SmallVectorImpl<llvm::Value *> &llvmPrivateVars,
1799 // private variable deallocation
1800 SmallVector<Region *> privateCleanupRegions;
1801 llvm::transform(privateDecls, std::back_inserter(privateCleanupRegions),
1802 [](omp::PrivateClauseOp privatizer) {
1803 return &privatizer.getDeallocRegion();
1804 });
1805
1806 if (failed(inlineOmpRegionCleanup(
1807 privateCleanupRegions, llvmPrivateVars, moduleTranslation, builder,
1808 "omp.private.dealloc", /*shouldLoadCleanupRegionArg=*/false)))
1809 return mlir::emitError(loc, "failed to inline `dealloc` region of an "
1810 "`omp.private` op in");
1811
1812 return success();
1813}
1814
1815/// Returns true if the construct contains omp.cancel or omp.cancellation_point
1817 // omp.cancel and omp.cancellation_point must be "closely nested" so they will
1818 // be visible and not inside of function calls. This is enforced by the
1819 // verifier.
1820 return op
1821 ->walk([](Operation *child) {
1822 if (mlir::isa<omp::CancelOp, omp::CancellationPointOp>(child))
1823 return WalkResult::interrupt();
1824 return WalkResult::advance();
1825 })
1826 .wasInterrupted();
1827}
1828
1829static LogicalResult
1830convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder,
1831 LLVM::ModuleTranslation &moduleTranslation) {
1832 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1833 using StorableBodyGenCallbackTy =
1834 llvm::OpenMPIRBuilder::StorableBodyGenCallbackTy;
1835
1836 auto sectionsOp = cast<omp::SectionsOp>(opInst);
1837
1838 if (failed(checkImplementationStatus(opInst)))
1839 return failure();
1840
1841 llvm::ArrayRef<bool> isByRef = getIsByRef(sectionsOp.getReductionByref());
1842 assert(isByRef.size() == sectionsOp.getNumReductionVars());
1843
1845 collectReductionDecls(sectionsOp, reductionDecls);
1846 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1847 findAllocaInsertPoint(builder, moduleTranslation);
1848
1849 SmallVector<llvm::Value *> privateReductionVariables(
1850 sectionsOp.getNumReductionVars());
1851 DenseMap<Value, llvm::Value *> reductionVariableMap;
1852
1853 MutableArrayRef<BlockArgument> reductionArgs =
1854 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
1855
1857 sectionsOp, reductionArgs, builder, moduleTranslation, allocaIP,
1858 reductionDecls, privateReductionVariables, reductionVariableMap,
1859 isByRef)))
1860 return failure();
1861
1863
1864 for (Operation &op : *sectionsOp.getRegion().begin()) {
1865 auto sectionOp = dyn_cast<omp::SectionOp>(op);
1866 if (!sectionOp) // omp.terminator
1867 continue;
1868
1869 Region &region = sectionOp.getRegion();
1870 auto sectionCB = [&sectionsOp, &region, &builder, &moduleTranslation](
1871 InsertPointTy allocaIP, InsertPointTy codeGenIP) {
1872 builder.restoreIP(codeGenIP);
1873
1874 // map the omp.section reduction block argument to the omp.sections block
1875 // arguments
1876 // TODO: this assumes that the only block arguments are reduction
1877 // variables
1878 assert(region.getNumArguments() ==
1879 sectionsOp.getRegion().getNumArguments());
1880 for (auto [sectionsArg, sectionArg] : llvm::zip_equal(
1881 sectionsOp.getRegion().getArguments(), region.getArguments())) {
1882 llvm::Value *llvmVal = moduleTranslation.lookupValue(sectionsArg);
1883 assert(llvmVal);
1884 moduleTranslation.mapValue(sectionArg, llvmVal);
1885 }
1886
1887 return convertOmpOpRegions(region, "omp.section.region", builder,
1888 moduleTranslation)
1889 .takeError();
1890 };
1891 sectionCBs.push_back(sectionCB);
1892 }
1893
1894 // No sections within omp.sections operation - skip generation. This situation
1895 // is only possible if there is only a terminator operation inside the
1896 // sections operation
1897 if (sectionCBs.empty())
1898 return success();
1899
1900 assert(isa<omp::SectionOp>(*sectionsOp.getRegion().op_begin()));
1901
1902 // TODO: Perform appropriate actions according to the data-sharing
1903 // attribute (shared, private, firstprivate, ...) of variables.
1904 // Currently defaults to shared.
1905 auto privCB = [&](InsertPointTy, InsertPointTy codeGenIP, llvm::Value &,
1906 llvm::Value &vPtr, llvm::Value *&replacementValue)
1907 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1908 replacementValue = &vPtr;
1909 return codeGenIP;
1910 };
1911
1912 // TODO: Perform finalization actions for variables. This has to be
1913 // called for variables which have destructors/finalizers.
1914 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
1915
1916 allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
1917 bool isCancellable = constructIsCancellable(sectionsOp);
1918 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1919 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1920 moduleTranslation.getOpenMPBuilder()->createSections(
1921 ompLoc, allocaIP, sectionCBs, privCB, finiCB, isCancellable,
1922 sectionsOp.getNowait());
1923
1924 if (failed(handleError(afterIP, opInst)))
1925 return failure();
1926
1927 builder.restoreIP(*afterIP);
1928
1929 // Process the reductions if required.
1931 sectionsOp, builder, moduleTranslation, allocaIP, reductionDecls,
1932 privateReductionVariables, isByRef, sectionsOp.getNowait());
1933}
1934
1935/// Converts an OpenMP single construct into LLVM IR using OpenMPIRBuilder.
1936static LogicalResult
1937convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder,
1938 LLVM::ModuleTranslation &moduleTranslation) {
1939 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1940 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1941
1942 if (failed(checkImplementationStatus(*singleOp)))
1943 return failure();
1944
1945 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
1946 builder.restoreIP(codegenIP);
1947 return convertOmpOpRegions(singleOp.getRegion(), "omp.single.region",
1948 builder, moduleTranslation)
1949 .takeError();
1950 };
1951 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
1952
1953 // Handle copyprivate
1954 Operation::operand_range cpVars = singleOp.getCopyprivateVars();
1955 std::optional<ArrayAttr> cpFuncs = singleOp.getCopyprivateSyms();
1958 for (size_t i = 0, e = cpVars.size(); i < e; ++i) {
1959 llvmCPVars.push_back(moduleTranslation.lookupValue(cpVars[i]));
1961 singleOp, cast<SymbolRefAttr>((*cpFuncs)[i]));
1962 llvmCPFuncs.push_back(
1963 moduleTranslation.lookupFunction(llvmFuncOp.getName()));
1964 }
1965
1966 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1967 moduleTranslation.getOpenMPBuilder()->createSingle(
1968 ompLoc, bodyCB, finiCB, singleOp.getNowait(), llvmCPVars,
1969 llvmCPFuncs);
1970
1971 if (failed(handleError(afterIP, *singleOp)))
1972 return failure();
1973
1974 builder.restoreIP(*afterIP);
1975 return success();
1976}
1977
1978static omp::DistributeOp
1980 // Early return if we found more than one distribute op or if we can't find
1981 // any distribute op in the teams region.
1982 omp::DistributeOp distOp;
1983 WalkResult walk = teamsOp.getRegion().walk([&](omp::DistributeOp op) {
1984 if (distOp)
1985 return WalkResult::interrupt();
1986 distOp = op;
1987 return WalkResult::skip();
1988 });
1989 if (walk.wasInterrupted() || !distOp)
1990 return {};
1991
1992 auto iface =
1993 llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(teamsOp.getOperation());
1994 // Check that all uses of the reduction block arg has the same distribute op
1995 // parent.
1997 for (auto ra : iface.getReductionBlockArgs())
1998 for (auto &use : ra.getUses()) {
1999 auto *useOp = use.getOwner();
2000 // Ignore debug uses.
2001 if (mlir::isa<LLVM::DbgDeclareOp, LLVM::DbgValueOp>(useOp)) {
2002 debugUses.push_back(useOp);
2003 continue;
2004 }
2005 if (!distOp->isProperAncestor(useOp))
2006 return {};
2007 }
2008
2009 // If we are going to use distribute reduction then remove any debug uses of
2010 // the reduction parameters in teamsOp. Otherwise they will be left without
2011 // any mapped value in moduleTranslation and will eventually error out.
2012 for (auto *use : debugUses)
2013 use->erase();
2014 return distOp;
2015}
2016
2017// Convert an OpenMP Teams construct to LLVM IR using OpenMPIRBuilder
2018static LogicalResult
2019convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder,
2020 LLVM::ModuleTranslation &moduleTranslation) {
2021 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2022 if (failed(checkImplementationStatus(*op)))
2023 return failure();
2024
2025 DenseMap<Value, llvm::Value *> reductionVariableMap;
2026 unsigned numReductionVars = op.getNumReductionVars();
2028 SmallVector<llvm::Value *> privateReductionVariables(numReductionVars);
2029 llvm::ArrayRef<bool> isByRef;
2030 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2031 findAllocaInsertPoint(builder, moduleTranslation);
2032
2033 // Only do teams reduction if there is no distribute op that captures the
2034 // reduction instead.
2035 bool doTeamsReduction = !getDistributeCapturingTeamsReduction(op);
2036 if (doTeamsReduction) {
2037 isByRef = getIsByRef(op.getReductionByref());
2038
2039 assert(isByRef.size() == op.getNumReductionVars());
2040
2041 MutableArrayRef<BlockArgument> reductionArgs =
2042 llvm::cast<omp::BlockArgOpenMPOpInterface>(*op).getReductionBlockArgs();
2043
2044 collectReductionDecls(op, reductionDecls);
2045
2047 op, reductionArgs, builder, moduleTranslation, allocaIP,
2048 reductionDecls, privateReductionVariables, reductionVariableMap,
2049 isByRef)))
2050 return failure();
2051 }
2052
2053 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
2055 moduleTranslation, allocaIP);
2056 builder.restoreIP(codegenIP);
2057 return convertOmpOpRegions(op.getRegion(), "omp.teams.region", builder,
2058 moduleTranslation)
2059 .takeError();
2060 };
2061
2062 llvm::Value *numTeamsLower = nullptr;
2063 if (Value numTeamsLowerVar = op.getNumTeamsLower())
2064 numTeamsLower = moduleTranslation.lookupValue(numTeamsLowerVar);
2065
2066 llvm::Value *numTeamsUpper = nullptr;
2067 if (!op.getNumTeamsUpperVars().empty())
2068 numTeamsUpper = moduleTranslation.lookupValue(op.getNumTeams(0));
2069
2070 llvm::Value *threadLimit = nullptr;
2071 if (!op.getThreadLimitVars().empty())
2072 threadLimit = moduleTranslation.lookupValue(op.getThreadLimit(0));
2073
2074 llvm::Value *ifExpr = nullptr;
2075 if (Value ifVar = op.getIfExpr())
2076 ifExpr = moduleTranslation.lookupValue(ifVar);
2077
2078 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2079 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2080 moduleTranslation.getOpenMPBuilder()->createTeams(
2081 ompLoc, bodyCB, numTeamsLower, numTeamsUpper, threadLimit, ifExpr);
2082
2083 if (failed(handleError(afterIP, *op)))
2084 return failure();
2085
2086 builder.restoreIP(*afterIP);
2087 if (doTeamsReduction) {
2088 // Process the reductions if required.
2090 op, builder, moduleTranslation, allocaIP, reductionDecls,
2091 privateReductionVariables, isByRef,
2092 /*isNoWait*/ false, /*isTeamsReduction*/ true);
2093 }
2094 return success();
2095}
2096
2097static llvm::omp::RTLDependenceKindTy
2098convertDependKind(mlir::omp::ClauseTaskDepend kind) {
2099 switch (kind) {
2100 case mlir::omp::ClauseTaskDepend::taskdependin:
2101 return llvm::omp::RTLDependenceKindTy::DepIn;
2102 // The OpenMP runtime requires that the codegen for 'depend' clause for
2103 // 'out' dependency kind must be the same as codegen for 'depend' clause
2104 // with 'inout' dependency.
2105 case mlir::omp::ClauseTaskDepend::taskdependout:
2106 case mlir::omp::ClauseTaskDepend::taskdependinout:
2107 return llvm::omp::RTLDependenceKindTy::DepInOut;
2108 case mlir::omp::ClauseTaskDepend::taskdependmutexinoutset:
2109 return llvm::omp::RTLDependenceKindTy::DepMutexInOutSet;
2110 case mlir::omp::ClauseTaskDepend::taskdependinoutset:
2111 return llvm::omp::RTLDependenceKindTy::DepInOutSet;
2112 }
2113 llvm_unreachable("unhandled depend kind");
2114}
2115
2117 std::optional<ArrayAttr> dependKinds, OperandRange dependVars,
2118 LLVM::ModuleTranslation &moduleTranslation,
2120 if (dependVars.empty())
2121 return;
2122 for (auto dep : llvm::zip(dependVars, dependKinds->getValue())) {
2123 auto kind =
2124 cast<mlir::omp::ClauseTaskDependAttr>(std::get<1>(dep)).getValue();
2125 llvm::omp::RTLDependenceKindTy type = convertDependKind(kind);
2126 llvm::Value *depVal = moduleTranslation.lookupValue(std::get<0>(dep));
2127 llvm::OpenMPIRBuilder::DependData dd(type, depVal->getType(), depVal);
2128 dds.emplace_back(dd);
2129 }
2130}
2131
2132/// Shared implementation of a callback which adds a termiator for the new block
2133/// created for the branch taken when an openmp construct is cancelled. The
2134/// terminator is saved in \p cancelTerminators. This callback is invoked only
2135/// if there is cancellation inside of the taskgroup body.
2136/// The terminator will need to be fixed to branch to the correct block to
2137/// cleanup the construct.
2139 SmallVectorImpl<llvm::UncondBrInst *> &cancelTerminators,
2140 llvm::IRBuilderBase &llvmBuilder, llvm::OpenMPIRBuilder &ompBuilder,
2141 mlir::Operation *op, llvm::omp::Directive cancelDirective) {
2142 auto finiCB = [&](llvm::OpenMPIRBuilder::InsertPointTy ip) -> llvm::Error {
2143 llvm::IRBuilderBase::InsertPointGuard guard(llvmBuilder);
2144
2145 // ip is currently in the block branched to if cancellation occurred.
2146 // We need to create a branch to terminate that block.
2147 llvmBuilder.restoreIP(ip);
2148
2149 // We must still clean up the construct after cancelling it, so we need to
2150 // branch to the block that finalizes the taskgroup.
2151 // That block has not been created yet so use this block as a dummy for now
2152 // and fix this after creating the operation.
2153 cancelTerminators.push_back(llvmBuilder.CreateBr(ip.getBlock()));
2154 return llvm::Error::success();
2155 };
2156 // We have to add the cleanup to the OpenMPIRBuilder before the body gets
2157 // created in case the body contains omp.cancel (which will then expect to be
2158 // able to find this cleanup callback).
2159 ompBuilder.pushFinalizationCB(
2160 {finiCB, cancelDirective, constructIsCancellable(op)});
2161}
2162
2163/// If we cancelled the construct, we should branch to the finalization block of
2164/// that construct. OMPIRBuilder structures the CFG such that the cleanup block
2165/// is immediately before the continuation block. Now this finalization has
2166/// been created we can fix the branch.
2167static void
2169 llvm::OpenMPIRBuilder &ompBuilder,
2170 const llvm::OpenMPIRBuilder::InsertPointTy &afterIP) {
2171 ompBuilder.popFinalizationCB();
2172 llvm::BasicBlock *constructFini = afterIP.getBlock()->getSinglePredecessor();
2173 for (llvm::UncondBrInst *cancelBranch : cancelTerminators)
2174 cancelBranch->setSuccessor(constructFini);
2175}
2176
2177namespace {
2178/// TaskContextStructManager takes care of creating and freeing a structure
2179/// containing information needed by the task body to execute.
2180class TaskContextStructManager {
2181public:
2182 TaskContextStructManager(llvm::IRBuilderBase &builder,
2183 LLVM::ModuleTranslation &moduleTranslation,
2184 MutableArrayRef<omp::PrivateClauseOp> privateDecls)
2185 : builder{builder}, moduleTranslation{moduleTranslation},
2186 privateDecls{privateDecls} {}
2187
2188 /// Creates a heap allocated struct containing space for each private
2189 /// variable. Invariant: privateVarTypes, privateDecls, and the elements of
2190 /// the structure should all have the same order (although privateDecls which
2191 /// do not read from the mold argument are skipped).
2192 void generateTaskContextStruct();
2193
2194 /// Create GEPs to access each member of the structure representing a private
2195 /// variable, adding them to llvmPrivateVars. Null values are added where
2196 /// private decls were skipped so that the ordering continues to match the
2197 /// private decls.
2198 void createGEPsToPrivateVars();
2199
2200 /// Given the address of the structure, return a GEP for each private variable
2201 /// in the structure. Null values are added where private decls were skipped
2202 /// so that the ordering continues to match the private decls.
2203 /// Must be called after generateTaskContextStruct().
2204 SmallVector<llvm::Value *>
2205 createGEPsToPrivateVars(llvm::Value *altStructPtr) const;
2206
2207 /// De-allocate the task context structure.
2208 void freeStructPtr();
2209
2210 MutableArrayRef<llvm::Value *> getLLVMPrivateVarGEPs() {
2211 return llvmPrivateVarGEPs;
2212 }
2213
2214 llvm::Value *getStructPtr() { return structPtr; }
2215
2216private:
2217 llvm::IRBuilderBase &builder;
2218 LLVM::ModuleTranslation &moduleTranslation;
2219 MutableArrayRef<omp::PrivateClauseOp> privateDecls;
2220
2221 /// The type of each member of the structure, in order.
2222 SmallVector<llvm::Type *> privateVarTypes;
2223
2224 /// LLVM values for each private variable, or null if that private variable is
2225 /// not included in the task context structure
2226 SmallVector<llvm::Value *> llvmPrivateVarGEPs;
2227
2228 /// A pointer to the structure containing context for this task.
2229 llvm::Value *structPtr = nullptr;
2230 /// The type of the structure
2231 llvm::Type *structTy = nullptr;
2232};
2233
2234/// IteratorInfo extracts and prepares loop bounds information from an
2235/// mlir::omp::IteratorOp for lowering to LLVM IR.
2236///
2237/// It computes the per-dimension trip counts and the total linearized trip
2238/// count, casted to i64. These are used to build a canonical loop and to
2239/// reconstruct the physical induction variables inside the loop body.
2240class IteratorInfo {
2241private:
2242 llvm::SmallVector<llvm::Value *> lowerBounds;
2243 llvm::SmallVector<llvm::Value *> upperBounds;
2244 llvm::SmallVector<llvm::Value *> steps;
2245 llvm::SmallVector<llvm::Value *> trips;
2246 unsigned dims;
2247 llvm::Value *totalTrips;
2248
2249 llvm::Value *lookUpAsI64(mlir::Value val, const LLVM::ModuleTranslation &mt,
2250 llvm::IRBuilderBase &builder) {
2251 llvm::Value *v = mt.lookupValue(val);
2252 if (!v)
2253 return nullptr;
2254 if (v->getType()->isIntegerTy(64))
2255 return v;
2256 if (v->getType()->isIntegerTy())
2257 return builder.CreateSExtOrTrunc(v, builder.getInt64Ty());
2258 return nullptr;
2259 }
2260
2261public:
2262 IteratorInfo(mlir::omp::IteratorOp itersOp,
2263 mlir::LLVM::ModuleTranslation &moduleTranslation,
2264 llvm::IRBuilderBase &builder) {
2265 dims = itersOp.getLoopLowerBounds().size();
2266 lowerBounds.resize(dims);
2267 upperBounds.resize(dims);
2268 steps.resize(dims);
2269 trips.resize(dims);
2270
2271 for (unsigned d = 0; d < dims; ++d) {
2272 llvm::Value *lb = lookUpAsI64(itersOp.getLoopLowerBounds()[d],
2273 moduleTranslation, builder);
2274 llvm::Value *ub = lookUpAsI64(itersOp.getLoopUpperBounds()[d],
2275 moduleTranslation, builder);
2276 llvm::Value *st =
2277 lookUpAsI64(itersOp.getLoopSteps()[d], moduleTranslation, builder);
2278 assert(lb && ub && st &&
2279 "Expect lowerBounds, upperBounds, and steps in IteratorOp");
2280 assert((!llvm::isa<llvm::ConstantInt>(st) ||
2281 !llvm::cast<llvm::ConstantInt>(st)->isZero()) &&
2282 "Expect non-zero step in IteratorOp");
2283
2284 lowerBounds[d] = lb;
2285 upperBounds[d] = ub;
2286 steps[d] = st;
2287
2288 // trips = ((ub - lb) / step) + 1 (inclusive ub, assume positive step)
2289 llvm::Value *diff = builder.CreateSub(ub, lb);
2290 llvm::Value *div = builder.CreateSDiv(diff, st);
2291 trips[d] = builder.CreateAdd(
2292 div, llvm::ConstantInt::get(builder.getInt64Ty(), 1));
2293 }
2294
2295 totalTrips = llvm::ConstantInt::get(builder.getInt64Ty(), 1);
2296 for (unsigned d = 0; d < dims; ++d)
2297 totalTrips = builder.CreateMul(totalTrips, trips[d]);
2298 }
2299
2300 unsigned getDims() const { return dims; }
2301 llvm::ArrayRef<llvm::Value *> getLowerBounds() const { return lowerBounds; }
2302 llvm::ArrayRef<llvm::Value *> getUpperBounds() const { return upperBounds; }
2303 llvm::ArrayRef<llvm::Value *> getSteps() const { return steps; }
2304 llvm::ArrayRef<llvm::Value *> getTrips() const { return trips; }
2305 llvm::Value *getTotalTrips() const { return totalTrips; }
2306};
2307
2308} // namespace
2309
2310void TaskContextStructManager::generateTaskContextStruct() {
2311 if (privateDecls.empty())
2312 return;
2313 privateVarTypes.reserve(privateDecls.size());
2314
2315 for (omp::PrivateClauseOp &privOp : privateDecls) {
2316 // Skip private variables which can safely be allocated and initialised
2317 // inside of the task
2318 if (!privOp.readsFromMold())
2319 continue;
2320 Type mlirType = privOp.getType();
2321 privateVarTypes.push_back(moduleTranslation.convertType(mlirType));
2322 }
2323
2324 if (privateVarTypes.empty())
2325 return;
2326
2327 structTy = llvm::StructType::get(moduleTranslation.getLLVMContext(),
2328 privateVarTypes);
2329
2330 llvm::DataLayout dataLayout =
2331 builder.GetInsertBlock()->getModule()->getDataLayout();
2332 llvm::Type *intPtrTy = builder.getIntPtrTy(dataLayout);
2333 llvm::Constant *allocSize = llvm::ConstantExpr::getSizeOf(structTy);
2334
2335 // Heap allocate the structure
2336 structPtr = builder.CreateMalloc(intPtrTy, structTy, allocSize,
2337 /*ArraySize=*/nullptr, /*MallocF=*/nullptr,
2338 "omp.task.context_ptr");
2339}
2340
2341SmallVector<llvm::Value *> TaskContextStructManager::createGEPsToPrivateVars(
2342 llvm::Value *altStructPtr) const {
2343 SmallVector<llvm::Value *> ret;
2344
2345 // Create GEPs for each struct member
2346 ret.reserve(privateDecls.size());
2347 llvm::Value *zero = builder.getInt32(0);
2348 unsigned i = 0;
2349 for (auto privDecl : privateDecls) {
2350 if (!privDecl.readsFromMold()) {
2351 // Handle this inside of the task so we don't pass unnessecary vars in
2352 ret.push_back(nullptr);
2353 continue;
2354 }
2355 llvm::Value *iVal = builder.getInt32(i);
2356 llvm::Value *gep = builder.CreateGEP(structTy, altStructPtr, {zero, iVal});
2357 ret.push_back(gep);
2358 i += 1;
2359 }
2360 return ret;
2361}
2362
2363void TaskContextStructManager::createGEPsToPrivateVars() {
2364 if (!structPtr)
2365 assert(privateVarTypes.empty());
2366 // Still need to run createGEPsToPrivateVars to populate llvmPrivateVarGEPs
2367 // with null values for skipped private decls
2368
2369 llvmPrivateVarGEPs = createGEPsToPrivateVars(structPtr);
2370}
2371
2372void TaskContextStructManager::freeStructPtr() {
2373 if (!structPtr)
2374 return;
2375
2376 llvm::IRBuilderBase::InsertPointGuard guard{builder};
2377 // Ensure we don't put the call to free() after the terminator
2378 builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
2379 builder.CreateFree(structPtr);
2380}
2381
2382static void storeAffinityEntry(llvm::IRBuilderBase &builder,
2383 llvm::OpenMPIRBuilder &ompBuilder,
2384 llvm::Value *affinityList, llvm::Value *index,
2385 llvm::Value *addr, llvm::Value *len) {
2386 llvm::StructType *kmpTaskAffinityInfoTy =
2387 ompBuilder.getKmpTaskAffinityInfoTy();
2388 llvm::Value *entry = builder.CreateInBoundsGEP(
2389 kmpTaskAffinityInfoTy, affinityList, index, "omp.affinity.entry");
2390
2391 addr = builder.CreatePtrToInt(addr, kmpTaskAffinityInfoTy->getElementType(0));
2392 len = builder.CreateIntCast(len, kmpTaskAffinityInfoTy->getElementType(1),
2393 /*isSigned=*/false);
2394 llvm::Value *flags = builder.getInt32(0);
2395
2396 builder.CreateStore(addr,
2397 builder.CreateStructGEP(kmpTaskAffinityInfoTy, entry, 0));
2398 builder.CreateStore(len,
2399 builder.CreateStructGEP(kmpTaskAffinityInfoTy, entry, 1));
2400 builder.CreateStore(flags,
2401 builder.CreateStructGEP(kmpTaskAffinityInfoTy, entry, 2));
2402}
2403
2405 llvm::IRBuilderBase &builder,
2406 LLVM::ModuleTranslation &moduleTranslation,
2407 llvm::Value *affinityList) {
2408 for (auto [i, affinityVar] : llvm::enumerate(affinityVars)) {
2409 auto entryOp = affinityVar.getDefiningOp<mlir::omp::AffinityEntryOp>();
2410 assert(entryOp && "affinity item must be omp.affinity_entry");
2411
2412 llvm::Value *addr = moduleTranslation.lookupValue(entryOp.getAddr());
2413 llvm::Value *len = moduleTranslation.lookupValue(entryOp.getLen());
2414 assert(addr && len && "expect affinity addr and len to be non-null");
2415 storeAffinityEntry(builder, *moduleTranslation.getOpenMPBuilder(),
2416 affinityList, builder.getInt64(i), addr, len);
2417 }
2418}
2419
2420static mlir::LogicalResult
2421convertIteratorRegion(llvm::Value *linearIV, IteratorInfo &iterInfo,
2422 mlir::Block &iteratorRegionBlock,
2423 llvm::IRBuilderBase &builder,
2424 LLVM::ModuleTranslation &moduleTranslation) {
2425 llvm::Value *tmp = linearIV;
2426 for (int d = (int)iterInfo.getDims() - 1; d >= 0; --d) {
2427 llvm::Value *trip = iterInfo.getTrips()[d];
2428 // idx_d = tmp % trip_d
2429 llvm::Value *idx = builder.CreateURem(tmp, trip);
2430 // tmp = tmp / trip_d
2431 tmp = builder.CreateUDiv(tmp, trip);
2432
2433 // physIV_d = lb_d + idx_d * step_d
2434 llvm::Value *physIV = builder.CreateAdd(
2435 iterInfo.getLowerBounds()[d],
2436 builder.CreateMul(idx, iterInfo.getSteps()[d]), "omp.it.phys_iv");
2437
2438 moduleTranslation.mapValue(iteratorRegionBlock.getArgument(d), physIV);
2439 }
2440
2441 // Translate the iterator region into the loop body.
2442 moduleTranslation.mapBlock(&iteratorRegionBlock, builder.GetInsertBlock());
2443 if (mlir::failed(moduleTranslation.convertBlock(iteratorRegionBlock,
2444 /*ignoreArguments=*/true,
2445 builder))) {
2446 return mlir::failure();
2447 }
2448 return mlir::success();
2449}
2450
2452 llvm::function_ref<void(llvm::Value *linearIV, mlir::omp::YieldOp yield)>;
2453
2454static mlir::LogicalResult
2455fillIteratorLoop(mlir::omp::IteratorOp itersOp, llvm::IRBuilderBase &builder,
2456 mlir::LLVM::ModuleTranslation &moduleTranslation,
2457 IteratorInfo &iterInfo, llvm::StringRef loopName,
2458 IteratorStoreEntryTy genStoreEntry) {
2459 mlir::Region &itersRegion = itersOp.getRegion();
2460 mlir::Block &iteratorRegionBlock = itersRegion.front();
2461
2462 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
2463
2464 auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy bodyIP,
2465 llvm::Value *linearIV) -> llvm::Error {
2466 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2467 builder.restoreIP(bodyIP);
2468
2469 if (failed(convertIteratorRegion(linearIV, iterInfo, iteratorRegionBlock,
2470 builder, moduleTranslation))) {
2471 return llvm::make_error<llvm::StringError>(
2472 "failed to convert iterator region", llvm::inconvertibleErrorCode());
2473 }
2474
2475 auto yield =
2476 mlir::dyn_cast<mlir::omp::YieldOp>(iteratorRegionBlock.getTerminator());
2477 assert(yield && yield.getResults().size() == 1 &&
2478 "expect omp.yield in iterator region to have one result");
2479
2480 genStoreEntry(linearIV, yield);
2481
2482 // Iterator-region block/value mappings are temporary for this conversion,
2483 // clear them to avoid stale entries in ModuleTranslation.
2484 moduleTranslation.forgetMapping(itersRegion);
2485
2486 return llvm::Error::success();
2487 };
2488
2489 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2490 moduleTranslation.getOpenMPBuilder()->createIteratorLoop(
2491 loc, iterInfo.getTotalTrips(), bodyGen, loopName);
2492 if (failed(handleError(afterIP, *itersOp)))
2493 return failure();
2494
2495 builder.restoreIP(*afterIP);
2496
2497 return mlir::success();
2498}
2499
2500static mlir::LogicalResult
2501buildAffinityData(mlir::omp::TaskOp &taskOp, llvm::IRBuilderBase &builder,
2502 mlir::LLVM::ModuleTranslation &moduleTranslation,
2503 llvm::OpenMPIRBuilder::AffinityData &ad) {
2504
2505 if (taskOp.getAffinityVars().empty() && taskOp.getIterated().empty()) {
2506 ad.Count = nullptr;
2507 ad.Info = nullptr;
2508 return mlir::success();
2509 }
2510
2512 llvm::StructType *kmpTaskAffinityInfoTy =
2513 moduleTranslation.getOpenMPBuilder()->getKmpTaskAffinityInfoTy();
2514
2515 auto allocateAffinityList = [&](llvm::Value *count) -> llvm::Value * {
2516 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2517 if (llvm::isa<llvm::Constant>(count) || llvm::isa<llvm::Argument>(count))
2518 builder.restoreIP(findAllocaInsertPoint(builder, moduleTranslation));
2519 return builder.CreateAlloca(kmpTaskAffinityInfoTy, count,
2520 "omp.affinity_list");
2521 };
2522
2523 auto createAffinity =
2524 [&](llvm::Value *count,
2525 llvm::Value *info) -> llvm::OpenMPIRBuilder::AffinityData {
2526 llvm::OpenMPIRBuilder::AffinityData ad{};
2527 ad.Count = builder.CreateTrunc(count, builder.getInt32Ty());
2528 ad.Info =
2529 builder.CreatePointerBitCastOrAddrSpaceCast(info, builder.getPtrTy(0));
2530 return ad;
2531 };
2532
2533 if (!taskOp.getAffinityVars().empty()) {
2534 llvm::Value *count = llvm::ConstantInt::get(
2535 builder.getInt64Ty(), taskOp.getAffinityVars().size());
2536 llvm::Value *list = allocateAffinityList(count);
2537 fillAffinityLocators(taskOp.getAffinityVars(), builder, moduleTranslation,
2538 list);
2539 ads.emplace_back(createAffinity(count, list));
2540 }
2541
2542 if (!taskOp.getIterated().empty()) {
2543 for (auto [i, iter] : llvm::enumerate(taskOp.getIterated())) {
2544 auto itersOp = iter.getDefiningOp<omp::IteratorOp>();
2545 assert(itersOp && "iterated value must be defined by omp.iterator");
2546 IteratorInfo iterInfo(itersOp, moduleTranslation, builder);
2547 llvm::Value *affList = allocateAffinityList(iterInfo.getTotalTrips());
2548 if (failed(fillIteratorLoop(
2549 itersOp, builder, moduleTranslation, iterInfo, "iterator",
2550 [&](llvm::Value *linearIV, mlir::omp::YieldOp yield) {
2551 auto entryOp = yield.getResults()[0]
2552 .getDefiningOp<mlir::omp::AffinityEntryOp>();
2553 assert(entryOp && "expect yield produce an affinity entry");
2554 llvm::Value *addr =
2555 moduleTranslation.lookupValue(entryOp.getAddr());
2556 llvm::Value *len =
2557 moduleTranslation.lookupValue(entryOp.getLen());
2558 storeAffinityEntry(builder,
2559 *moduleTranslation.getOpenMPBuilder(),
2560 affList, linearIV, addr, len);
2561 })))
2562 return llvm::failure();
2563 ads.emplace_back(createAffinity(iterInfo.getTotalTrips(), affList));
2564 }
2565 }
2566
2567 llvm::Value *totalAffinityCount = builder.getInt32(0);
2568 for (const auto &affinity : ads)
2569 totalAffinityCount = builder.CreateAdd(
2570 totalAffinityCount,
2571 builder.CreateIntCast(affinity.Count, builder.getInt32Ty(),
2572 /*isSigned=*/false));
2573
2574 llvm::Value *affinityInfo = ads.front().Info;
2575 if (ads.size() > 1) {
2576 llvm::StructType *kmpTaskAffinityInfoTy =
2577 moduleTranslation.getOpenMPBuilder()->getKmpTaskAffinityInfoTy();
2578 llvm::Value *affinityInfoElemSize = builder.getInt64(
2579 moduleTranslation.getLLVMModule()->getDataLayout().getTypeAllocSize(
2580 kmpTaskAffinityInfoTy));
2581
2582 llvm::Value *packedAffinityInfo = allocateAffinityList(totalAffinityCount);
2583 llvm::Value *packedAffinityInfoOffset = builder.getInt32(0);
2584 for (const auto &affinity : ads) {
2585 llvm::Value *affinityCount = builder.CreateIntCast(
2586 affinity.Count, builder.getInt32Ty(), /*isSigned=*/false);
2587 llvm::Value *affinityCountInt64 = builder.CreateIntCast(
2588 affinityCount, builder.getInt64Ty(), /*isSigned=*/false);
2589 llvm::Value *affinityInfoSize =
2590 builder.CreateMul(affinityCountInt64, affinityInfoElemSize);
2591
2592 llvm::Value *packedAffinityInfoIndex = builder.CreateIntCast(
2593 packedAffinityInfoOffset, kmpTaskAffinityInfoTy->getElementType(0),
2594 /*isSigned=*/false);
2595 packedAffinityInfoIndex = builder.CreateInBoundsGEP(
2596 kmpTaskAffinityInfoTy, packedAffinityInfo, packedAffinityInfoIndex);
2597
2598 builder.CreateMemCpy(
2599 packedAffinityInfoIndex, llvm::Align(1),
2600 builder.CreatePointerBitCastOrAddrSpaceCast(
2601 affinity.Info, builder.getPtrTy(packedAffinityInfoIndex->getType()
2602 ->getPointerAddressSpace())),
2603 llvm::Align(1), affinityInfoSize);
2604
2605 packedAffinityInfoOffset =
2606 builder.CreateAdd(packedAffinityInfoOffset, affinityCount);
2607 }
2608
2609 affinityInfo = packedAffinityInfo;
2610 }
2611
2612 ad.Count = totalAffinityCount;
2613 ad.Info = affinityInfo;
2614
2615 return mlir::success();
2616}
2617
2618// Allocates a single kmp_dep_info array sized to hold both locator
2619// (non-iterated) and iterated entries, fills the locator entries first, then
2620// runs an iterator loop for each iterator modifier object.
2621static mlir::LogicalResult
2622buildDependData(OperandRange dependVars, std::optional<ArrayAttr> dependKinds,
2623 OperandRange dependIterated,
2624 std::optional<ArrayAttr> dependIteratedKinds,
2625 llvm::IRBuilderBase &builder,
2626 mlir::LLVM::ModuleTranslation &moduleTranslation,
2627 llvm::OpenMPIRBuilder::DependenciesInfo &taskDeps) {
2628 if (dependIterated.empty()) {
2629 buildDependDataLocator(dependKinds, dependVars, moduleTranslation,
2630 taskDeps.Deps);
2631 return mlir::success();
2632 }
2633
2634 llvm::OpenMPIRBuilder &ompBuilder = *moduleTranslation.getOpenMPBuilder();
2635 llvm::Type *dependInfoTy = ompBuilder.DependInfo;
2636 unsigned numLocator = dependVars.size();
2637
2638 // Compute total count: locator deps + sum of iterator trip counts.
2639 llvm::Value *totalCount =
2640 llvm::ConstantInt::get(builder.getInt64Ty(), numLocator);
2641
2643 for (auto iter : dependIterated) {
2644 auto itersOp = iter.getDefiningOp<mlir::omp::IteratorOp>();
2645 assert(itersOp && "depend_iterated value must be defined by omp.iterator");
2646 iterInfos.emplace_back(itersOp, moduleTranslation, builder);
2647 totalCount =
2648 builder.CreateAdd(totalCount, iterInfos.back().getTotalTrips());
2649 }
2650
2651 // Heap-allocate the kmp_depend_info array so we don't risk
2652 // dynamic-sized alloca outside the entry block (e.g. inside loops).
2653 llvm::Constant *allocSize = llvm::ConstantExpr::getSizeOf(dependInfoTy);
2654 llvm::Value *depArray =
2655 builder.CreateMalloc(ompBuilder.SizeTy, dependInfoTy, allocSize,
2656 totalCount, /*MallocF=*/nullptr, ".dep.arr.addr");
2657
2658 // Fill non-iterated entries at indices [0, numLocator).
2659 if (numLocator > 0) {
2661 buildDependDataLocator(dependKinds, dependVars, moduleTranslation, dds);
2662 for (auto [i, dd] : llvm::enumerate(dds)) {
2663 llvm::Value *idx = llvm::ConstantInt::get(builder.getInt64Ty(), i);
2664 llvm::Value *entry =
2665 builder.CreateInBoundsGEP(dependInfoTy, depArray, idx);
2666 ompBuilder.emitTaskDependency(builder, entry, dd);
2667 }
2668 }
2669
2670 // Fill iterated entries starting at index numLocator.
2671 llvm::Value *offset =
2672 llvm::ConstantInt::get(builder.getInt64Ty(), numLocator);
2673 for (auto [i, iterInfo] : llvm::enumerate(iterInfos)) {
2674 auto kindAttr = cast<mlir::omp::ClauseTaskDependAttr>(
2675 dependIteratedKinds->getValue()[i]);
2676 llvm::omp::RTLDependenceKindTy rtlKind =
2677 convertDependKind(kindAttr.getValue());
2678
2679 auto itersOp = dependIterated[i].getDefiningOp<mlir::omp::IteratorOp>();
2680 if (failed(fillIteratorLoop(
2681 itersOp, builder, moduleTranslation, iterInfo, "dep_iterator",
2682 [&](llvm::Value *linearIV, mlir::omp::YieldOp yield) {
2683 llvm::Value *addr =
2684 moduleTranslation.lookupValue(yield.getResults()[0]);
2685 llvm::Value *idx = builder.CreateAdd(offset, linearIV);
2686 llvm::Value *entry =
2687 builder.CreateInBoundsGEP(dependInfoTy, depArray, idx);
2688 ompBuilder.emitTaskDependency(
2689 builder, entry,
2690 llvm::OpenMPIRBuilder::DependData{rtlKind, addr->getType(),
2691 addr});
2692 })))
2693 return mlir::failure();
2694
2695 // Advance offset by the trip count of this iterator.
2696 offset = builder.CreateAdd(offset, iterInfo.getTotalTrips());
2697 }
2698
2699 taskDeps.DepArray = depArray;
2700 taskDeps.NumDeps = builder.CreateTrunc(totalCount, builder.getInt32Ty());
2701 return mlir::success();
2702}
2703
2704/// Converts an OpenMP task construct into LLVM IR using OpenMPIRBuilder.
2705static LogicalResult
2706convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
2707 LLVM::ModuleTranslation &moduleTranslation) {
2708 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2709 if (failed(checkImplementationStatus(*taskOp)))
2710 return failure();
2711
2712 PrivateVarsInfo privateVarsInfo(taskOp);
2713 TaskContextStructManager taskStructMgr{builder, moduleTranslation,
2714 privateVarsInfo.privatizers};
2715
2716 // Allocate and copy private variables before creating the task. This avoids
2717 // accessing invalid memory if (after this scope ends) the private variables
2718 // are initialized from host variables or if the variables are copied into
2719 // from host variables (firstprivate). The insertion point is just before
2720 // where the code for creating and scheduling the task will go. That puts this
2721 // code outside of the outlined task region, which is what we want because
2722 // this way the initialization and copy regions are executed immediately while
2723 // the host variable data are still live.
2724
2725 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2726 findAllocaInsertPoint(builder, moduleTranslation);
2727
2728 // Not using splitBB() because that requires the current block to have a
2729 // terminator.
2730 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end());
2731 llvm::BasicBlock *taskStartBlock = llvm::BasicBlock::Create(
2732 builder.getContext(), "omp.task.start",
2733 /*Parent=*/builder.GetInsertBlock()->getParent());
2734 llvm::Instruction *branchToTaskStartBlock = builder.CreateBr(taskStartBlock);
2735 builder.SetInsertPoint(branchToTaskStartBlock);
2736
2737 // Now do this again to make the initialization and copy blocks
2738 llvm::BasicBlock *copyBlock =
2739 splitBB(builder, /*CreateBranch=*/true, "omp.private.copy");
2740 llvm::BasicBlock *initBlock =
2741 splitBB(builder, /*CreateBranch=*/true, "omp.private.init");
2742
2743 // Now the control flow graph should look like
2744 // starter_block:
2745 // <---- where we started when convertOmpTaskOp was called
2746 // br %omp.private.init
2747 // omp.private.init:
2748 // br %omp.private.copy
2749 // omp.private.copy:
2750 // br %omp.task.start
2751 // omp.task.start:
2752 // <---- where we want the insertion point to be when we call createTask()
2753
2754 // Save the alloca insertion point on ModuleTranslation stack for use in
2755 // nested regions.
2757 moduleTranslation, allocaIP);
2758
2759 // Allocate and initialize private variables
2760 builder.SetInsertPoint(initBlock->getTerminator());
2761
2762 // Create task variable structure
2763 taskStructMgr.generateTaskContextStruct();
2764 // GEPs so that we can initialize the variables. Don't use these GEPs inside
2765 // of the body otherwise it will be the GEP not the struct which is fowarded
2766 // to the outlined function. GEPs forwarded in this way are passed in a
2767 // stack-allocated (by OpenMPIRBuilder) structure which is not safe for tasks
2768 // which may not be executed until after the current stack frame goes out of
2769 // scope.
2770 taskStructMgr.createGEPsToPrivateVars();
2771
2772 for (auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVarAlloc] :
2773 llvm::zip_equal(privateVarsInfo.privatizers, privateVarsInfo.mlirVars,
2774 privateVarsInfo.blockArgs,
2775 taskStructMgr.getLLVMPrivateVarGEPs())) {
2776 // To be handled inside the task.
2777 if (!privDecl.readsFromMold())
2778 continue;
2779 assert(llvmPrivateVarAlloc &&
2780 "reads from mold so shouldn't have been skipped");
2781
2782 llvm::Expected<llvm::Value *> privateVarOrErr =
2783 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2784 blockArg, llvmPrivateVarAlloc, initBlock);
2785 if (!privateVarOrErr)
2786 return handleError(privateVarOrErr, *taskOp.getOperation());
2787
2789
2790 // TODO: this is a bit of a hack for Fortran character boxes.
2791 // Character boxes are passed by value into the init region and then the
2792 // initialized character box is yielded by value. Here we need to store the
2793 // yielded value into the private allocation, and load the private
2794 // allocation to match the type expected by region block arguments.
2795 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
2796 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2797 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
2798 // Load it so we have the value pointed to by the GEP
2799 llvmPrivateVarAlloc = builder.CreateLoad(privateVarOrErr.get()->getType(),
2800 llvmPrivateVarAlloc);
2801 }
2802 assert(llvmPrivateVarAlloc->getType() ==
2803 moduleTranslation.convertType(blockArg.getType()));
2804
2805 // Mapping blockArg -> llvmPrivateVarAlloc is done inside the body callback
2806 // so that OpenMPIRBuilder doesn't try to pass each GEP address through a
2807 // stack allocated structure.
2808 }
2809
2810 // firstprivate copy region
2811 setInsertPointForPossiblyEmptyBlock(builder, copyBlock);
2812 if (failed(copyFirstPrivateVars(
2813 taskOp, builder, moduleTranslation, privateVarsInfo.mlirVars,
2814 taskStructMgr.getLLVMPrivateVarGEPs(), privateVarsInfo.privatizers,
2815 taskOp.getPrivateNeedsBarrier())))
2816 return llvm::failure();
2817
2818 llvm::OpenMPIRBuilder::AffinityData ad;
2819 if (failed(buildAffinityData(taskOp, builder, moduleTranslation, ad)))
2820 return llvm::failure();
2821
2822 // Set up for call to createTask()
2823 builder.SetInsertPoint(taskStartBlock);
2824
2825 auto bodyCB = [&](InsertPointTy allocaIP,
2826 InsertPointTy codegenIP) -> llvm::Error {
2827 // Save the alloca insertion point on ModuleTranslation stack for use in
2828 // nested regions.
2830 moduleTranslation, allocaIP);
2831
2832 // translate the body of the task:
2833 builder.restoreIP(codegenIP);
2834
2835 llvm::BasicBlock *privInitBlock = nullptr;
2836 privateVarsInfo.llvmVars.resize(privateVarsInfo.blockArgs.size());
2837 for (auto [i, zip] : llvm::enumerate(llvm::zip_equal(
2838 privateVarsInfo.blockArgs, privateVarsInfo.privatizers,
2839 privateVarsInfo.mlirVars))) {
2840 auto [blockArg, privDecl, mlirPrivVar] = zip;
2841 // This is handled before the task executes
2842 if (privDecl.readsFromMold())
2843 continue;
2844
2845 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2846 llvm::Type *llvmAllocType =
2847 moduleTranslation.convertType(privDecl.getType());
2848 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
2849 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
2850 llvmAllocType, /*ArraySize=*/nullptr, "omp.private.alloc");
2851
2852 llvm::Expected<llvm::Value *> privateVarOrError =
2853 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2854 blockArg, llvmPrivateVar, privInitBlock);
2855 if (!privateVarOrError)
2856 return privateVarOrError.takeError();
2857 moduleTranslation.mapValue(blockArg, privateVarOrError.get());
2858 privateVarsInfo.llvmVars[i] = privateVarOrError.get();
2859 }
2860
2861 taskStructMgr.createGEPsToPrivateVars();
2862 for (auto [i, llvmPrivVar] :
2863 llvm::enumerate(taskStructMgr.getLLVMPrivateVarGEPs())) {
2864 if (!llvmPrivVar) {
2865 assert(privateVarsInfo.llvmVars[i] &&
2866 "This is added in the loop above");
2867 continue;
2868 }
2869 privateVarsInfo.llvmVars[i] = llvmPrivVar;
2870 }
2871
2872 // Find and map the addresses of each variable within the task context
2873 // structure
2874 for (auto [blockArg, llvmPrivateVar, privateDecl] :
2875 llvm::zip_equal(privateVarsInfo.blockArgs, privateVarsInfo.llvmVars,
2876 privateVarsInfo.privatizers)) {
2877 // This was handled above.
2878 if (!privateDecl.readsFromMold())
2879 continue;
2880 // Fix broken pass-by-value case for Fortran character boxes
2881 if (!mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2882 llvmPrivateVar = builder.CreateLoad(
2883 moduleTranslation.convertType(blockArg.getType()), llvmPrivateVar);
2884 }
2885 assert(llvmPrivateVar->getType() ==
2886 moduleTranslation.convertType(blockArg.getType()));
2887 moduleTranslation.mapValue(blockArg, llvmPrivateVar);
2888 }
2889
2890 auto continuationBlockOrError = convertOmpOpRegions(
2891 taskOp.getRegion(), "omp.task.region", builder, moduleTranslation);
2892 if (failed(handleError(continuationBlockOrError, *taskOp)))
2893 return llvm::make_error<PreviouslyReportedError>();
2894
2895 builder.SetInsertPoint(continuationBlockOrError.get()->getTerminator());
2896
2897 if (failed(cleanupPrivateVars(builder, moduleTranslation, taskOp.getLoc(),
2898 privateVarsInfo.llvmVars,
2899 privateVarsInfo.privatizers)))
2900 return llvm::make_error<PreviouslyReportedError>();
2901
2902 // Free heap allocated task context structure at the end of the task.
2903 taskStructMgr.freeStructPtr();
2904
2905 return llvm::Error::success();
2906 };
2907
2908 llvm::OpenMPIRBuilder &ompBuilder = *moduleTranslation.getOpenMPBuilder();
2909 SmallVector<llvm::UncondBrInst *> cancelTerminators;
2910 // The directive to match here is OMPD_taskgroup because it is the taskgroup
2911 // which is canceled. This is handled here because it is the task's cleanup
2912 // block which should be branched to.
2913 pushCancelFinalizationCB(cancelTerminators, builder, ompBuilder, taskOp,
2914 llvm::omp::Directive::OMPD_taskgroup);
2915
2916 llvm::OpenMPIRBuilder::DependenciesInfo dependencies;
2917 if (failed(buildDependData(taskOp.getDependVars(), taskOp.getDependKinds(),
2918 taskOp.getDependIterated(),
2919 taskOp.getDependIteratedKinds(), builder,
2920 moduleTranslation, dependencies)))
2921 return failure();
2922
2923 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2924 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2925 moduleTranslation.getOpenMPBuilder()->createTask(
2926 ompLoc, allocaIP, bodyCB, !taskOp.getUntied(),
2927 moduleTranslation.lookupValue(taskOp.getFinal()),
2928 moduleTranslation.lookupValue(taskOp.getIfExpr()), dependencies, ad,
2929 taskOp.getMergeable(),
2930 moduleTranslation.lookupValue(taskOp.getEventHandle()),
2931 moduleTranslation.lookupValue(taskOp.getPriority()));
2932
2933 if (failed(handleError(afterIP, *taskOp)))
2934 return failure();
2935
2936 // Set the correct branch target for task cancellation
2937 popCancelFinalizationCB(cancelTerminators, ompBuilder, afterIP.get());
2938
2939 builder.restoreIP(*afterIP);
2940
2941 if (dependencies.DepArray)
2942 builder.CreateFree(dependencies.DepArray);
2943
2944 return success();
2945}
2946
2947/// The correct entry point is convertOmpTaskloopContextOp. This gets called
2948/// whilst lowering the body of the taskloop context (i.e. the task function).
2949static LogicalResult
2950convertOmpTaskloopWrapperOp(omp::TaskloopWrapperOp loopWrapperOp,
2951 llvm::IRBuilderBase &builder,
2952 LLVM::ModuleTranslation &moduleTranslation) {
2953 mlir::Operation &opInst = *loopWrapperOp.getOperation();
2954 if (failed(checkImplementationStatus(opInst)))
2955 return failure();
2956
2957 // Recurse into the loop body.
2958 auto continuationBlockOrError = convertOmpOpRegions(
2959 loopWrapperOp.getRegion(), "omp.taskloop.wrapper.region", builder,
2960 moduleTranslation);
2961
2962 if (failed(handleError(continuationBlockOrError, opInst)))
2963 return failure();
2964
2965 builder.SetInsertPoint(continuationBlockOrError.get());
2966 return success();
2967}
2968
2969/// Look up the given value in the mapping, and if it's not there, translate its
2970/// defining operation at the current builder insertion point. Only pure,
2971/// regionless operations are supported because the same operation will later be
2972/// translated again when the taskloop body itself is lowered.
2973static llvm::Expected<llvm::Value *>
2975 LLVM::ModuleTranslation &moduleTranslation,
2976 llvm::IRBuilderBase &builder) {
2977 if (llvm::Value *mapped = moduleTranslation.lookupValue(value))
2978 return mapped;
2979
2980 Operation *defOp = value.getDefiningOp();
2981 if (!defOp)
2982 return llvm::make_error<llvm::StringError>(
2983 "value is a block argument and is not mapped",
2984 llvm::inconvertibleErrorCode());
2985 if (defOp->getNumRegions() != 0 || !isPure(defOp))
2986 return llvm::make_error<llvm::StringError>(
2987 "unsupported op defining taskloop loop bound",
2988 llvm::inconvertibleErrorCode());
2989
2990 SmallVector<Value> mappingsToRemove;
2991 mappingsToRemove.reserve(defOp->getNumOperands() + defOp->getNumResults());
2992 for (Value operand : defOp->getOperands()) {
2993 if (moduleTranslation.lookupValue(operand))
2994 continue;
2995
2996 llvm::Expected<llvm::Value *> operandOrError =
2997 lookupOrTranslatePureValue(operand, moduleTranslation, builder);
2998 if (!operandOrError)
2999 return operandOrError.takeError();
3000 moduleTranslation.mapValue(operand, *operandOrError);
3001 mappingsToRemove.push_back(operand);
3002 }
3003
3004 if (failed(moduleTranslation.convertOperation(*defOp, builder)))
3005 return llvm::make_error<llvm::StringError>(
3006 "failed to convert op defining taskloop loop bound",
3007 llvm::inconvertibleErrorCode());
3008
3009 llvm::Value *result = moduleTranslation.lookupValue(value);
3010 assert(result && "expected conversion of loop bound op to produce a value");
3011
3012 for (Value resultValue : defOp->getResults()) {
3013 if (moduleTranslation.lookupValue(resultValue))
3014 mappingsToRemove.push_back(resultValue);
3015 }
3016 for (Value mappedValue : mappingsToRemove)
3017 moduleTranslation.forgetMapping(mappedValue);
3018
3019 return result;
3020}
3021
3022static llvm::Error
3023computeTaskloopBounds(omp::LoopNestOp loopOp, llvm::IRBuilderBase &builder,
3024 LLVM::ModuleTranslation &moduleTranslation,
3025 llvm::Value *&lbVal, llvm::Value *&ubVal,
3026 llvm::Value *&stepVal) {
3027 Operation::operand_range lowerBounds = loopOp.getLoopLowerBounds();
3028 Operation::operand_range upperBounds = loopOp.getLoopUpperBounds();
3029 Operation::operand_range steps = loopOp.getLoopSteps();
3030
3031 llvm::Expected<llvm::Value *> firstLbOrErr =
3032 lookupOrTranslatePureValue(lowerBounds[0], moduleTranslation, builder);
3033 if (!firstLbOrErr)
3034 return firstLbOrErr.takeError();
3035
3036 llvm::Type *boundType = (*firstLbOrErr)->getType();
3037 ubVal = builder.getIntN(boundType->getIntegerBitWidth(), 1);
3038 if (loopOp.getCollapseNumLoops() > 1) {
3039 // In cases where Collapse is used with Taskloop, the upper bound of the
3040 // iteration space needs to be recalculated to cater for the collapsed loop.
3041 // The Collapsed Loop UpperBound is the product of all collapsed
3042 // loop's tripcount.
3043 // The LowerBound for collapsed loops is always 1. When the loops are
3044 // collapsed, it will reset the bounds and introduce processing to ensure
3045 // the index's are presented as expected. As this happens after creating
3046 // Taskloop, these bounds need predicting. Example:
3047 // !$omp taskloop collapse(2)
3048 // do i = 1, 10
3049 // do j = 1, 5
3050 // ..
3051 // end do
3052 // end do
3053 // This loop above has a total of 50 iterations, so the lb will be 1, and
3054 // the ub will be 50. collapseLoops in OMPIRBuilder then handles ensuring
3055 // that i and j are properly presented when used in the loop.
3056 for (uint64_t i = 0; i < loopOp.getCollapseNumLoops(); i++) {
3058 i == 0 ? std::move(firstLbOrErr)
3059 : lookupOrTranslatePureValue(lowerBounds[i], moduleTranslation,
3060 builder);
3061 if (!lbOrErr)
3062 return lbOrErr.takeError();
3064 upperBounds[i], moduleTranslation, builder);
3065 if (!ubOrErr)
3066 return ubOrErr.takeError();
3068 lookupOrTranslatePureValue(steps[i], moduleTranslation, builder);
3069 if (!stepOrErr)
3070 return stepOrErr.takeError();
3071
3072 llvm::Value *loopLb = *lbOrErr;
3073 llvm::Value *loopUb = *ubOrErr;
3074 llvm::Value *loopStep = *stepOrErr;
3075 // In some cases, such as where the ub is less than the lb so the loop
3076 // steps down, the calculation for the loopTripCount is swapped. To ensure
3077 // the correct value is found, calculate both UB - LB and LB - UB then
3078 // select which value to use depending on how the loop has been
3079 // configured.
3080 llvm::Value *loopLbMinusOne = builder.CreateSub(
3081 loopLb, builder.getIntN(boundType->getIntegerBitWidth(), 1));
3082 llvm::Value *loopUbMinusOne = builder.CreateSub(
3083 loopUb, builder.getIntN(boundType->getIntegerBitWidth(), 1));
3084 llvm::Value *boundsCmp = builder.CreateICmpSLT(loopLb, loopUb);
3085 llvm::Value *ubMinusLb = builder.CreateSub(loopUb, loopLbMinusOne);
3086 llvm::Value *lbMinusUb = builder.CreateSub(loopLb, loopUbMinusOne);
3087 llvm::Value *loopTripCount =
3088 builder.CreateSelect(boundsCmp, ubMinusLb, lbMinusUb);
3089 loopTripCount = builder.CreateBinaryIntrinsic(
3090 llvm::Intrinsic::abs, loopTripCount, builder.getFalse());
3091 // For loops that have a step value not equal to 1, we need to adjust the
3092 // trip count to ensure the correct number of iterations for the loop is
3093 // captured.
3094 llvm::Value *loopTripCountDivStep =
3095 builder.CreateSDiv(loopTripCount, loopStep);
3096 loopTripCountDivStep = builder.CreateBinaryIntrinsic(
3097 llvm::Intrinsic::abs, loopTripCountDivStep, builder.getFalse());
3098 llvm::Value *loopTripCountRem =
3099 builder.CreateSRem(loopTripCount, loopStep);
3100 loopTripCountRem = builder.CreateBinaryIntrinsic(
3101 llvm::Intrinsic::abs, loopTripCountRem, builder.getFalse());
3102 llvm::Value *needsRoundUp = builder.CreateICmpNE(
3103 loopTripCountRem,
3104 builder.getIntN(loopTripCountRem->getType()->getIntegerBitWidth(),
3105 0));
3106 loopTripCount =
3107 builder.CreateAdd(loopTripCountDivStep,
3108 builder.CreateZExtOrTrunc(
3109 needsRoundUp, loopTripCountDivStep->getType()));
3110 ubVal = builder.CreateMul(ubVal, loopTripCount);
3111 }
3112 lbVal = builder.getIntN(boundType->getIntegerBitWidth(), 1);
3113 stepVal = builder.getIntN(boundType->getIntegerBitWidth(), 1);
3114 } else {
3116 lookupOrTranslatePureValue(upperBounds[0], moduleTranslation, builder);
3117 if (!ubOrErr)
3118 return ubOrErr.takeError();
3120 lookupOrTranslatePureValue(steps[0], moduleTranslation, builder);
3121 if (!stepOrErr)
3122 return stepOrErr.takeError();
3123 lbVal = *firstLbOrErr;
3124 ubVal = *ubOrErr;
3125 stepVal = *stepOrErr;
3126 }
3127
3128 assert(lbVal != nullptr && "Expected value for lbVal");
3129 assert(ubVal != nullptr && "Expected value for ubVal");
3130 assert(stepVal != nullptr && "Expected value for stepVal");
3131 return llvm::Error::success();
3132}
3133
3134// Converts an OpenMP taskloop construct into LLVM IR using OpenMPIRBuilder.
3135static LogicalResult
3136convertOmpTaskloopContextOp(omp::TaskloopContextOp contextOp,
3137 llvm::IRBuilderBase &builder,
3138 LLVM::ModuleTranslation &moduleTranslation) {
3139 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
3140 mlir::Operation &opInst = *contextOp.getOperation();
3141 omp::TaskloopWrapperOp loopWrapperOp = contextOp.getLoopOp();
3142 if (failed(checkImplementationStatus(opInst)))
3143 return failure();
3144
3145 // It stores the pointer of allocated firstprivate copies,
3146 // which can be used later for freeing the allocated space.
3147 SmallVector<llvm::Value *> llvmFirstPrivateVars;
3148 PrivateVarsInfo privateVarsInfo(contextOp);
3149 TaskContextStructManager taskStructMgr{builder, moduleTranslation,
3150 privateVarsInfo.privatizers};
3151
3152 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3153 findAllocaInsertPoint(builder, moduleTranslation);
3154
3155 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end());
3156 llvm::BasicBlock *taskloopStartBlock = llvm::BasicBlock::Create(
3157 builder.getContext(), "omp.taskloop.wrapper.start",
3158 /*Parent=*/builder.GetInsertBlock()->getParent());
3159 llvm::Instruction *branchToTaskloopStartBlock =
3160 builder.CreateBr(taskloopStartBlock);
3161 builder.SetInsertPoint(branchToTaskloopStartBlock);
3162
3163 llvm::BasicBlock *copyBlock =
3164 splitBB(builder, /*CreateBranch=*/true, "omp.private.copy");
3165 llvm::BasicBlock *initBlock =
3166 splitBB(builder, /*CreateBranch=*/true, "omp.private.init");
3167
3169 moduleTranslation, allocaIP);
3170
3171 // Allocate and initialize private variables
3172 builder.SetInsertPoint(initBlock->getTerminator());
3173
3174 // TODO: don't allocate if the loop has zero iterations.
3175 taskStructMgr.generateTaskContextStruct();
3176 taskStructMgr.createGEPsToPrivateVars();
3177
3178 llvmFirstPrivateVars.resize(privateVarsInfo.blockArgs.size());
3179
3180 for (auto [i, zip] : llvm::enumerate(llvm::zip_equal(
3181 privateVarsInfo.privatizers, privateVarsInfo.mlirVars,
3182 privateVarsInfo.blockArgs, taskStructMgr.getLLVMPrivateVarGEPs()))) {
3183 auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVarAlloc] = zip;
3184 // To be handled inside the taskloop.
3185 if (!privDecl.readsFromMold())
3186 continue;
3187 assert(llvmPrivateVarAlloc &&
3188 "reads from mold so shouldn't have been skipped");
3189
3190 llvm::Expected<llvm::Value *> privateVarOrErr =
3191 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
3192 blockArg, llvmPrivateVarAlloc, initBlock);
3193 if (!privateVarOrErr)
3194 return handleError(privateVarOrErr, *contextOp.getOperation());
3195
3196 llvmFirstPrivateVars[i] = privateVarOrErr.get();
3197
3198 llvm::IRBuilderBase::InsertPointGuard guard(builder);
3199 builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
3200
3201 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
3202 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
3203 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
3204 // Load it so we have the value pointed to by the GEP
3205 llvmPrivateVarAlloc = builder.CreateLoad(privateVarOrErr.get()->getType(),
3206 llvmPrivateVarAlloc);
3207 }
3208 assert(llvmPrivateVarAlloc->getType() ==
3209 moduleTranslation.convertType(blockArg.getType()));
3210 }
3211
3212 // firstprivate copy region
3213 setInsertPointForPossiblyEmptyBlock(builder, copyBlock);
3214 if (failed(copyFirstPrivateVars(
3215 contextOp, builder, moduleTranslation, privateVarsInfo.mlirVars,
3216 taskStructMgr.getLLVMPrivateVarGEPs(), privateVarsInfo.privatizers,
3217 contextOp.getPrivateNeedsBarrier())))
3218 return llvm::failure();
3219
3220 // Set up inserttion point for call to createTaskloop()
3221 builder.SetInsertPoint(taskloopStartBlock);
3222
3223 auto loopOp = cast<omp::LoopNestOp>(loopWrapperOp.getWrappedLoop());
3224 llvm::Value *lbVal = nullptr;
3225 llvm::Value *ubVal = nullptr;
3226 llvm::Value *stepVal = nullptr;
3227 if (llvm::Error err = computeTaskloopBounds(
3228 loopOp, builder, moduleTranslation, lbVal, ubVal, stepVal))
3229 return handleError(std::move(err), opInst);
3230
3231 auto bodyCB = [&](InsertPointTy allocaIP,
3232 InsertPointTy codegenIP) -> llvm::Error {
3233 // Save the alloca insertion point on ModuleTranslation stack for use in
3234 // nested regions.
3236 moduleTranslation, allocaIP);
3237
3238 // translate the body of the taskloop:
3239 builder.restoreIP(codegenIP);
3240
3241 llvm::BasicBlock *privInitBlock = nullptr;
3242 privateVarsInfo.llvmVars.resize(privateVarsInfo.blockArgs.size());
3243 for (auto [i, zip] : llvm::enumerate(llvm::zip_equal(
3244 privateVarsInfo.blockArgs, privateVarsInfo.privatizers,
3245 privateVarsInfo.mlirVars))) {
3246 auto [blockArg, privDecl, mlirPrivVar] = zip;
3247 // This is handled before the task executes
3248 if (privDecl.readsFromMold())
3249 continue;
3250
3251 llvm::IRBuilderBase::InsertPointGuard guard(builder);
3252 llvm::Type *llvmAllocType =
3253 moduleTranslation.convertType(privDecl.getType());
3254 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
3255 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
3256 llvmAllocType, /*ArraySize=*/nullptr, "omp.private.alloc");
3257
3258 llvm::Expected<llvm::Value *> privateVarOrError =
3259 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
3260 blockArg, llvmPrivateVar, privInitBlock);
3261 if (!privateVarOrError)
3262 return privateVarOrError.takeError();
3263 moduleTranslation.mapValue(blockArg, privateVarOrError.get());
3264 privateVarsInfo.llvmVars[i] = privateVarOrError.get();
3265 }
3266
3267 taskStructMgr.createGEPsToPrivateVars();
3268 for (auto [i, llvmPrivVar] :
3269 llvm::enumerate(taskStructMgr.getLLVMPrivateVarGEPs())) {
3270 if (!llvmPrivVar) {
3271 assert(privateVarsInfo.llvmVars[i] &&
3272 "This is added in the loop above");
3273 continue;
3274 }
3275 privateVarsInfo.llvmVars[i] = llvmPrivVar;
3276 }
3277
3278 // Find and map the addresses of each variable within the taskloop context
3279 // structure
3280 for (auto [blockArg, llvmPrivateVar, privateDecl] :
3281 llvm::zip_equal(privateVarsInfo.blockArgs, privateVarsInfo.llvmVars,
3282 privateVarsInfo.privatizers)) {
3283 // This was handled above.
3284 if (!privateDecl.readsFromMold())
3285 continue;
3286 // Fix broken pass-by-value case for Fortran character boxes
3287 if (!mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
3288 llvmPrivateVar = builder.CreateLoad(
3289 moduleTranslation.convertType(blockArg.getType()), llvmPrivateVar);
3290 }
3291 assert(llvmPrivateVar->getType() ==
3292 moduleTranslation.convertType(blockArg.getType()));
3293 moduleTranslation.mapValue(blockArg, llvmPrivateVar);
3294 }
3295
3296 // Lower the contents of the taskloop context region: this is the body of
3297 // the generated task, not the loop.
3298 auto continuationBlockOrError = convertOmpOpRegions(
3299 contextOp.getRegion(), "omp.taskloop.context.region", builder,
3300 moduleTranslation);
3301
3302 if (failed(handleError(continuationBlockOrError, opInst)))
3303 return llvm::make_error<PreviouslyReportedError>();
3304
3305 builder.SetInsertPoint(continuationBlockOrError.get()->getTerminator());
3306
3307 // This is freeing the private variables as mapped inside of the task: these
3308 // will be per-task private copies possibly after task duplication. This is
3309 // handled transparently by how these are passed to the structure passed
3310 // into the outlined function. When the task is duplicated, that structure
3311 // is duplicated too.
3312 if (failed(cleanupPrivateVars(builder, moduleTranslation,
3313 contextOp.getLoc(), privateVarsInfo.llvmVars,
3314 privateVarsInfo.privatizers)))
3315 return llvm::make_error<PreviouslyReportedError>();
3316 // Similarly, the task context structure freed inside the task is the
3317 // per-task copy after task duplication.
3318 taskStructMgr.freeStructPtr();
3319
3320 return llvm::Error::success();
3321 };
3322
3323 // Taskloop divides into an appropriate number of tasks by repeatedly
3324 // duplicating the original task. Each time this is done, the task context
3325 // structure must be duplicated too.
3326 auto taskDupCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP,
3327 llvm::Value *destPtr, llvm::Value *srcPtr)
3329 llvm::IRBuilderBase::InsertPointGuard guard(builder);
3330 builder.restoreIP(codegenIP);
3331
3332 llvm::Type *ptrTy =
3333 builder.getPtrTy(srcPtr->getType()->getPointerAddressSpace());
3334 llvm::Value *src =
3335 builder.CreateLoad(ptrTy, srcPtr, "omp.taskloop.context.src");
3336
3337 TaskContextStructManager &srcStructMgr = taskStructMgr;
3338 TaskContextStructManager destStructMgr(builder, moduleTranslation,
3339 privateVarsInfo.privatizers);
3340 destStructMgr.generateTaskContextStruct();
3341 llvm::Value *dest = destStructMgr.getStructPtr();
3342 dest->setName("omp.taskloop.context.dest");
3343 builder.CreateStore(dest, destPtr);
3344
3346 srcStructMgr.createGEPsToPrivateVars(src);
3348 destStructMgr.createGEPsToPrivateVars(dest);
3349
3350 // Inline init regions.
3351 for (auto [privDecl, mold, blockArg, llvmPrivateVarAlloc] :
3352 llvm::zip_equal(privateVarsInfo.privatizers, srcGEPs,
3353 privateVarsInfo.blockArgs, destGEPs)) {
3354 // To be handled inside task body.
3355 if (!privDecl.readsFromMold())
3356 continue;
3357 assert(llvmPrivateVarAlloc &&
3358 "reads from mold so shouldn't have been skipped");
3359
3360 llvm::Expected<llvm::Value *> privateVarOrErr =
3361 initPrivateVar(builder, moduleTranslation, privDecl, mold, blockArg,
3362 llvmPrivateVarAlloc, builder.GetInsertBlock());
3363 if (!privateVarOrErr)
3364 return privateVarOrErr.takeError();
3365
3367
3368 // TODO: this is a bit of a hack for Fortran character boxes.
3369 // Character boxes are passed by value into the init region and then the
3370 // initialized character box is yielded by value. Here we need to store
3371 // the yielded value into the private allocation, and load the private
3372 // allocation to match the type expected by region block arguments.
3373 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
3374 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
3375 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
3376 // Load it so we have the value pointed to by the GEP
3377 llvmPrivateVarAlloc = builder.CreateLoad(
3378 privateVarOrErr.get()->getType(), llvmPrivateVarAlloc);
3379 }
3380 assert(llvmPrivateVarAlloc->getType() ==
3381 moduleTranslation.convertType(blockArg.getType()));
3382
3383 // Mapping blockArg -> llvmPrivateVarAlloc is done inside the body
3384 // callback so that OpenMPIRBuilder doesn't try to pass each GEP address
3385 // through a stack allocated structure.
3386 }
3387
3388 if (failed(copyFirstPrivateVars(contextOp.getOperation(), builder,
3389 moduleTranslation, srcGEPs, destGEPs,
3390 privateVarsInfo.privatizers,
3391 contextOp.getPrivateNeedsBarrier())))
3392 return llvm::make_error<PreviouslyReportedError>();
3393
3394 return builder.saveIP();
3395 };
3396
3397 auto loopInfo = [&]() -> llvm::Expected<llvm::CanonicalLoopInfo *> {
3398 llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
3399 return loopInfo;
3400 };
3401
3402 llvm::Value *ifCond = nullptr;
3403 llvm::Value *grainsize = nullptr;
3404 int sched = 0; // default
3405 mlir::Value grainsizeVal = contextOp.getGrainsize();
3406 mlir::Value numTasksVal = contextOp.getNumTasks();
3407 if (Value ifVar = contextOp.getIfExpr())
3408 ifCond = moduleTranslation.lookupValue(ifVar);
3409 if (grainsizeVal) {
3410 grainsize = moduleTranslation.lookupValue(grainsizeVal);
3411 sched = 1; // grainsize
3412 } else if (numTasksVal) {
3413 grainsize = moduleTranslation.lookupValue(numTasksVal);
3414 sched = 2; // num_tasks
3415 }
3416
3417 llvm::OpenMPIRBuilder::TaskDupCallbackTy taskDupOrNull = nullptr;
3418 if (taskStructMgr.getStructPtr())
3419 taskDupOrNull = taskDupCB;
3420
3421 llvm::OpenMPIRBuilder &ompBuilder = *moduleTranslation.getOpenMPBuilder();
3422 SmallVector<llvm::UncondBrInst *> cancelTerminators;
3423 // The directive to match here is OMPD_taskgroup because it is the
3424 // taskgroup which is canceled. This is handled here because it is the
3425 // task's cleanup block which should be branched to. It doesn't depend upon
3426 // nogroup because even in that case the taskloop might still be inside an
3427 // explicit taskgroup.
3428 pushCancelFinalizationCB(cancelTerminators, builder, ompBuilder, contextOp,
3429 llvm::omp::Directive::OMPD_taskgroup);
3430
3431 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3432 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3433 moduleTranslation.getOpenMPBuilder()->createTaskloop(
3434 ompLoc, allocaIP, bodyCB, loopInfo, lbVal, ubVal, stepVal,
3435 contextOp.getUntied(), ifCond, grainsize, contextOp.getNogroup(),
3436 sched, moduleTranslation.lookupValue(contextOp.getFinal()),
3437 contextOp.getMergeable(),
3438 moduleTranslation.lookupValue(contextOp.getPriority()),
3439 loopOp.getCollapseNumLoops(), taskDupOrNull,
3440 taskStructMgr.getStructPtr());
3441
3442 if (failed(handleError(afterIP, opInst)))
3443 return failure();
3444
3445 popCancelFinalizationCB(cancelTerminators, ompBuilder, afterIP.get());
3446
3447 builder.restoreIP(*afterIP);
3448 return success();
3449}
3450
3451/// Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder.
3452static LogicalResult
3453convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder,
3454 LLVM::ModuleTranslation &moduleTranslation) {
3455 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
3456 if (failed(checkImplementationStatus(*tgOp)))
3457 return failure();
3458
3459 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
3460 builder.restoreIP(codegenIP);
3461 return convertOmpOpRegions(tgOp.getRegion(), "omp.taskgroup.region",
3462 builder, moduleTranslation)
3463 .takeError();
3464 };
3465
3466 InsertPointTy allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
3467 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3468 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3469 moduleTranslation.getOpenMPBuilder()->createTaskgroup(ompLoc, allocaIP,
3470 bodyCB);
3471
3472 if (failed(handleError(afterIP, *tgOp)))
3473 return failure();
3474
3475 builder.restoreIP(*afterIP);
3476 return success();
3477}
3478
3479static LogicalResult
3480convertOmpTaskwaitOp(omp::TaskwaitOp twOp, llvm::IRBuilderBase &builder,
3481 LLVM::ModuleTranslation &moduleTranslation) {
3482 if (failed(checkImplementationStatus(*twOp)))
3483 return failure();
3484
3485 moduleTranslation.getOpenMPBuilder()->createTaskwait(builder.saveIP());
3486 return success();
3487}
3488
3489/// Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
3490static LogicalResult
3491convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
3492 LLVM::ModuleTranslation &moduleTranslation) {
3493 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3494 auto wsloopOp = cast<omp::WsloopOp>(opInst);
3495 if (failed(checkImplementationStatus(opInst)))
3496 return failure();
3497
3498 auto loopOp = cast<omp::LoopNestOp>(wsloopOp.getWrappedLoop());
3499 llvm::ArrayRef<bool> isByRef = getIsByRef(wsloopOp.getReductionByref());
3500 assert(isByRef.size() == wsloopOp.getNumReductionVars());
3501
3502 // Static is the default.
3503 auto schedule =
3504 wsloopOp.getScheduleKind().value_or(omp::ClauseScheduleKind::Static);
3505
3506 // Find the loop configuration.
3507 llvm::Value *step = moduleTranslation.lookupValue(loopOp.getLoopSteps()[0]);
3508 llvm::Type *ivType = step->getType();
3509 llvm::Value *chunk = nullptr;
3510 if (wsloopOp.getScheduleChunk()) {
3511 llvm::Value *chunkVar =
3512 moduleTranslation.lookupValue(wsloopOp.getScheduleChunk());
3513 chunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
3514 }
3515
3516 omp::DistributeOp distributeOp = nullptr;
3517 llvm::Value *distScheduleChunk = nullptr;
3518 bool hasDistSchedule = false;
3519 if (llvm::isa_and_present<omp::DistributeOp>(opInst.getParentOp())) {
3520 distributeOp = cast<omp::DistributeOp>(opInst.getParentOp());
3521 hasDistSchedule = distributeOp.getDistScheduleStatic();
3522 if (distributeOp.getDistScheduleChunkSize()) {
3523 llvm::Value *chunkVar = moduleTranslation.lookupValue(
3524 distributeOp.getDistScheduleChunkSize());
3525 distScheduleChunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
3526 }
3527 }
3528
3529 PrivateVarsInfo privateVarsInfo(wsloopOp);
3530
3532 collectReductionDecls(wsloopOp, reductionDecls);
3533 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3534 findAllocaInsertPoint(builder, moduleTranslation);
3535
3536 SmallVector<llvm::Value *> privateReductionVariables(
3537 wsloopOp.getNumReductionVars());
3538
3540 builder, moduleTranslation, privateVarsInfo, allocaIP);
3541 if (handleError(afterAllocas, opInst).failed())
3542 return failure();
3543
3544 DenseMap<Value, llvm::Value *> reductionVariableMap;
3545
3546 MutableArrayRef<BlockArgument> reductionArgs =
3547 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
3548
3549 SmallVector<DeferredStore> deferredStores;
3550
3551 if (failed(allocReductionVars(wsloopOp, reductionArgs, builder,
3552 moduleTranslation, allocaIP, reductionDecls,
3553 privateReductionVariables, reductionVariableMap,
3554 deferredStores, isByRef)))
3555 return failure();
3556
3557 if (handleError(initPrivateVars(builder, moduleTranslation, privateVarsInfo),
3558 opInst)
3559 .failed())
3560 return failure();
3561
3562 if (failed(copyFirstPrivateVars(
3563 wsloopOp, builder, moduleTranslation, privateVarsInfo.mlirVars,
3564 privateVarsInfo.llvmVars, privateVarsInfo.privatizers,
3565 wsloopOp.getPrivateNeedsBarrier())))
3566 return failure();
3567
3568 assert(afterAllocas.get()->getSinglePredecessor());
3569 if (failed(initReductionVars(wsloopOp, reductionArgs, builder,
3570 moduleTranslation,
3571 afterAllocas.get()->getSinglePredecessor(),
3572 reductionDecls, privateReductionVariables,
3573 reductionVariableMap, isByRef, deferredStores)))
3574 return failure();
3575
3576 // TODO: Handle doacross loops when the ordered clause has a parameter.
3577 bool isOrdered = wsloopOp.getOrdered().has_value();
3578 std::optional<omp::ScheduleModifier> scheduleMod = wsloopOp.getScheduleMod();
3579 bool isSimd = wsloopOp.getScheduleSimd();
3580 bool loopNeedsBarrier = !wsloopOp.getNowait();
3581
3582 // The only legal way for the direct parent to be omp.distribute is that this
3583 // represents 'distribute parallel do'. Otherwise, this is a regular
3584 // worksharing loop.
3585 llvm::omp::WorksharingLoopType workshareLoopType =
3586 llvm::isa_and_present<omp::DistributeOp>(opInst.getParentOp())
3587 ? llvm::omp::WorksharingLoopType::DistributeForStaticLoop
3588 : llvm::omp::WorksharingLoopType::ForStaticLoop;
3589
3590 SmallVector<llvm::UncondBrInst *> cancelTerminators;
3591 pushCancelFinalizationCB(cancelTerminators, builder, *ompBuilder, wsloopOp,
3592 llvm::omp::Directive::OMPD_for);
3593
3594 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3595
3596 // Initialize linear variables and linear step
3597 LinearClauseProcessor linearClauseProcessor;
3598
3599 if (!wsloopOp.getLinearVars().empty()) {
3600 auto linearVarTypes = wsloopOp.getLinearVarTypes().value();
3601 for (mlir::Attribute linearVarType : linearVarTypes)
3602 linearClauseProcessor.registerType(moduleTranslation, linearVarType);
3603
3604 for (auto [idx, linearVar] : llvm::enumerate(wsloopOp.getLinearVars()))
3605 linearClauseProcessor.createLinearVar(
3606 builder, moduleTranslation, moduleTranslation.lookupValue(linearVar),
3607 idx);
3608 for (mlir::Value linearStep : wsloopOp.getLinearStepVars())
3609 linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
3610 }
3611
3613 wsloopOp.getRegion(), "omp.wsloop.region", builder, moduleTranslation);
3614
3615 if (failed(handleError(regionBlock, opInst)))
3616 return failure();
3617
3618 llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
3619
3620 // Emit Initialization and Update IR for linear variables
3621 if (!wsloopOp.getLinearVars().empty()) {
3622 linearClauseProcessor.initLinearVar(builder, moduleTranslation,
3623 loopInfo->getPreheader());
3624 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
3625 moduleTranslation.getOpenMPBuilder()->createBarrier(
3626 builder.saveIP(), llvm::omp::OMPD_barrier);
3627 if (failed(handleError(afterBarrierIP, *loopOp)))
3628 return failure();
3629 builder.restoreIP(*afterBarrierIP);
3630 linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
3631 loopInfo->getIndVar());
3632 linearClauseProcessor.splitLinearFiniBB(builder, loopInfo->getExit());
3633 }
3634
3635 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
3636
3637 // Check if we can generate no-loop kernel
3638 bool noLoopMode = false;
3639 omp::TargetOp targetOp = wsloopOp->getParentOfType<mlir::omp::TargetOp>();
3640 if (targetOp) {
3641 Operation *targetCapturedOp = targetOp.getInnermostCapturedOmpOp();
3642 // We need this check because, without it, noLoopMode would be set to true
3643 // for every omp.wsloop nested inside a no-loop SPMD target region, even if
3644 // that loop is not the top-level SPMD one.
3645 if (loopOp == targetCapturedOp) {
3646 omp::TargetRegionFlags kernelFlags =
3647 targetOp.getKernelExecFlags(targetCapturedOp);
3648 if (omp::bitEnumContainsAll(kernelFlags,
3649 omp::TargetRegionFlags::spmd |
3650 omp::TargetRegionFlags::no_loop) &&
3651 !omp::bitEnumContainsAny(kernelFlags,
3652 omp::TargetRegionFlags::generic))
3653 noLoopMode = true;
3654 }
3655 }
3656
3657 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
3658 ompBuilder->applyWorkshareLoop(
3659 ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
3660 convertToScheduleKind(schedule), chunk, isSimd,
3661 scheduleMod == omp::ScheduleModifier::monotonic,
3662 scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
3663 workshareLoopType, noLoopMode, hasDistSchedule, distScheduleChunk);
3664
3665 if (failed(handleError(wsloopIP, opInst)))
3666 return failure();
3667
3668 // Emit finalization and in-place rewrites for linear vars.
3669 if (!wsloopOp.getLinearVars().empty()) {
3670 llvm::OpenMPIRBuilder::InsertPointTy oldIP = builder.saveIP();
3671 assert(loopInfo->getLastIter() &&
3672 "`lastiter` in CanonicalLoopInfo is nullptr");
3673 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
3674 linearClauseProcessor.finalizeLinearVar(builder, moduleTranslation,
3675 loopInfo->getLastIter());
3676 if (failed(handleError(afterBarrierIP, *loopOp)))
3677 return failure();
3678 for (size_t index = 0; index < wsloopOp.getLinearVars().size(); index++)
3679 linearClauseProcessor.rewriteInPlace(builder, "omp.loop_nest.region",
3680 index);
3681 builder.restoreIP(oldIP);
3682 }
3683
3684 // Set the correct branch target for task cancellation
3685 popCancelFinalizationCB(cancelTerminators, *ompBuilder, wsloopIP.get());
3686
3687 // Process the reductions if required.
3688 if (failed(createReductionsAndCleanup(
3689 wsloopOp, builder, moduleTranslation, allocaIP, reductionDecls,
3690 privateReductionVariables, isByRef, wsloopOp.getNowait(),
3691 /*isTeamsReduction=*/false)))
3692 return failure();
3693
3694 return cleanupPrivateVars(builder, moduleTranslation, wsloopOp.getLoc(),
3695 privateVarsInfo.llvmVars,
3696 privateVarsInfo.privatizers);
3697}
3698
3699/// Converts the OpenMP parallel operation to LLVM IR.
3700static LogicalResult
3701convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
3702 LLVM::ModuleTranslation &moduleTranslation) {
3703 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
3704 ArrayRef<bool> isByRef = getIsByRef(opInst.getReductionByref());
3705 assert(isByRef.size() == opInst.getNumReductionVars());
3706 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3707 bool isCancellable = constructIsCancellable(opInst);
3708
3709 if (failed(checkImplementationStatus(*opInst)))
3710 return failure();
3711
3712 PrivateVarsInfo privateVarsInfo(opInst);
3713
3714 // Collect reduction declarations
3716 collectReductionDecls(opInst, reductionDecls);
3717 SmallVector<llvm::Value *> privateReductionVariables(
3718 opInst.getNumReductionVars());
3719 SmallVector<DeferredStore> deferredStores;
3720
3721 auto bodyGenCB = [&](InsertPointTy allocaIP,
3722 InsertPointTy codeGenIP) -> llvm::Error {
3724 builder, moduleTranslation, privateVarsInfo, allocaIP);
3725 if (handleError(afterAllocas, *opInst).failed())
3726 return llvm::make_error<PreviouslyReportedError>();
3727
3728 // Allocate reduction vars
3729 DenseMap<Value, llvm::Value *> reductionVariableMap;
3730
3731 MutableArrayRef<BlockArgument> reductionArgs =
3732 cast<omp::BlockArgOpenMPOpInterface>(*opInst).getReductionBlockArgs();
3733
3734 allocaIP =
3735 InsertPointTy(allocaIP.getBlock(),
3736 allocaIP.getBlock()->getTerminator()->getIterator());
3737
3738 if (failed(allocReductionVars(
3739 opInst, reductionArgs, builder, moduleTranslation, allocaIP,
3740 reductionDecls, privateReductionVariables, reductionVariableMap,
3741 deferredStores, isByRef)))
3742 return llvm::make_error<PreviouslyReportedError>();
3743
3744 assert(afterAllocas.get()->getSinglePredecessor());
3745 builder.restoreIP(codeGenIP);
3746
3747 if (handleError(
3748 initPrivateVars(builder, moduleTranslation, privateVarsInfo),
3749 *opInst)
3750 .failed())
3751 return llvm::make_error<PreviouslyReportedError>();
3752
3753 if (failed(copyFirstPrivateVars(
3754 opInst, builder, moduleTranslation, privateVarsInfo.mlirVars,
3755 privateVarsInfo.llvmVars, privateVarsInfo.privatizers,
3756 opInst.getPrivateNeedsBarrier())))
3757 return llvm::make_error<PreviouslyReportedError>();
3758
3759 if (failed(
3760 initReductionVars(opInst, reductionArgs, builder, moduleTranslation,
3761 afterAllocas.get()->getSinglePredecessor(),
3762 reductionDecls, privateReductionVariables,
3763 reductionVariableMap, isByRef, deferredStores)))
3764 return llvm::make_error<PreviouslyReportedError>();
3765
3766 // Save the alloca insertion point on ModuleTranslation stack for use in
3767 // nested regions.
3769 moduleTranslation, allocaIP);
3770
3771 // ParallelOp has only one region associated with it.
3773 opInst.getRegion(), "omp.par.region", builder, moduleTranslation);
3774 if (!regionBlock)
3775 return regionBlock.takeError();
3776
3777 // Process the reductions if required.
3778 if (opInst.getNumReductionVars() > 0) {
3779 // Collect reduction info
3781 SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
3783 owningReductionGenRefDataPtrGens;
3785 collectReductionInfo(opInst, builder, moduleTranslation, reductionDecls,
3786 owningReductionGens, owningAtomicReductionGens,
3787 owningReductionGenRefDataPtrGens,
3788 privateReductionVariables, reductionInfos, isByRef);
3789
3790 // Move to region cont block
3791 builder.SetInsertPoint((*regionBlock)->getTerminator());
3792
3793 // Generate reductions from info
3794 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
3795 builder.SetInsertPoint(tempTerminator);
3796
3797 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
3798 ompBuilder->createReductions(
3799 builder.saveIP(), allocaIP, reductionInfos, isByRef,
3800 /*IsNoWait=*/false, /*IsTeamsReduction=*/false);
3801 if (!contInsertPoint)
3802 return contInsertPoint.takeError();
3803
3804 if (!contInsertPoint->getBlock())
3805 return llvm::make_error<PreviouslyReportedError>();
3806
3807 tempTerminator->eraseFromParent();
3808 builder.restoreIP(*contInsertPoint);
3809 }
3810
3811 return llvm::Error::success();
3812 };
3813
3814 auto privCB = [](InsertPointTy allocaIP, InsertPointTy codeGenIP,
3815 llvm::Value &, llvm::Value &val, llvm::Value *&replVal) {
3816 // tell OpenMPIRBuilder not to do anything. We handled Privatisation in
3817 // bodyGenCB.
3818 replVal = &val;
3819 return codeGenIP;
3820 };
3821
3822 // TODO: Perform finalization actions for variables. This has to be
3823 // called for variables which have destructors/finalizers.
3824 auto finiCB = [&](InsertPointTy codeGenIP) -> llvm::Error {
3825 InsertPointTy oldIP = builder.saveIP();
3826 builder.restoreIP(codeGenIP);
3827
3828 // if the reduction has a cleanup region, inline it here to finalize the
3829 // reduction variables
3830 SmallVector<Region *> reductionCleanupRegions;
3831 llvm::transform(reductionDecls, std::back_inserter(reductionCleanupRegions),
3832 [](omp::DeclareReductionOp reductionDecl) {
3833 return &reductionDecl.getCleanupRegion();
3834 });
3835 if (failed(inlineOmpRegionCleanup(
3836 reductionCleanupRegions, privateReductionVariables,
3837 moduleTranslation, builder, "omp.reduction.cleanup")))
3838 return llvm::createStringError(
3839 "failed to inline `cleanup` region of `omp.declare_reduction`");
3840
3841 if (failed(cleanupPrivateVars(builder, moduleTranslation, opInst.getLoc(),
3842 privateVarsInfo.llvmVars,
3843 privateVarsInfo.privatizers)))
3844 return llvm::make_error<PreviouslyReportedError>();
3845
3846 // If we could be performing cancellation, add the cancellation barrier on
3847 // the way out of the outlined region.
3848 if (isCancellable) {
3849 auto IPOrErr = ompBuilder->createBarrier(
3850 llvm::OpenMPIRBuilder::LocationDescription(builder),
3851 llvm::omp::Directive::OMPD_unknown,
3852 /* ForceSimpleCall */ false,
3853 /* CheckCancelFlag */ false);
3854 if (!IPOrErr)
3855 return IPOrErr.takeError();
3856 }
3857
3858 builder.restoreIP(oldIP);
3859 return llvm::Error::success();
3860 };
3861
3862 llvm::Value *ifCond = nullptr;
3863 if (auto ifVar = opInst.getIfExpr())
3864 ifCond = moduleTranslation.lookupValue(ifVar);
3865 llvm::Value *numThreads = nullptr;
3866 if (!opInst.getNumThreadsVars().empty())
3867 numThreads = moduleTranslation.lookupValue(opInst.getNumThreads(0));
3868 auto pbKind = llvm::omp::OMP_PROC_BIND_default;
3869 if (auto bind = opInst.getProcBindKind())
3870 pbKind = getProcBindKind(*bind);
3871
3872 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3873 findAllocaInsertPoint(builder, moduleTranslation);
3874 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3875
3876 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3877 ompBuilder->createParallel(ompLoc, allocaIP, bodyGenCB, privCB, finiCB,
3878 ifCond, numThreads, pbKind, isCancellable);
3879
3880 if (failed(handleError(afterIP, *opInst)))
3881 return failure();
3882
3883 builder.restoreIP(*afterIP);
3884 return success();
3885}
3886
3887/// Convert Order attribute to llvm::omp::OrderKind.
3888static llvm::omp::OrderKind
3889convertOrderKind(std::optional<omp::ClauseOrderKind> o) {
3890 if (!o)
3891 return llvm::omp::OrderKind::OMP_ORDER_unknown;
3892 switch (*o) {
3893 case omp::ClauseOrderKind::Concurrent:
3894 return llvm::omp::OrderKind::OMP_ORDER_concurrent;
3895 }
3896 llvm_unreachable("Unknown ClauseOrderKind kind");
3897}
3898
3899/// Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder.
3900static LogicalResult
3901convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
3902 LLVM::ModuleTranslation &moduleTranslation) {
3903 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3904 auto simdOp = cast<omp::SimdOp>(opInst);
3905
3906 if (failed(checkImplementationStatus(opInst)))
3907 return failure();
3908
3909 PrivateVarsInfo privateVarsInfo(simdOp);
3910
3911 MutableArrayRef<BlockArgument> reductionArgs =
3912 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
3913 DenseMap<Value, llvm::Value *> reductionVariableMap;
3914 SmallVector<llvm::Value *> privateReductionVariables(
3915 simdOp.getNumReductionVars());
3916 SmallVector<DeferredStore> deferredStores;
3918 collectReductionDecls(simdOp, reductionDecls);
3919 llvm::ArrayRef<bool> isByRef = getIsByRef(simdOp.getReductionByref());
3920 assert(isByRef.size() == simdOp.getNumReductionVars());
3921
3922 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3923 findAllocaInsertPoint(builder, moduleTranslation);
3924
3926 builder, moduleTranslation, privateVarsInfo, allocaIP);
3927 if (handleError(afterAllocas, opInst).failed())
3928 return failure();
3929
3930 // Initialize linear variables and linear step
3931 LinearClauseProcessor linearClauseProcessor;
3932
3933 if (!simdOp.getLinearVars().empty()) {
3934 auto linearVarTypes = simdOp.getLinearVarTypes().value();
3935 for (mlir::Attribute linearVarType : linearVarTypes)
3936 linearClauseProcessor.registerType(moduleTranslation, linearVarType);
3937 for (auto [idx, linearVar] : llvm::enumerate(simdOp.getLinearVars())) {
3938 bool isImplicit = false;
3939 for (auto [mlirPrivVar, llvmPrivateVar] : llvm::zip_equal(
3940 privateVarsInfo.mlirVars, privateVarsInfo.llvmVars)) {
3941 // If the linear variable is implicit, reuse the already
3942 // existing llvm::Value
3943 if (linearVar == mlirPrivVar) {
3944 isImplicit = true;
3945 linearClauseProcessor.createLinearVar(builder, moduleTranslation,
3946 llvmPrivateVar, idx);
3947 break;
3948 }
3949 }
3950
3951 if (!isImplicit)
3952 linearClauseProcessor.createLinearVar(
3953 builder, moduleTranslation,
3954 moduleTranslation.lookupValue(linearVar), idx);
3955 }
3956 for (mlir::Value linearStep : simdOp.getLinearStepVars())
3957 linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
3958 }
3959
3960 if (failed(allocReductionVars(simdOp, reductionArgs, builder,
3961 moduleTranslation, allocaIP, reductionDecls,
3962 privateReductionVariables, reductionVariableMap,
3963 deferredStores, isByRef)))
3964 return failure();
3965
3966 if (handleError(initPrivateVars(builder, moduleTranslation, privateVarsInfo),
3967 opInst)
3968 .failed())
3969 return failure();
3970
3971 // No call to copyFirstPrivateVars because FIRSTPRIVATE is not allowed for
3972 // SIMD.
3973
3974 assert(afterAllocas.get()->getSinglePredecessor());
3975 if (failed(initReductionVars(simdOp, reductionArgs, builder,
3976 moduleTranslation,
3977 afterAllocas.get()->getSinglePredecessor(),
3978 reductionDecls, privateReductionVariables,
3979 reductionVariableMap, isByRef, deferredStores)))
3980 return failure();
3981
3982 llvm::ConstantInt *simdlen = nullptr;
3983 if (std::optional<uint64_t> simdlenVar = simdOp.getSimdlen())
3984 simdlen = builder.getInt64(simdlenVar.value());
3985
3986 llvm::ConstantInt *safelen = nullptr;
3987 if (std::optional<uint64_t> safelenVar = simdOp.getSafelen())
3988 safelen = builder.getInt64(safelenVar.value());
3989
3990 llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars;
3991 llvm::omp::OrderKind order = convertOrderKind(simdOp.getOrder());
3992
3993 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
3994 std::optional<ArrayAttr> alignmentValues = simdOp.getAlignments();
3995 mlir::OperandRange operands = simdOp.getAlignedVars();
3996 for (size_t i = 0; i < operands.size(); ++i) {
3997 llvm::Value *alignment = nullptr;
3998 llvm::Value *llvmVal = moduleTranslation.lookupValue(operands[i]);
3999 llvm::Type *ty = llvmVal->getType();
4000
4001 auto intAttr = cast<IntegerAttr>((*alignmentValues)[i]);
4002 alignment = builder.getInt64(intAttr.getInt());
4003 assert(ty->isPointerTy() && "Invalid type for aligned variable");
4004 assert(alignment && "Invalid alignment value");
4005
4006 // Check if the alignment value is not a power of 2. If so, skip emitting
4007 // alignment.
4008 if (!intAttr.getValue().isPowerOf2())
4009 continue;
4010
4011 auto curInsert = builder.saveIP();
4012 builder.SetInsertPoint(sourceBlock);
4013 llvmVal = builder.CreateLoad(ty, llvmVal);
4014 builder.restoreIP(curInsert);
4015 alignedVars[llvmVal] = alignment;
4016 }
4017
4019 simdOp.getRegion(), "omp.simd.region", builder, moduleTranslation);
4020
4021 if (failed(handleError(regionBlock, opInst)))
4022 return failure();
4023
4024 llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
4025 // Emit Initialization for linear variables
4026 if (simdOp.getLinearVars().size()) {
4027 linearClauseProcessor.initLinearVar(builder, moduleTranslation,
4028 loopInfo->getPreheader());
4029
4030 linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
4031 loopInfo->getIndVar());
4032 }
4033 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
4034
4035 ompBuilder->applySimd(loopInfo, alignedVars,
4036 simdOp.getIfExpr()
4037 ? moduleTranslation.lookupValue(simdOp.getIfExpr())
4038 : nullptr,
4039 order, simdlen, safelen);
4040
4041 linearClauseProcessor.emitStoresForLinearVar(builder);
4042
4043 // Check if this SIMD loop contains ordered regions
4044 bool hasOrderedRegions = false;
4045 simdOp.getRegion().walk([&](omp::OrderedRegionOp orderedOp) {
4046 hasOrderedRegions = true;
4047 return WalkResult::interrupt();
4048 });
4049
4050 for (size_t index = 0; index < simdOp.getLinearVars().size(); index++) {
4051 linearClauseProcessor.rewriteInPlace(builder, "omp.loop_nest.region",
4052 index);
4053 if (hasOrderedRegions) {
4054 // Also rewrite uses in ordered regions so they read the current value
4055 linearClauseProcessor.rewriteInPlace(builder, "omp.ordered.region",
4056 index);
4057 // Also rewrite uses in finalize blocks (code after ordered regions)
4058 linearClauseProcessor.rewriteInPlace(builder, "omp_region.finalize",
4059 index);
4060 }
4061 }
4062
4063 // We now need to reduce the per-simd-lane reduction variable into the
4064 // original variable. This works a bit differently to other reductions (e.g.
4065 // wsloop) because we don't need to call into the OpenMP runtime to handle
4066 // threads: everything happened in this one thread.
4067 for (auto [i, tuple] : llvm::enumerate(
4068 llvm::zip(reductionDecls, isByRef, simdOp.getReductionVars(),
4069 privateReductionVariables))) {
4070 auto [decl, byRef, reductionVar, privateReductionVar] = tuple;
4071
4072 OwningReductionGen gen = makeReductionGen(decl, builder, moduleTranslation);
4073 llvm::Value *originalVariable = moduleTranslation.lookupValue(reductionVar);
4074 llvm::Type *reductionType = moduleTranslation.convertType(decl.getType());
4075
4076 // We have one less load for by-ref case because that load is now inside of
4077 // the reduction region.
4078 llvm::Value *redValue = originalVariable;
4079 if (!byRef)
4080 redValue =
4081 builder.CreateLoad(reductionType, redValue, "red.value." + Twine(i));
4082 llvm::Value *privateRedValue = builder.CreateLoad(
4083 reductionType, privateReductionVar, "red.private.value." + Twine(i));
4084 llvm::Value *reduced;
4085
4086 auto res = gen(builder.saveIP(), redValue, privateRedValue, reduced);
4087 if (failed(handleError(res, opInst)))
4088 return failure();
4089 builder.restoreIP(res.get());
4090
4091 // For by-ref case, the store is inside of the reduction region.
4092 if (!byRef)
4093 builder.CreateStore(reduced, originalVariable);
4094 }
4095
4096 // After the construct, deallocate private reduction variables.
4097 SmallVector<Region *> reductionRegions;
4098 llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
4099 [](omp::DeclareReductionOp reductionDecl) {
4100 return &reductionDecl.getCleanupRegion();
4101 });
4102 if (failed(inlineOmpRegionCleanup(reductionRegions, privateReductionVariables,
4103 moduleTranslation, builder,
4104 "omp.reduction.cleanup")))
4105 return failure();
4106
4107 return cleanupPrivateVars(builder, moduleTranslation, simdOp.getLoc(),
4108 privateVarsInfo.llvmVars,
4109 privateVarsInfo.privatizers);
4110}
4111
4112/// Converts an OpenMP loop nest into LLVM IR using OpenMPIRBuilder.
4113static LogicalResult
4114convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder,
4115 LLVM::ModuleTranslation &moduleTranslation) {
4116 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4117 auto loopOp = cast<omp::LoopNestOp>(opInst);
4118
4119 if (failed(checkImplementationStatus(opInst)))
4120 return failure();
4121
4122 // Set up the source location value for OpenMP runtime.
4123 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4124
4125 // Generator of the canonical loop body.
4128 auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip,
4129 llvm::Value *iv) -> llvm::Error {
4130 // Make sure further conversions know about the induction variable.
4131 moduleTranslation.mapValue(
4132 loopOp.getRegion().front().getArgument(loopInfos.size()), iv);
4133
4134 // Capture the body insertion point for use in nested loops. BodyIP of the
4135 // CanonicalLoopInfo always points to the beginning of the entry block of
4136 // the body.
4137 bodyInsertPoints.push_back(ip);
4138
4139 if (loopInfos.size() != loopOp.getNumLoops() - 1)
4140 return llvm::Error::success();
4141
4142 // Convert the body of the loop.
4143 builder.restoreIP(ip);
4145 loopOp.getRegion(), "omp.loop_nest.region", builder, moduleTranslation);
4146 if (!regionBlock)
4147 return regionBlock.takeError();
4148
4149 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
4150 return llvm::Error::success();
4151 };
4152
4153 // Delegate actual loop construction to the OpenMP IRBuilder.
4154 // TODO: this currently assumes omp.loop_nest is semantically similar to SCF
4155 // loop, i.e. it has a positive step, uses signed integer semantics.
4156 // Reconsider this code when the nested loop operation clearly supports more
4157 // cases.
4158 for (unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) {
4159 llvm::Value *lowerBound =
4160 moduleTranslation.lookupValue(loopOp.getLoopLowerBounds()[i]);
4161 llvm::Value *upperBound =
4162 moduleTranslation.lookupValue(loopOp.getLoopUpperBounds()[i]);
4163 llvm::Value *step = moduleTranslation.lookupValue(loopOp.getLoopSteps()[i]);
4164
4165 // Make sure loop trip count are emitted in the preheader of the outermost
4166 // loop at the latest so that they are all available for the new collapsed
4167 // loop will be created below.
4168 llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc;
4169 llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP;
4170 if (i != 0) {
4171 loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back(),
4172 ompLoc.DL);
4173 computeIP = loopInfos.front()->getPreheaderIP();
4174 }
4175
4177 ompBuilder->createCanonicalLoop(
4178 loc, bodyGen, lowerBound, upperBound, step,
4179 /*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP);
4180
4181 if (failed(handleError(loopResult, *loopOp)))
4182 return failure();
4183
4184 loopInfos.push_back(*loopResult);
4185 }
4186
4187 llvm::OpenMPIRBuilder::InsertPointTy afterIP =
4188 loopInfos.front()->getAfterIP();
4189
4190 // Do tiling.
4191 if (const auto &tiles = loopOp.getTileSizes()) {
4192 llvm::Type *ivType = loopInfos.front()->getIndVarType();
4194
4195 for (auto tile : tiles.value()) {
4196 llvm::Value *tileVal = llvm::ConstantInt::get(ivType, tile);
4197 tileSizes.push_back(tileVal);
4198 }
4199
4200 std::vector<llvm::CanonicalLoopInfo *> newLoops =
4201 ompBuilder->tileLoops(ompLoc.DL, loopInfos, tileSizes);
4202
4203 // Update afterIP to get the correct insertion point after
4204 // tiling.
4205 llvm::BasicBlock *afterBB = newLoops.front()->getAfter();
4206 llvm::BasicBlock *afterAfterBB = afterBB->getSingleSuccessor();
4207 afterIP = {afterAfterBB, afterAfterBB->begin()};
4208
4209 // Update the loop infos.
4210 loopInfos.clear();
4211 for (const auto &newLoop : newLoops)
4212 loopInfos.push_back(newLoop);
4213 } // Tiling done.
4214
4215 // Do collapse.
4216 const auto &numCollapse = loopOp.getCollapseNumLoops();
4218 loopInfos.begin(), loopInfos.begin() + (numCollapse));
4219
4220 auto newTopLoopInfo =
4221 ompBuilder->collapseLoops(ompLoc.DL, collapseLoopInfos, {});
4222
4223 assert(newTopLoopInfo && "New top loop information is missing");
4224 moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
4225 [&](OpenMPLoopInfoStackFrame &frame) {
4226 frame.loopInfo = newTopLoopInfo;
4227 return WalkResult::interrupt();
4228 });
4229
4230 // Continue building IR after the loop. Note that the LoopInfo returned by
4231 // `collapseLoops` points inside the outermost loop and is intended for
4232 // potential further loop transformations. Use the insertion point stored
4233 // before collapsing loops instead.
4234 builder.restoreIP(afterIP);
4235 return success();
4236}
4237
4238/// Convert an omp.canonical_loop to LLVM-IR
4239static LogicalResult
4240convertOmpCanonicalLoopOp(omp::CanonicalLoopOp op, llvm::IRBuilderBase &builder,
4241 LLVM::ModuleTranslation &moduleTranslation) {
4242 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4243
4244 llvm::OpenMPIRBuilder::LocationDescription loopLoc(builder);
4245 Value loopIV = op.getInductionVar();
4246 Value loopTC = op.getTripCount();
4247
4248 llvm::Value *llvmTC = moduleTranslation.lookupValue(loopTC);
4249
4251 ompBuilder->createCanonicalLoop(
4252 loopLoc,
4253 [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *llvmIV) {
4254 // Register the mapping of MLIR induction variable to LLVM-IR
4255 // induction variable
4256 moduleTranslation.mapValue(loopIV, llvmIV);
4257
4258 builder.restoreIP(ip);
4260 convertOmpOpRegions(op.getRegion(), "omp.loop.region", builder,
4261 moduleTranslation);
4262
4263 return bodyGenStatus.takeError();
4264 },
4265 llvmTC, "omp.loop");
4266 if (!llvmOrError)
4267 return op.emitError(llvm::toString(llvmOrError.takeError()));
4268
4269 llvm::CanonicalLoopInfo *llvmCLI = *llvmOrError;
4270 llvm::IRBuilderBase::InsertPoint afterIP = llvmCLI->getAfterIP();
4271 builder.restoreIP(afterIP);
4272
4273 // Register the mapping of MLIR loop to LLVM-IR OpenMPIRBuilder loop
4274 if (Value cli = op.getCli())
4275 moduleTranslation.mapOmpLoop(cli, llvmCLI);
4276
4277 return success();
4278}
4279
4280/// Apply a `#pragma omp unroll` / "!$omp unroll" transformation using the
4281/// OpenMPIRBuilder.
4282static LogicalResult
4283applyUnrollHeuristic(omp::UnrollHeuristicOp op, llvm::IRBuilderBase &builder,
4284 LLVM::ModuleTranslation &moduleTranslation) {
4285 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4286
4287 Value applyee = op.getApplyee();
4288 assert(applyee && "Loop to apply unrolling on required");
4289
4290 llvm::CanonicalLoopInfo *consBuilderCLI =
4291 moduleTranslation.lookupOMPLoop(applyee);
4292 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
4293 ompBuilder->unrollLoopHeuristic(loc.DL, consBuilderCLI);
4294
4295 moduleTranslation.invalidateOmpLoop(applyee);
4296 return success();
4297}
4298
4299/// Apply a `#pragma omp tile` / `!$omp tile` transformation using the
4300/// OpenMPIRBuilder.
4301static LogicalResult applyTile(omp::TileOp op, llvm::IRBuilderBase &builder,
4302 LLVM::ModuleTranslation &moduleTranslation) {
4303 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4304 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
4305
4307 SmallVector<llvm::Value *> translatedSizes;
4308
4309 for (Value size : op.getSizes()) {
4310 llvm::Value *translatedSize = moduleTranslation.lookupValue(size);
4311 assert(translatedSize &&
4312 "sizes clause arguments must already be translated");
4313 translatedSizes.push_back(translatedSize);
4314 }
4315
4316 for (Value applyee : op.getApplyees()) {
4317 llvm::CanonicalLoopInfo *consBuilderCLI =
4318 moduleTranslation.lookupOMPLoop(applyee);
4319 assert(applyee && "Canonical loop must already been translated");
4320 translatedLoops.push_back(consBuilderCLI);
4321 }
4322
4323 auto generatedLoops =
4324 ompBuilder->tileLoops(loc.DL, translatedLoops, translatedSizes);
4325 if (!op.getGeneratees().empty()) {
4326 for (auto [mlirLoop, genLoop] :
4327 zip_equal(op.getGeneratees(), generatedLoops))
4328 moduleTranslation.mapOmpLoop(mlirLoop, genLoop);
4329 }
4330
4331 // CLIs can only be consumed once
4332 for (Value applyee : op.getApplyees())
4333 moduleTranslation.invalidateOmpLoop(applyee);
4334
4335 return success();
4336}
4337
4338/// Apply a `#pragma omp fuse` / `!$omp fuse` transformation using the
4339/// OpenMPIRBuilder.
4340static LogicalResult applyFuse(omp::FuseOp op, llvm::IRBuilderBase &builder,
4341 LLVM::ModuleTranslation &moduleTranslation) {
4342 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4343 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
4344
4345 // Select what CLIs are going to be fused
4346 SmallVector<llvm::CanonicalLoopInfo *> beforeFuse, toFuse, afterFuse;
4347 for (size_t i = 0; i < op.getApplyees().size(); i++) {
4348 Value applyee = op.getApplyees()[i];
4349 llvm::CanonicalLoopInfo *consBuilderCLI =
4350 moduleTranslation.lookupOMPLoop(applyee);
4351 assert(applyee && "Canonical loop must already been translated");
4352 if (op.getFirst().has_value() && i < op.getFirst().value() - 1)
4353 beforeFuse.push_back(consBuilderCLI);
4354 else if (op.getCount().has_value() &&
4355 i >= op.getFirst().value() + op.getCount().value() - 1)
4356 afterFuse.push_back(consBuilderCLI);
4357 else
4358 toFuse.push_back(consBuilderCLI);
4359 }
4360 assert(
4361 (op.getGeneratees().empty() ||
4362 beforeFuse.size() + afterFuse.size() + 1 == op.getGeneratees().size()) &&
4363 "Wrong number of generatees");
4364
4365 // do the fuse
4366 auto generatedLoop = ompBuilder->fuseLoops(loc.DL, toFuse);
4367 if (!op.getGeneratees().empty()) {
4368 size_t i = 0;
4369 for (; i < beforeFuse.size(); i++)
4370 moduleTranslation.mapOmpLoop(op.getGeneratees()[i], beforeFuse[i]);
4371 moduleTranslation.mapOmpLoop(op.getGeneratees()[i++], generatedLoop);
4372 for (; i < afterFuse.size(); i++)
4373 moduleTranslation.mapOmpLoop(op.getGeneratees()[i], afterFuse[i]);
4374 }
4375
4376 // CLIs can only be consumed once
4377 for (Value applyee : op.getApplyees())
4378 moduleTranslation.invalidateOmpLoop(applyee);
4379
4380 return success();
4381}
4382
4383/// Convert an Atomic Ordering attribute to llvm::AtomicOrdering.
4384static llvm::AtomicOrdering
4385convertAtomicOrdering(std::optional<omp::ClauseMemoryOrderKind> ao) {
4386 if (!ao)
4387 return llvm::AtomicOrdering::Monotonic; // Default Memory Ordering
4388
4389 switch (*ao) {
4390 case omp::ClauseMemoryOrderKind::Seq_cst:
4391 return llvm::AtomicOrdering::SequentiallyConsistent;
4392 case omp::ClauseMemoryOrderKind::Acq_rel:
4393 return llvm::AtomicOrdering::AcquireRelease;
4394 case omp::ClauseMemoryOrderKind::Acquire:
4395 return llvm::AtomicOrdering::Acquire;
4396 case omp::ClauseMemoryOrderKind::Release:
4397 return llvm::AtomicOrdering::Release;
4398 case omp::ClauseMemoryOrderKind::Relaxed:
4399 return llvm::AtomicOrdering::Monotonic;
4400 }
4401 llvm_unreachable("Unknown ClauseMemoryOrderKind kind");
4402}
4403
4404/// Convert omp.atomic.read operation to LLVM IR.
4405static LogicalResult
4406convertOmpAtomicRead(Operation &opInst, llvm::IRBuilderBase &builder,
4407 LLVM::ModuleTranslation &moduleTranslation) {
4408 auto readOp = cast<omp::AtomicReadOp>(opInst);
4409 if (failed(checkImplementationStatus(opInst)))
4410 return failure();
4411
4412 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4413 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4414 findAllocaInsertPoint(builder, moduleTranslation);
4415
4416 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4417
4418 llvm::AtomicOrdering AO = convertAtomicOrdering(readOp.getMemoryOrder());
4419 llvm::Value *x = moduleTranslation.lookupValue(readOp.getX());
4420 llvm::Value *v = moduleTranslation.lookupValue(readOp.getV());
4421
4422 llvm::Type *elementType =
4423 moduleTranslation.convertType(readOp.getElementType());
4424
4425 llvm::OpenMPIRBuilder::AtomicOpValue V = {v, elementType, false, false};
4426 llvm::OpenMPIRBuilder::AtomicOpValue X = {x, elementType, false, false};
4427 builder.restoreIP(ompBuilder->createAtomicRead(ompLoc, X, V, AO, allocaIP));
4428 return success();
4429}
4430
4431/// Converts an omp.atomic.write operation to LLVM IR.
4432static LogicalResult
4433convertOmpAtomicWrite(Operation &opInst, llvm::IRBuilderBase &builder,
4434 LLVM::ModuleTranslation &moduleTranslation) {
4435 auto writeOp = cast<omp::AtomicWriteOp>(opInst);
4436 if (failed(checkImplementationStatus(opInst)))
4437 return failure();
4438
4439 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4440 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4441 findAllocaInsertPoint(builder, moduleTranslation);
4442
4443 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4444 llvm::AtomicOrdering ao = convertAtomicOrdering(writeOp.getMemoryOrder());
4445 llvm::Value *expr = moduleTranslation.lookupValue(writeOp.getExpr());
4446 llvm::Value *dest = moduleTranslation.lookupValue(writeOp.getX());
4447 llvm::Type *ty = moduleTranslation.convertType(writeOp.getExpr().getType());
4448 llvm::OpenMPIRBuilder::AtomicOpValue x = {dest, ty, /*isSigned=*/false,
4449 /*isVolatile=*/false};
4450 builder.restoreIP(
4451 ompBuilder->createAtomicWrite(ompLoc, x, expr, ao, allocaIP));
4452 return success();
4453}
4454
4455/// Converts an LLVM dialect binary operation to the corresponding enum value
4456/// for `atomicrmw` supported binary operation.
4457static llvm::AtomicRMWInst::BinOp convertBinOpToAtomic(Operation &op) {
4459 .Case([&](LLVM::AddOp) { return llvm::AtomicRMWInst::BinOp::Add; })
4460 .Case([&](LLVM::SubOp) { return llvm::AtomicRMWInst::BinOp::Sub; })
4461 .Case([&](LLVM::AndOp) { return llvm::AtomicRMWInst::BinOp::And; })
4462 .Case([&](LLVM::OrOp) { return llvm::AtomicRMWInst::BinOp::Or; })
4463 .Case([&](LLVM::XOrOp) { return llvm::AtomicRMWInst::BinOp::Xor; })
4464 .Case([&](LLVM::UMaxOp) { return llvm::AtomicRMWInst::BinOp::UMax; })
4465 .Case([&](LLVM::UMinOp) { return llvm::AtomicRMWInst::BinOp::UMin; })
4466 .Case([&](LLVM::FAddOp) { return llvm::AtomicRMWInst::BinOp::FAdd; })
4467 .Case([&](LLVM::FSubOp) { return llvm::AtomicRMWInst::BinOp::FSub; })
4468 .Default(llvm::AtomicRMWInst::BinOp::BAD_BINOP);
4469}
4470
4471static void extractAtomicControlFlags(omp::AtomicUpdateOp atomicUpdateOp,
4472 bool &isIgnoreDenormalMode,
4473 bool &isFineGrainedMemory,
4474 bool &isRemoteMemory) {
4475 isIgnoreDenormalMode = false;
4476 isFineGrainedMemory = false;
4477 isRemoteMemory = false;
4478 if (atomicUpdateOp &&
4479 atomicUpdateOp->hasAttr(atomicUpdateOp.getAtomicControlAttrName())) {
4480 mlir::omp::AtomicControlAttr atomicControlAttr =
4481 atomicUpdateOp.getAtomicControlAttr();
4482 isIgnoreDenormalMode = atomicControlAttr.getIgnoreDenormalMode();
4483 isFineGrainedMemory = atomicControlAttr.getFineGrainedMemory();
4484 isRemoteMemory = atomicControlAttr.getRemoteMemory();
4485 }
4486}
4487
4488/// Converts an OpenMP atomic update operation using OpenMPIRBuilder.
4489static LogicalResult
4490convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst,
4491 llvm::IRBuilderBase &builder,
4492 LLVM::ModuleTranslation &moduleTranslation) {
4493 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4494 if (failed(checkImplementationStatus(*opInst)))
4495 return failure();
4496
4497 // Convert values and types.
4498 auto &innerOpList = opInst.getRegion().front().getOperations();
4499 bool isXBinopExpr{false};
4500 llvm::AtomicRMWInst::BinOp binop;
4501 mlir::Value mlirExpr;
4502 llvm::Value *llvmExpr = nullptr;
4503 llvm::Value *llvmX = nullptr;
4504 llvm::Type *llvmXElementType = nullptr;
4505 if (innerOpList.size() == 2) {
4506 // The two operations here are the update and the terminator.
4507 // Since we can identify the update operation, there is a possibility
4508 // that we can generate the atomicrmw instruction.
4509 mlir::Operation &innerOp = *opInst.getRegion().front().begin();
4510 if (!llvm::is_contained(innerOp.getOperands(),
4511 opInst.getRegion().getArgument(0))) {
4512 return opInst.emitError("no atomic update operation with region argument"
4513 " as operand found inside atomic.update region");
4514 }
4515 binop = convertBinOpToAtomic(innerOp);
4516 isXBinopExpr = innerOp.getOperand(0) == opInst.getRegion().getArgument(0);
4517 mlirExpr = (isXBinopExpr ? innerOp.getOperand(1) : innerOp.getOperand(0));
4518 llvmExpr = moduleTranslation.lookupValue(mlirExpr);
4519 } else {
4520 // Since the update region includes more than one operation
4521 // we will resort to generating a cmpxchg loop.
4522 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
4523 }
4524 llvmX = moduleTranslation.lookupValue(opInst.getX());
4525 llvmXElementType = moduleTranslation.convertType(
4526 opInst.getRegion().getArgument(0).getType());
4527 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
4528 /*isSigned=*/false,
4529 /*isVolatile=*/false};
4530
4531 llvm::AtomicOrdering atomicOrdering =
4532 convertAtomicOrdering(opInst.getMemoryOrder());
4533
4534 // Generate update code.
4535 auto updateFn =
4536 [&opInst, &moduleTranslation](
4537 llvm::Value *atomicx,
4538 llvm::IRBuilder<> &builder) -> llvm::Expected<llvm::Value *> {
4539 Block &bb = *opInst.getRegion().begin();
4540 moduleTranslation.mapValue(*opInst.getRegion().args_begin(), atomicx);
4541 moduleTranslation.mapBlock(&bb, builder.GetInsertBlock());
4542 if (failed(moduleTranslation.convertBlock(bb, true, builder)))
4543 return llvm::make_error<PreviouslyReportedError>();
4544
4545 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.getTerminator());
4546 assert(yieldop && yieldop.getResults().size() == 1 &&
4547 "terminator must be omp.yield op and it must have exactly one "
4548 "argument");
4549 return moduleTranslation.lookupValue(yieldop.getResults()[0]);
4550 };
4551
4552 bool isIgnoreDenormalMode;
4553 bool isFineGrainedMemory;
4554 bool isRemoteMemory;
4555 extractAtomicControlFlags(opInst, isIgnoreDenormalMode, isFineGrainedMemory,
4556 isRemoteMemory);
4557 // Handle ambiguous alloca, if any.
4558 auto allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
4559 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4560 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4561 ompBuilder->createAtomicUpdate(ompLoc, allocaIP, llvmAtomicX, llvmExpr,
4562 atomicOrdering, binop, updateFn,
4563 isXBinopExpr, isIgnoreDenormalMode,
4564 isFineGrainedMemory, isRemoteMemory);
4565
4566 if (failed(handleError(afterIP, *opInst)))
4567 return failure();
4568
4569 builder.restoreIP(*afterIP);
4570 return success();
4571}
4572
4573static LogicalResult
4574convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp,
4575 llvm::IRBuilderBase &builder,
4576 LLVM::ModuleTranslation &moduleTranslation) {
4577 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4578 if (failed(checkImplementationStatus(*atomicCaptureOp)))
4579 return failure();
4580
4581 mlir::Value mlirExpr;
4582 bool isXBinopExpr = false, isPostfixUpdate = false;
4583 llvm::AtomicRMWInst::BinOp binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
4584
4585 omp::AtomicUpdateOp atomicUpdateOp = atomicCaptureOp.getAtomicUpdateOp();
4586 omp::AtomicWriteOp atomicWriteOp = atomicCaptureOp.getAtomicWriteOp();
4587
4588 assert((atomicUpdateOp || atomicWriteOp) &&
4589 "internal op must be an atomic.update or atomic.write op");
4590
4591 if (atomicWriteOp) {
4592 isPostfixUpdate = true;
4593 mlirExpr = atomicWriteOp.getExpr();
4594 } else {
4595 isPostfixUpdate = atomicCaptureOp.getSecondOp() ==
4596 atomicCaptureOp.getAtomicUpdateOp().getOperation();
4597 auto &innerOpList = atomicUpdateOp.getRegion().front().getOperations();
4598 // Find the binary update operation that uses the region argument
4599 // and get the expression to update
4600 if (innerOpList.size() == 2) {
4601 mlir::Operation &innerOp = *atomicUpdateOp.getRegion().front().begin();
4602 if (!llvm::is_contained(innerOp.getOperands(),
4603 atomicUpdateOp.getRegion().getArgument(0))) {
4604 return atomicUpdateOp.emitError(
4605 "no atomic update operation with region argument"
4606 " as operand found inside atomic.update region");
4607 }
4608 binop = convertBinOpToAtomic(innerOp);
4609 isXBinopExpr =
4610 innerOp.getOperand(0) == atomicUpdateOp.getRegion().getArgument(0);
4611 mlirExpr = (isXBinopExpr ? innerOp.getOperand(1) : innerOp.getOperand(0));
4612 } else {
4613 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
4614 }
4615 }
4616
4617 llvm::Value *llvmExpr = moduleTranslation.lookupValue(mlirExpr);
4618 llvm::Value *llvmX =
4619 moduleTranslation.lookupValue(atomicCaptureOp.getAtomicReadOp().getX());
4620 llvm::Value *llvmV =
4621 moduleTranslation.lookupValue(atomicCaptureOp.getAtomicReadOp().getV());
4622 llvm::Type *llvmXElementType = moduleTranslation.convertType(
4623 atomicCaptureOp.getAtomicReadOp().getElementType());
4624 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
4625 /*isSigned=*/false,
4626 /*isVolatile=*/false};
4627 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicV = {llvmV, llvmXElementType,
4628 /*isSigned=*/false,
4629 /*isVolatile=*/false};
4630
4631 llvm::AtomicOrdering atomicOrdering =
4632 convertAtomicOrdering(atomicCaptureOp.getMemoryOrder());
4633
4634 auto updateFn =
4635 [&](llvm::Value *atomicx,
4636 llvm::IRBuilder<> &builder) -> llvm::Expected<llvm::Value *> {
4637 if (atomicWriteOp)
4638 return moduleTranslation.lookupValue(atomicWriteOp.getExpr());
4639 Block &bb = *atomicUpdateOp.getRegion().begin();
4640 moduleTranslation.mapValue(*atomicUpdateOp.getRegion().args_begin(),
4641 atomicx);
4642 moduleTranslation.mapBlock(&bb, builder.GetInsertBlock());
4643 if (failed(moduleTranslation.convertBlock(bb, true, builder)))
4644 return llvm::make_error<PreviouslyReportedError>();
4645
4646 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.getTerminator());
4647 assert(yieldop && yieldop.getResults().size() == 1 &&
4648 "terminator must be omp.yield op and it must have exactly one "
4649 "argument");
4650 return moduleTranslation.lookupValue(yieldop.getResults()[0]);
4651 };
4652
4653 bool isIgnoreDenormalMode;
4654 bool isFineGrainedMemory;
4655 bool isRemoteMemory;
4656 extractAtomicControlFlags(atomicUpdateOp, isIgnoreDenormalMode,
4657 isFineGrainedMemory, isRemoteMemory);
4658 // Handle ambiguous alloca, if any.
4659 auto allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
4660 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4661 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4662 ompBuilder->createAtomicCapture(
4663 ompLoc, allocaIP, llvmAtomicX, llvmAtomicV, llvmExpr, atomicOrdering,
4664 binop, updateFn, atomicUpdateOp, isPostfixUpdate, isXBinopExpr,
4665 isIgnoreDenormalMode, isFineGrainedMemory, isRemoteMemory);
4666
4667 if (failed(handleError(afterIP, *atomicCaptureOp)))
4668 return failure();
4669
4670 builder.restoreIP(*afterIP);
4671 return success();
4672}
4673
4674static llvm::omp::Directive convertCancellationConstructType(
4675 omp::ClauseCancellationConstructType directive) {
4676 switch (directive) {
4677 case omp::ClauseCancellationConstructType::Loop:
4678 return llvm::omp::Directive::OMPD_for;
4679 case omp::ClauseCancellationConstructType::Parallel:
4680 return llvm::omp::Directive::OMPD_parallel;
4681 case omp::ClauseCancellationConstructType::Sections:
4682 return llvm::omp::Directive::OMPD_sections;
4683 case omp::ClauseCancellationConstructType::Taskgroup:
4684 return llvm::omp::Directive::OMPD_taskgroup;
4685 }
4686 llvm_unreachable("Unhandled cancellation construct type");
4687}
4688
4689static LogicalResult
4690convertOmpCancel(omp::CancelOp op, llvm::IRBuilderBase &builder,
4691 LLVM::ModuleTranslation &moduleTranslation) {
4692 if (failed(checkImplementationStatus(*op.getOperation())))
4693 return failure();
4694
4695 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4696 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4697
4698 llvm::Value *ifCond = nullptr;
4699 if (Value ifVar = op.getIfExpr())
4700 ifCond = moduleTranslation.lookupValue(ifVar);
4701
4702 llvm::omp::Directive cancelledDirective =
4703 convertCancellationConstructType(op.getCancelDirective());
4704
4705 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4706 ompBuilder->createCancel(ompLoc, ifCond, cancelledDirective);
4707
4708 if (failed(handleError(afterIP, *op.getOperation())))
4709 return failure();
4710
4711 builder.restoreIP(afterIP.get());
4712
4713 return success();
4714}
4715
4716static LogicalResult
4717convertOmpCancellationPoint(omp::CancellationPointOp op,
4718 llvm::IRBuilderBase &builder,
4719 LLVM::ModuleTranslation &moduleTranslation) {
4720 if (failed(checkImplementationStatus(*op.getOperation())))
4721 return failure();
4722
4723 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4724 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4725
4726 llvm::omp::Directive cancelledDirective =
4727 convertCancellationConstructType(op.getCancelDirective());
4728
4729 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4730 ompBuilder->createCancellationPoint(ompLoc, cancelledDirective);
4731
4732 if (failed(handleError(afterIP, *op.getOperation())))
4733 return failure();
4734
4735 builder.restoreIP(afterIP.get());
4736
4737 return success();
4738}
4739
4740/// Converts an OpenMP Threadprivate operation into LLVM IR using
4741/// OpenMPIRBuilder.
4742static LogicalResult
4743convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder,
4744 LLVM::ModuleTranslation &moduleTranslation) {
4745 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4746 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4747 auto threadprivateOp = cast<omp::ThreadprivateOp>(opInst);
4748
4749 if (failed(checkImplementationStatus(opInst)))
4750 return failure();
4751
4752 Value symAddr = threadprivateOp.getSymAddr();
4753 auto *symOp = symAddr.getDefiningOp();
4754
4755 if (auto asCast = dyn_cast<LLVM::AddrSpaceCastOp>(symOp))
4756 symOp = asCast.getOperand().getDefiningOp();
4757
4758 if (!isa<LLVM::AddressOfOp>(symOp))
4759 return opInst.emitError("Addressing symbol not found");
4760 LLVM::AddressOfOp addressOfOp = dyn_cast<LLVM::AddressOfOp>(symOp);
4761
4762 LLVM::GlobalOp global =
4763 addressOfOp.getGlobal(moduleTranslation.symbolTable());
4764 llvm::GlobalValue *globalValue = moduleTranslation.lookupGlobal(global);
4765 llvm::Type *type = globalValue->getValueType();
4766 llvm::TypeSize typeSize =
4767 builder.GetInsertBlock()->getModule()->getDataLayout().getTypeStoreSize(
4768 type);
4769 llvm::ConstantInt *size = builder.getInt64(typeSize.getFixedValue());
4770 llvm::Value *callInst = ompBuilder->createCachedThreadPrivate(
4771 ompLoc, globalValue, size, global.getSymName() + ".cache");
4772 moduleTranslation.mapValue(opInst.getResult(0), callInst);
4773
4774 return success();
4775}
4776
4777static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind
4778convertToDeviceClauseKind(mlir::omp::DeclareTargetDeviceType deviceClause) {
4779 switch (deviceClause) {
4780 case mlir::omp::DeclareTargetDeviceType::host:
4781 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseHost;
4782 break;
4783 case mlir::omp::DeclareTargetDeviceType::nohost:
4784 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseNoHost;
4785 break;
4786 case mlir::omp::DeclareTargetDeviceType::any:
4787 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseAny;
4788 break;
4789 }
4790 llvm_unreachable("unhandled device clause");
4791}
4792
4793static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind
4795 mlir::omp::DeclareTargetCaptureClause captureClause) {
4796 switch (captureClause) {
4797 case mlir::omp::DeclareTargetCaptureClause::to:
4798 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo;
4799 case mlir::omp::DeclareTargetCaptureClause::link:
4800 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink;
4801 case mlir::omp::DeclareTargetCaptureClause::enter:
4802 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter;
4803 case mlir::omp::DeclareTargetCaptureClause::none:
4804 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryNone;
4805 }
4806 llvm_unreachable("unhandled capture clause");
4807}
4808
4810 Operation *op = value.getDefiningOp();
4811 if (auto addrCast = dyn_cast_if_present<LLVM::AddrSpaceCastOp>(op))
4812 op = addrCast->getOperand(0).getDefiningOp();
4813 if (auto addressOfOp = dyn_cast_if_present<LLVM::AddressOfOp>(op)) {
4814 auto modOp = addressOfOp->getParentOfType<mlir::ModuleOp>();
4815 return modOp.lookupSymbol(addressOfOp.getGlobalName());
4816 }
4817 return nullptr;
4818}
4819
4821 while (Operation *op = value.getDefiningOp()) {
4822 if (auto addrCast = dyn_cast_if_present<LLVM::AddrSpaceCastOp>(op))
4823 value = addrCast.getOperand();
4824 // Traces through hlfir.declare, fir.declare to reach the base address and
4825 // use for type lookup.
4826 else if (op->getName().getIdentifier() &&
4827 (op->getName().getIdentifier().str() == "hlfir.declare" ||
4828 op->getName().getIdentifier().str() == "fir.declare")) {
4829 if (op->getNumOperands() > 0)
4830 value = op->getOperand(0);
4831 else
4832 break;
4833 } else {
4834 break;
4835 }
4836 }
4837 return value;
4838}
4839
4840static llvm::SmallString<64>
4841getDeclareTargetRefPtrSuffix(LLVM::GlobalOp globalOp,
4842 llvm::OpenMPIRBuilder &ompBuilder) {
4843 llvm::SmallString<64> suffix;
4844 llvm::raw_svector_ostream os(suffix);
4845 if (globalOp.getVisibility() == mlir::SymbolTable::Visibility::Private) {
4846 auto loc = globalOp->getLoc()->findInstanceOf<FileLineColLoc>();
4847 auto fileInfoCallBack = [&loc]() {
4848 return std::pair<std::string, uint64_t>(
4849 llvm::StringRef(loc.getFilename()), loc.getLine());
4850 };
4851
4852 auto vfs = llvm::vfs::getRealFileSystem();
4853 os << llvm::format(
4854 "_%x",
4855 ompBuilder.getTargetEntryUniqueInfo(fileInfoCallBack, *vfs).FileID);
4856 }
4857 os << "_decl_tgt_ref_ptr";
4858
4859 return suffix;
4860}
4861
4862static bool isDeclareTargetLink(Value value) {
4863 if (auto declareTargetGlobal =
4864 dyn_cast_if_present<omp::DeclareTargetInterface>(
4865 getGlobalOpFromValue(value)))
4866 if (declareTargetGlobal.getDeclareTargetCaptureClause() ==
4867 omp::DeclareTargetCaptureClause::link)
4868 return true;
4869 return false;
4870}
4871
4872static bool isDeclareTargetTo(Value value) {
4873 if (auto declareTargetGlobal =
4874 dyn_cast_if_present<omp::DeclareTargetInterface>(
4875 getGlobalOpFromValue(value)))
4876 if (declareTargetGlobal.getDeclareTargetCaptureClause() ==
4877 omp::DeclareTargetCaptureClause::to ||
4878 declareTargetGlobal.getDeclareTargetCaptureClause() ==
4879 omp::DeclareTargetCaptureClause::enter)
4880 return true;
4881 return false;
4882}
4883
4884// Returns the reference pointer generated by the lowering of the declare
4885// target operation in cases where the link clause is used or the to clause is
4886// used in USM mode.
4887static llvm::Value *
4889 LLVM::ModuleTranslation &moduleTranslation) {
4890 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4891 if (auto gOp =
4892 dyn_cast_or_null<LLVM::GlobalOp>(getGlobalOpFromValue(value))) {
4893 if (auto declareTargetGlobal =
4894 dyn_cast<omp::DeclareTargetInterface>(gOp.getOperation())) {
4895 // In this case, we must utilise the reference pointer generated by
4896 // the declare target operation, similar to Clang
4897 if ((declareTargetGlobal.getDeclareTargetCaptureClause() ==
4898 omp::DeclareTargetCaptureClause::link) ||
4899 (declareTargetGlobal.getDeclareTargetCaptureClause() ==
4900 omp::DeclareTargetCaptureClause::to &&
4901 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
4902 llvm::SmallString<64> suffix =
4903 getDeclareTargetRefPtrSuffix(gOp, *ompBuilder);
4904
4905 if (gOp.getSymName().contains(suffix))
4906 return moduleTranslation.getLLVMModule()->getNamedValue(
4907 gOp.getSymName());
4908
4909 return moduleTranslation.getLLVMModule()->getNamedValue(
4910 (gOp.getSymName().str() + suffix.str()).str());
4911 }
4912 }
4913 }
4914 return nullptr;
4915}
4916
4917namespace {
4918// Append customMappers information to existing MapInfosTy
4919struct MapInfosTy : llvm::OpenMPIRBuilder::MapInfosTy {
4920 SmallVector<Operation *, 4> Mappers;
4921
4922 /// Append arrays in \a CurInfo.
4923 void append(MapInfosTy &curInfo) {
4924 Mappers.append(curInfo.Mappers.begin(), curInfo.Mappers.end());
4925 llvm::OpenMPIRBuilder::MapInfosTy::append(curInfo);
4926 }
4927};
4928// A small helper structure to contain data gathered
4929// for map lowering and coalese it into one area and
4930// avoiding extra computations such as searches in the
4931// llvm module for lowered mapped variables or checking
4932// if something is declare target (and retrieving the
4933// value) more than neccessary.
4934struct MapInfoData : MapInfosTy {
4935 llvm::SmallVector<bool, 4> IsDeclareTarget;
4936 llvm::SmallVector<bool, 4> IsAMember;
4937 // Identify if mapping was added by mapClause or use_device clauses.
4938 llvm::SmallVector<bool, 4> IsAMapping;
4939 llvm::SmallVector<mlir::Operation *, 4> MapClause;
4940 llvm::SmallVector<llvm::Value *, 4> OriginalValue;
4941 // Stripped off array/pointer to get the underlying
4942 // element type
4943 llvm::SmallVector<llvm::Type *, 4> BaseType;
4944
4945 /// Append arrays in \a CurInfo.
4946 void append(MapInfoData &CurInfo) {
4947 IsDeclareTarget.append(CurInfo.IsDeclareTarget.begin(),
4948 CurInfo.IsDeclareTarget.end());
4949 MapClause.append(CurInfo.MapClause.begin(), CurInfo.MapClause.end());
4950 OriginalValue.append(CurInfo.OriginalValue.begin(),
4951 CurInfo.OriginalValue.end());
4952 BaseType.append(CurInfo.BaseType.begin(), CurInfo.BaseType.end());
4953 MapInfosTy::append(CurInfo);
4954 }
4955};
4956
4957enum class TargetDirectiveEnumTy : uint32_t {
4958 None = 0,
4959 Target = 1,
4960 TargetData = 2,
4961 TargetEnterData = 3,
4962 TargetExitData = 4,
4963 TargetUpdate = 5
4964};
4965
4966static TargetDirectiveEnumTy getTargetDirectiveEnumTyFromOp(Operation *op) {
4967 return llvm::TypeSwitch<Operation *, TargetDirectiveEnumTy>(op)
4968 .Case([](omp::TargetDataOp) { return TargetDirectiveEnumTy::TargetData; })
4969 .Case([](omp::TargetEnterDataOp) {
4970 return TargetDirectiveEnumTy::TargetEnterData;
4971 })
4972 .Case([&](omp::TargetExitDataOp) {
4973 return TargetDirectiveEnumTy::TargetExitData;
4974 })
4975 .Case([&](omp::TargetUpdateOp) {
4976 return TargetDirectiveEnumTy::TargetUpdate;
4977 })
4978 .Case([&](omp::TargetOp) { return TargetDirectiveEnumTy::Target; })
4979 .Default([&](Operation *op) { return TargetDirectiveEnumTy::None; });
4980}
4981
4982} // namespace
4983
4984static uint64_t getArrayElementSizeInBits(LLVM::LLVMArrayType arrTy,
4985 DataLayout &dl) {
4986 if (auto nestedArrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(
4987 arrTy.getElementType()))
4988 return getArrayElementSizeInBits(nestedArrTy, dl);
4989 return dl.getTypeSizeInBits(arrTy.getElementType());
4990}
4991
4992// This function calculates the size to be offloaded for a specified type, given
4993// its associated map clause (which can contain bounds information which affects
4994// the total size), this size is calculated based on the underlying element type
4995// e.g. given a 1-D array of ints, we will calculate the size from the integer
4996// type * number of elements in the array. This size can be used in other
4997// calculations but is ultimately used as an argument to the OpenMP runtimes
4998// kernel argument structure which is generated through the combinedInfo data
4999// structures.
5000// This function is somewhat equivalent to Clang's getExprTypeSize inside of
5001// CGOpenMPRuntime.cpp.
5002static llvm::Value *getSizeInBytes(DataLayout &dl, const mlir::Type &type,
5003 Operation *clauseOp,
5004 llvm::Value *basePointer,
5005 llvm::Type *baseType,
5006 llvm::IRBuilderBase &builder,
5007 LLVM::ModuleTranslation &moduleTranslation) {
5008 if (auto memberClause =
5009 mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(clauseOp)) {
5010 // This calculates the size to transfer based on bounds and the underlying
5011 // element type, provided bounds have been specified (Fortran
5012 // pointers/allocatables/target and arrays that have sections specified fall
5013 // into this as well)
5014 if (!memberClause.getBounds().empty()) {
5015 llvm::Value *elementCount = builder.getInt64(1);
5016 for (auto bounds : memberClause.getBounds()) {
5017 if (auto boundOp = mlir::dyn_cast_if_present<mlir::omp::MapBoundsOp>(
5018 bounds.getDefiningOp())) {
5019 // The below calculation for the size to be mapped calculated from the
5020 // map.info's bounds is: (elemCount * [UB - LB] + 1), later we
5021 // multiply by the underlying element types byte size to get the full
5022 // size to be offloaded based on the bounds
5023 elementCount = builder.CreateMul(
5024 elementCount,
5025 builder.CreateAdd(
5026 builder.CreateSub(
5027 moduleTranslation.lookupValue(boundOp.getUpperBound()),
5028 moduleTranslation.lookupValue(boundOp.getLowerBound())),
5029 builder.getInt64(1)));
5030 }
5031 }
5032
5033 // utilising getTypeSizeInBits instead of getTypeSize as getTypeSize gives
5034 // the size in inconsistent byte or bit format.
5035 uint64_t underlyingTypeSzInBits = dl.getTypeSizeInBits(type);
5036 if (auto arrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(type))
5037 underlyingTypeSzInBits = getArrayElementSizeInBits(arrTy, dl);
5038
5039 // The size in bytes x number of elements, the sizeInBytes stored is
5040 // the underyling types size, e.g. if ptr<i32>, it'll be the i32's
5041 // size, so we do some on the fly runtime math to get the size in
5042 // bytes from the extent (ub - lb) * sizeInBytes. NOTE: This may need
5043 // some adjustment for members with more complex types.
5044 return builder.CreateMul(elementCount,
5045 builder.getInt64(underlyingTypeSzInBits / 8));
5046 }
5047 }
5048
5049 return builder.getInt64(dl.getTypeSizeInBits(type) / 8);
5050}
5051
5052// Convert the MLIR map flag set to the runtime map flag set for embedding
5053// in LLVM-IR. This is important as the two bit-flag lists do not correspond
5054// 1-to-1 as there's flags the runtime doesn't care about and vice versa.
5055// Certain flags are discarded here such as RefPtee and co.
5056static llvm::omp::OpenMPOffloadMappingFlags
5057convertClauseMapFlags(omp::ClauseMapFlags mlirFlags) {
5058 auto mapTypeToBool = [&mlirFlags](omp::ClauseMapFlags flag) {
5059 return (mlirFlags & flag) == flag;
5060 };
5061 const bool hasExplicitMap =
5062 (mlirFlags & ~omp::ClauseMapFlags::is_device_ptr) !=
5063 omp::ClauseMapFlags::none;
5064
5065 llvm::omp::OpenMPOffloadMappingFlags mapType =
5066 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
5067
5068 if (mapTypeToBool(omp::ClauseMapFlags::to))
5069 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
5070
5071 if (mapTypeToBool(omp::ClauseMapFlags::from))
5072 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
5073
5074 if (mapTypeToBool(omp::ClauseMapFlags::always))
5075 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
5076
5077 if (mapTypeToBool(omp::ClauseMapFlags::del))
5078 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
5079
5080 if (mapTypeToBool(omp::ClauseMapFlags::return_param))
5081 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
5082
5083 if (mapTypeToBool(omp::ClauseMapFlags::priv))
5084 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRIVATE;
5085
5086 if (mapTypeToBool(omp::ClauseMapFlags::literal))
5087 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
5088
5089 if (mapTypeToBool(omp::ClauseMapFlags::implicit))
5090 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
5091
5092 if (mapTypeToBool(omp::ClauseMapFlags::close))
5093 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
5094
5095 if (mapTypeToBool(omp::ClauseMapFlags::present))
5096 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT;
5097
5098 if (mapTypeToBool(omp::ClauseMapFlags::ompx_hold))
5099 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD;
5100
5101 if (mapTypeToBool(omp::ClauseMapFlags::attach))
5102 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH;
5103
5104 if (mapTypeToBool(omp::ClauseMapFlags::is_device_ptr)) {
5105 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
5106 if (!hasExplicitMap)
5107 mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
5108 }
5109
5110 return mapType;
5111}
5112
5114 MapInfoData &mapData, SmallVectorImpl<Value> &mapVars,
5115 LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl,
5116 llvm::IRBuilderBase &builder, ArrayRef<Value> useDevPtrOperands = {},
5117 ArrayRef<Value> useDevAddrOperands = {},
5118 ArrayRef<Value> hasDevAddrOperands = {}) {
5119 auto checkIsAMember = [](const auto &mapVars, auto mapOp) {
5120 // Check if this is a member mapping and correctly assign that it is, if
5121 // it is a member of a larger object.
5122 // TODO: Need better handling of members, and distinguishing of members
5123 // that are implicitly allocated on device vs explicitly passed in as
5124 // arguments.
5125 // TODO: May require some further additions to support nested record
5126 // types, i.e. member maps that can have member maps.
5127 for (Value mapValue : mapVars) {
5128 auto map = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
5129 for (auto member : map.getMembers())
5130 if (member == mapOp)
5131 return true;
5132 }
5133 return false;
5134 };
5135
5136 // Process MapOperands
5137 for (Value mapValue : mapVars) {
5138 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
5139 Value offloadPtr =
5140 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
5141 mapData.OriginalValue.push_back(moduleTranslation.lookupValue(offloadPtr));
5142 mapData.Pointers.push_back(mapData.OriginalValue.back());
5143
5144 if (llvm::Value *refPtr =
5145 getRefPtrIfDeclareTarget(offloadPtr, moduleTranslation)) {
5146 mapData.IsDeclareTarget.push_back(true);
5147 mapData.BasePointers.push_back(refPtr);
5148 } else if (isDeclareTargetTo(offloadPtr)) {
5149 mapData.IsDeclareTarget.push_back(true);
5150 mapData.BasePointers.push_back(mapData.OriginalValue.back());
5151 } else { // regular mapped variable
5152 mapData.IsDeclareTarget.push_back(false);
5153 mapData.BasePointers.push_back(mapData.OriginalValue.back());
5154 }
5155
5156 mapData.BaseType.push_back(
5157 moduleTranslation.convertType(mapOp.getVarType()));
5158 mapData.Sizes.push_back(
5159 getSizeInBytes(dl, mapOp.getVarType(), mapOp, mapData.Pointers.back(),
5160 mapData.BaseType.back(), builder, moduleTranslation));
5161 mapData.MapClause.push_back(mapOp.getOperation());
5162 mapData.Types.push_back(convertClauseMapFlags(mapOp.getMapType()));
5163 mapData.Names.push_back(LLVM::createMappingInformation(
5164 mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
5165 mapData.DevicePointers.push_back(llvm::OpenMPIRBuilder::DeviceInfoTy::None);
5166 if (mapOp.getMapperId())
5167 mapData.Mappers.push_back(
5169 mapOp, mapOp.getMapperIdAttr()));
5170 else
5171 mapData.Mappers.push_back(nullptr);
5172 mapData.IsAMapping.push_back(true);
5173 mapData.IsAMember.push_back(checkIsAMember(mapVars, mapOp));
5174 }
5175
5176 auto findMapInfo = [&mapData](llvm::Value *val,
5177 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
5178 unsigned index = 0;
5179 bool found = false;
5180 for (llvm::Value *basePtr : mapData.OriginalValue) {
5181 if (basePtr == val && mapData.IsAMapping[index]) {
5182 found = true;
5183 mapData.Types[index] |=
5184 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
5185 mapData.DevicePointers[index] = devInfoTy;
5186 }
5187 index++;
5188 }
5189 return found;
5190 };
5191
5192 // Process useDevPtr(Addr)Operands
5193 auto addDevInfos = [&](const llvm::ArrayRef<Value> &useDevOperands,
5194 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
5195 for (Value mapValue : useDevOperands) {
5196 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
5197 Value offloadPtr =
5198 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
5199 llvm::Value *origValue = moduleTranslation.lookupValue(offloadPtr);
5200
5201 // Check if map info is already present for this entry.
5202 if (!findMapInfo(origValue, devInfoTy)) {
5203 mapData.OriginalValue.push_back(origValue);
5204 mapData.Pointers.push_back(mapData.OriginalValue.back());
5205 mapData.IsDeclareTarget.push_back(false);
5206 mapData.BasePointers.push_back(mapData.OriginalValue.back());
5207 mapData.BaseType.push_back(
5208 moduleTranslation.convertType(mapOp.getVarType()));
5209 mapData.Sizes.push_back(builder.getInt64(0));
5210 mapData.MapClause.push_back(mapOp.getOperation());
5211 mapData.Types.push_back(
5212 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM);
5213 mapData.Names.push_back(LLVM::createMappingInformation(
5214 mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
5215 mapData.DevicePointers.push_back(devInfoTy);
5216 mapData.Mappers.push_back(nullptr);
5217 mapData.IsAMapping.push_back(false);
5218 mapData.IsAMember.push_back(checkIsAMember(useDevOperands, mapOp));
5219 }
5220 }
5221 };
5222
5223 addDevInfos(useDevAddrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
5224 addDevInfos(useDevPtrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer);
5225
5226 for (Value mapValue : hasDevAddrOperands) {
5227 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
5228 Value offloadPtr =
5229 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
5230 llvm::Value *origValue = moduleTranslation.lookupValue(offloadPtr);
5231 auto mapType = convertClauseMapFlags(mapOp.getMapType());
5232 auto mapTypeAlways = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
5233 bool isDevicePtr =
5234 (mapOp.getMapType() & omp::ClauseMapFlags::is_device_ptr) !=
5235 omp::ClauseMapFlags::none;
5236
5237 mapData.OriginalValue.push_back(origValue);
5238 mapData.BasePointers.push_back(origValue);
5239 mapData.Pointers.push_back(origValue);
5240 mapData.IsDeclareTarget.push_back(false);
5241 mapData.BaseType.push_back(
5242 moduleTranslation.convertType(mapOp.getVarType()));
5243 mapData.Sizes.push_back(
5244 builder.getInt64(dl.getTypeSize(mapOp.getVarType())));
5245 mapData.MapClause.push_back(mapOp.getOperation());
5246 if (llvm::to_underlying(mapType & mapTypeAlways)) {
5247 // Descriptors are mapped with the ALWAYS flag, since they can get
5248 // rematerialized, so the address of the decriptor for a given object
5249 // may change from one place to another.
5250 mapData.Types.push_back(mapType);
5251 // Technically it's possible for a non-descriptor mapping to have
5252 // both has-device-addr and ALWAYS, so lookup the mapper in case it
5253 // exists.
5254 if (mapOp.getMapperId()) {
5255 mapData.Mappers.push_back(
5257 mapOp, mapOp.getMapperIdAttr()));
5258 } else {
5259 mapData.Mappers.push_back(nullptr);
5260 }
5261 } else {
5262 // For is_device_ptr we need the map type to propagate so the runtime
5263 // can materialize the device-side copy of the pointer container.
5264 mapData.Types.push_back(
5265 isDevicePtr ? mapType
5266 : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL);
5267 mapData.Mappers.push_back(nullptr);
5268 }
5269 mapData.Names.push_back(LLVM::createMappingInformation(
5270 mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
5271 mapData.DevicePointers.push_back(
5272 isDevicePtr ? llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer
5273 : llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
5274 mapData.IsAMapping.push_back(false);
5275 mapData.IsAMember.push_back(checkIsAMember(hasDevAddrOperands, mapOp));
5276 }
5277}
5278
5279static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp) {
5280 auto *res = llvm::find(mapData.MapClause, memberOp);
5281 assert(res != mapData.MapClause.end() &&
5282 "MapInfoOp for member not found in MapData, cannot return index");
5283 return std::distance(mapData.MapClause.begin(), res);
5284}
5285
5287 omp::MapInfoOp mapInfo) {
5288 ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
5289 llvm::SmallVector<size_t> occludedChildren;
5290 llvm::sort(
5291 indices.begin(), indices.end(), [&](const size_t a, const size_t b) {
5292 // Bail early if we are asked to look at the same index. If we do not
5293 // bail early, we can end up mistakenly adding indices to
5294 // occludedChildren. This can occur with some types of libc++ hardening.
5295 if (a == b)
5296 return false;
5297
5298 auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]);
5299 auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]);
5300
5301 for (auto it : llvm::zip(memberIndicesA, memberIndicesB)) {
5302 int64_t aIndex = mlir::cast<IntegerAttr>(std::get<0>(it)).getInt();
5303 int64_t bIndex = mlir::cast<IntegerAttr>(std::get<1>(it)).getInt();
5304
5305 if (aIndex == bIndex)
5306 continue;
5307
5308 if (aIndex < bIndex)
5309 return true;
5310
5311 if (aIndex > bIndex)
5312 return false;
5313 }
5314
5315 // Iterated up until the end of the smallest member and
5316 // they were found to be equal up to that point, so select
5317 // the member with the lowest index count, so the "parent"
5318 bool memberAParent = memberIndicesA.size() < memberIndicesB.size();
5319 if (memberAParent)
5320 occludedChildren.push_back(b);
5321 else
5322 occludedChildren.push_back(a);
5323 return memberAParent;
5324 });
5325
5326 // We remove children from the index list that are overshadowed by
5327 // a parent, this prevents us retrieving these as the first or last
5328 // element when the parent is the correct element in these cases.
5329 for (auto v : occludedChildren)
5330 indices.erase(std::remove(indices.begin(), indices.end(), v),
5331 indices.end());
5332}
5333
5334static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo,
5335 bool first) {
5336 ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
5337 // Only 1 member has been mapped, we can return it.
5338 if (indexAttr.size() == 1)
5339 return cast<omp::MapInfoOp>(mapInfo.getMembers()[0].getDefiningOp());
5340 llvm::SmallVector<size_t> indices(indexAttr.size());
5341 std::iota(indices.begin(), indices.end(), 0);
5342 sortMapIndices(indices, mapInfo);
5343 return llvm::cast<omp::MapInfoOp>(
5344 mapInfo.getMembers()[first ? indices.front() : indices.back()]
5345 .getDefiningOp());
5346}
5347
5348/// This function calculates the array/pointer offset for map data provided
5349/// with bounds operations, e.g. when provided something like the following:
5350///
5351/// Fortran
5352/// map(tofrom: array(2:5, 3:2))
5353///
5354/// We must calculate the initial pointer offset to pass across, this function
5355/// performs this using bounds.
5356///
5357/// TODO/WARNING: This only supports Fortran's column major indexing currently
5358/// as is noted in the note below and comments in the function, we must extend
5359/// this function when we add a C++ frontend.
5360/// NOTE: which while specified in row-major order it currently needs to be
5361/// flipped for Fortran's column order array allocation and access (as
5362/// opposed to C++'s row-major, hence the backwards processing where order is
5363/// important). This is likely important to keep in mind for the future when
5364/// we incorporate a C++ frontend, both frontends will need to agree on the
5365/// ordering of generated bounds operations (one may have to flip them) to
5366/// make the below lowering frontend agnostic. The offload size
5367/// calcualtion may also have to be adjusted for C++.
5368static std::vector<llvm::Value *>
5370 llvm::IRBuilderBase &builder, bool isArrayTy,
5371 OperandRange bounds) {
5372 std::vector<llvm::Value *> idx;
5373 // There's no bounds to calculate an offset from, we can safely
5374 // ignore and return no indices.
5375 if (bounds.empty())
5376 return idx;
5377
5378 // If we have an array type, then we have its type so can treat it as a
5379 // normal GEP instruction where the bounds operations are simply indexes
5380 // into the array. We currently do reverse order of the bounds, which
5381 // I believe leans more towards Fortran's column-major in memory.
5382 if (isArrayTy) {
5383 idx.push_back(builder.getInt64(0));
5384 for (int i = bounds.size() - 1; i >= 0; --i) {
5385 if (auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
5386 bounds[i].getDefiningOp())) {
5387 idx.push_back(moduleTranslation.lookupValue(boundOp.getLowerBound()));
5388 }
5389 }
5390 } else {
5391 // If we do not have an array type, but we have bounds, then we're dealing
5392 // with a pointer that's being treated like an array and we have the
5393 // underlying type e.g. an i32, or f64 etc, e.g. a fortran descriptor base
5394 // address (pointer pointing to the actual data) so we must caclulate the
5395 // offset using a single index which the following loop attempts to
5396 // compute using the standard column-major algorithm e.g for a 3D array:
5397 //
5398 // ((((c_idx * b_len) + b_idx) * a_len) + a_idx)
5399 //
5400 // It is of note that it's doing column-major rather than row-major at the
5401 // moment, but having a way for the frontend to indicate which major format
5402 // to use or standardizing/canonicalizing the order of the bounds to compute
5403 // the offset may be useful in the future when there's other frontends with
5404 // different formats.
5405 std::vector<llvm::Value *> dimensionIndexSizeOffset;
5406 for (int i = bounds.size() - 1; i >= 0; --i) {
5407 if (auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
5408 bounds[i].getDefiningOp())) {
5409 if (i == ((int)bounds.size() - 1))
5410 idx.emplace_back(
5411 moduleTranslation.lookupValue(boundOp.getLowerBound()));
5412 else
5413 idx.back() = builder.CreateAdd(
5414 builder.CreateMul(idx.back(), moduleTranslation.lookupValue(
5415 boundOp.getExtent())),
5416 moduleTranslation.lookupValue(boundOp.getLowerBound()));
5417 }
5418 }
5419 }
5420
5421 return idx;
5422}
5423
5425 llvm::transform(values, std::back_inserter(ints), [](Attribute value) {
5426 return cast<IntegerAttr>(value).getInt();
5427 });
5428}
5429
5430// Gathers members that are overlapping in the parent, excluding members that
5431// themselves overlap, keeping the top-most (closest to parents level) map.
5432static void
5434 omp::MapInfoOp parentOp) {
5435 // No members mapped, no overlaps.
5436 if (parentOp.getMembers().empty())
5437 return;
5438
5439 // Single member, we can insert and return early.
5440 if (parentOp.getMembers().size() == 1) {
5441 overlapMapDataIdxs.push_back(0);
5442 return;
5443 }
5444
5445 // 1) collect list of top-level overlapping members from MemberOp
5447 ArrayAttr indexAttr = parentOp.getMembersIndexAttr();
5448 for (auto [memIndex, indicesAttr] : llvm::enumerate(indexAttr))
5449 memberByIndex.push_back(
5450 std::make_pair(memIndex, cast<ArrayAttr>(indicesAttr)));
5451
5452 // Sort the smallest first (higher up the parent -> member chain), so that
5453 // when we remove members, we remove as much as we can in the initial
5454 // iterations, shortening the number of passes required.
5455 llvm::sort(memberByIndex.begin(), memberByIndex.end(),
5456 [&](auto a, auto b) { return a.second.size() < b.second.size(); });
5457
5458 // Remove elements from the vector if there is a parent element that
5459 // supersedes it. i.e. if member [0] is mapped, we can remove members [0,1],
5460 // [0,2].. etc.
5462 for (auto v : memberByIndex) {
5463 llvm::SmallVector<int64_t> vArr(v.second.size());
5464 getAsIntegers(v.second, vArr);
5465 skipList.push_back(
5466 *std::find_if(memberByIndex.begin(), memberByIndex.end(), [&](auto x) {
5467 if (v == x)
5468 return false;
5469 llvm::SmallVector<int64_t> xArr(x.second.size());
5470 getAsIntegers(x.second, xArr);
5471 return std::equal(vArr.begin(), vArr.end(), xArr.begin()) &&
5472 xArr.size() >= vArr.size();
5473 }));
5474 }
5475
5476 // Collect the indices, as we need the base pointer etc. from the MapData
5477 // structure which is primarily accessible via index at the moment.
5478 for (auto v : memberByIndex)
5479 if (find(skipList.begin(), skipList.end(), v) == skipList.end())
5480 overlapMapDataIdxs.push_back(v.first);
5481}
5482
5483// The intent is to verify if the mapped data being passed is a
5484// pointer -> pointee that requires special handling in certain cases,
5485// e.g. applying the OMP_MAP_PTR_AND_OBJ map type.
5486//
5487// There may be a better way to verify this, but unfortunately with
5488// opaque pointers we lose the ability to easily check if something is
5489// a pointer whilst maintaining access to the underlying type.
5490static bool checkIfPointerMap(omp::MapInfoOp mapOp) {
5491 // If we have a varPtrPtr field assigned then the underlying type is a pointer
5492 if (mapOp.getVarPtrPtr())
5493 return true;
5494
5495 // If the map data is declare target with a link clause, then it's represented
5496 // as a pointer when we lower it to LLVM-IR even if at the MLIR level it has
5497 // no relation to pointers.
5498 if (isDeclareTargetLink(mapOp.getVarPtr()))
5499 return true;
5500
5501 return false;
5502}
5503
5504// This creates two insertions into the MapInfosTy data structure for the
5505// "parent" of a set of members, (usually a container e.g.
5506// class/structure/derived type) when subsequent members have also been
5507// explicitly mapped on the same map clause. Certain types, such as Fortran
5508// descriptors are mapped like this as well, however, the members are
5509// implicit as far as a user is concerned, but we must explicitly map them
5510// internally.
5511//
5512// This function also returns the memberOfFlag for this particular parent,
5513// which is utilised in subsequent member mappings (by modifying there map type
5514// with it) to indicate that a member is part of this parent and should be
5515// treated by the runtime as such. Important to achieve the correct mapping.
5516//
5517// This function borrows a lot from Clang's emitCombinedEntry function
5518// inside of CGOpenMPRuntime.cpp
5519static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
5520 LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder,
5521 llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo,
5522 MapInfoData &mapData, uint64_t mapDataIndex,
5523 TargetDirectiveEnumTy targetDirective) {
5524 assert(!ompBuilder.Config.isTargetDevice() &&
5525 "function only supported for host device codegen");
5526
5527 auto parentClause =
5528 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
5529
5530 auto *parentMapper = mapData.Mappers[mapDataIndex];
5531
5532 // Map the first segment of the parent. If a user-defined mapper is attached,
5533 // include the parent's to/from-style bits (and common modifiers) in this
5534 // base entry so the mapper receives correct copy semantics via its 'type'
5535 // parameter. Also keep TARGET_PARAM when required for kernel arguments.
5536 llvm::omp::OpenMPOffloadMappingFlags baseFlag =
5537 (targetDirective == TargetDirectiveEnumTy::Target &&
5538 !mapData.IsDeclareTarget[mapDataIndex])
5539 ? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM
5540 : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
5541
5542 if (parentMapper) {
5543 using mapFlags = llvm::omp::OpenMPOffloadMappingFlags;
5544 // Preserve relevant map-type bits from the parent clause. These include
5545 // the copy direction (TO/FROM), as well as commonly used modifiers that
5546 // should be visible to the mapper for correct behaviour.
5547 mapFlags parentFlags = mapData.Types[mapDataIndex];
5548 mapFlags preserve = mapFlags::OMP_MAP_TO | mapFlags::OMP_MAP_FROM |
5549 mapFlags::OMP_MAP_ALWAYS | mapFlags::OMP_MAP_CLOSE |
5550 mapFlags::OMP_MAP_PRESENT | mapFlags::OMP_MAP_OMPX_HOLD;
5551 baseFlag |= (parentFlags & preserve);
5552 }
5553
5554 combinedInfo.Types.emplace_back(baseFlag);
5555 combinedInfo.DevicePointers.emplace_back(
5556 mapData.DevicePointers[mapDataIndex]);
5557 // Only attach the mapper to the base entry when we are mapping the whole
5558 // parent. Combined/segment entries must not carry a mapper; otherwise the
5559 // mapper can be invoked with a partial size, which is undefined behaviour.
5560 combinedInfo.Mappers.emplace_back(
5561 parentMapper && !parentClause.getPartialMap() ? parentMapper : nullptr);
5562 combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
5563 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
5564 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
5565
5566 // Calculate size of the parent object being mapped based on the
5567 // addresses at runtime, highAddr - lowAddr = size. This of course
5568 // doesn't factor in allocated data like pointers, hence the further
5569 // processing of members specified by users, or in the case of
5570 // Fortran pointers and allocatables, the mapping of the pointed to
5571 // data by the descriptor (which itself, is a structure containing
5572 // runtime information on the dynamically allocated data).
5573 llvm::Value *lowAddr, *highAddr;
5574 if (!parentClause.getPartialMap()) {
5575 lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
5576 builder.getPtrTy());
5577 highAddr = builder.CreatePointerCast(
5578 builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
5579 mapData.Pointers[mapDataIndex], 1),
5580 builder.getPtrTy());
5581 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
5582 } else {
5583 auto mapOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
5584 int firstMemberIdx = getMapDataMemberIdx(
5585 mapData, getFirstOrLastMappedMemberPtr(mapOp, true));
5586 lowAddr = builder.CreatePointerCast(mapData.Pointers[firstMemberIdx],
5587 builder.getPtrTy());
5588 int lastMemberIdx = getMapDataMemberIdx(
5589 mapData, getFirstOrLastMappedMemberPtr(mapOp, false));
5590 highAddr = builder.CreatePointerCast(
5591 builder.CreateGEP(mapData.BaseType[lastMemberIdx],
5592 mapData.Pointers[lastMemberIdx], builder.getInt64(1)),
5593 builder.getPtrTy());
5594 combinedInfo.Pointers.emplace_back(mapData.Pointers[firstMemberIdx]);
5595 }
5596
5597 llvm::Value *size = builder.CreateIntCast(
5598 builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
5599 builder.getInt64Ty(),
5600 /*isSigned=*/false);
5601 combinedInfo.Sizes.push_back(size);
5602
5603 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag =
5604 ompBuilder.getMemberOfFlag(combinedInfo.BasePointers.size() - 1);
5605
5606 // This creates the initial MEMBER_OF mapping that consists of
5607 // the parent/top level container (same as above effectively, except
5608 // with a fixed initial compile time size and separate maptype which
5609 // indicates the true mape type (tofrom etc.). This parent mapping is
5610 // only relevant if the structure in its totality is being mapped,
5611 // otherwise the above suffices.
5612 if (!parentClause.getPartialMap()) {
5613 // TODO: This will need to be expanded to include the whole host of logic
5614 // for the map flags that Clang currently supports (e.g. it should do some
5615 // further case specific flag modifications). For the moment, it handles
5616 // what we support as expected.
5617 llvm::omp::OpenMPOffloadMappingFlags mapFlag = mapData.Types[mapDataIndex];
5618 bool hasMapClose = (llvm::omp::OpenMPOffloadMappingFlags(mapFlag) &
5619 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE) ==
5620 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
5621 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
5622
5623 if (targetDirective == TargetDirectiveEnumTy::TargetUpdate || hasMapClose) {
5624 combinedInfo.Types.emplace_back(mapFlag);
5625 combinedInfo.DevicePointers.emplace_back(
5626 mapData.DevicePointers[mapDataIndex]);
5627 combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
5628 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
5629 combinedInfo.BasePointers.emplace_back(
5630 mapData.BasePointers[mapDataIndex]);
5631 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
5632 combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIndex]);
5633 combinedInfo.Mappers.emplace_back(nullptr);
5634 } else {
5635 llvm::SmallVector<size_t> overlapIdxs;
5636 // Find all of the members that "overlap", i.e. occlude other members that
5637 // were mapped alongside the parent, e.g. member [0], occludes [0,1] and
5638 // [0,2], but not [1,0].
5639 getOverlappedMembers(overlapIdxs, parentClause);
5640 // We need to make sure the overlapped members are sorted in order of
5641 // lowest address to highest address.
5642 sortMapIndices(overlapIdxs, parentClause);
5643
5644 lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
5645 builder.getPtrTy());
5646 highAddr = builder.CreatePointerCast(
5647 builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
5648 mapData.Pointers[mapDataIndex], 1),
5649 builder.getPtrTy());
5650
5651 // TODO: We may want to skip arrays/array sections in this as Clang does.
5652 // It appears to be an optimisation rather than a necessity though,
5653 // but this requires further investigation. However, we would have to make
5654 // sure to not exclude maps with bounds that ARE pointers, as these are
5655 // processed as separate components, i.e. pointer + data.
5656 for (auto v : overlapIdxs) {
5657 auto mapDataOverlapIdx = getMapDataMemberIdx(
5658 mapData,
5659 cast<omp::MapInfoOp>(parentClause.getMembers()[v].getDefiningOp()));
5660 combinedInfo.Types.emplace_back(mapFlag);
5661 combinedInfo.DevicePointers.emplace_back(
5662 mapData.DevicePointers[mapDataOverlapIdx]);
5663 combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
5664 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
5665 combinedInfo.BasePointers.emplace_back(
5666 mapData.BasePointers[mapDataIndex]);
5667 combinedInfo.Mappers.emplace_back(nullptr);
5668 combinedInfo.Pointers.emplace_back(lowAddr);
5669 combinedInfo.Sizes.emplace_back(builder.CreateIntCast(
5670 builder.CreatePtrDiff(builder.getInt8Ty(),
5671 mapData.OriginalValue[mapDataOverlapIdx],
5672 lowAddr),
5673 builder.getInt64Ty(), /*isSigned=*/true));
5674 lowAddr = builder.CreateConstGEP1_32(
5675 checkIfPointerMap(llvm::cast<omp::MapInfoOp>(
5676 mapData.MapClause[mapDataOverlapIdx]))
5677 ? builder.getPtrTy()
5678 : mapData.BaseType[mapDataOverlapIdx],
5679 mapData.BasePointers[mapDataOverlapIdx], 1);
5680 }
5681
5682 combinedInfo.Types.emplace_back(mapFlag);
5683 combinedInfo.DevicePointers.emplace_back(
5684 mapData.DevicePointers[mapDataIndex]);
5685 combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
5686 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
5687 combinedInfo.BasePointers.emplace_back(
5688 mapData.BasePointers[mapDataIndex]);
5689 combinedInfo.Mappers.emplace_back(nullptr);
5690 combinedInfo.Pointers.emplace_back(lowAddr);
5691 combinedInfo.Sizes.emplace_back(builder.CreateIntCast(
5692 builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
5693 builder.getInt64Ty(), true));
5694 }
5695 }
5696 return memberOfFlag;
5697}
5698
5699// This function is intended to add explicit mappings of members
5701 LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder,
5702 llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo,
5703 MapInfoData &mapData, uint64_t mapDataIndex,
5704 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag,
5705 TargetDirectiveEnumTy targetDirective) {
5706 assert(!ompBuilder.Config.isTargetDevice() &&
5707 "function only supported for host device codegen");
5708
5709 auto parentClause =
5710 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
5711
5712 for (auto mappedMembers : parentClause.getMembers()) {
5713 auto memberClause =
5714 llvm::cast<omp::MapInfoOp>(mappedMembers.getDefiningOp());
5715 int memberDataIdx = getMapDataMemberIdx(mapData, memberClause);
5716
5717 assert(memberDataIdx >= 0 && "could not find mapped member of structure");
5718
5719 // If we're currently mapping a pointer to a block of data, we must
5720 // initially map the pointer, and then attatch/bind the data with a
5721 // subsequent map to the pointer. This segment of code generates the
5722 // pointer mapping, which can in certain cases be optimised out as Clang
5723 // currently does in its lowering. However, for the moment we do not do so,
5724 // in part as we currently have substantially less information on the data
5725 // being mapped at this stage.
5726 if (checkIfPointerMap(memberClause)) {
5727 auto mapFlag = convertClauseMapFlags(memberClause.getMapType());
5728 mapFlag &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
5729 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
5730 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
5731 combinedInfo.Types.emplace_back(mapFlag);
5732 combinedInfo.DevicePointers.emplace_back(
5733 llvm::OpenMPIRBuilder::DeviceInfoTy::None);
5734 combinedInfo.Mappers.emplace_back(nullptr);
5735 combinedInfo.Names.emplace_back(
5736 LLVM::createMappingInformation(memberClause.getLoc(), ompBuilder));
5737 combinedInfo.BasePointers.emplace_back(
5738 mapData.BasePointers[mapDataIndex]);
5739 combinedInfo.Pointers.emplace_back(mapData.BasePointers[memberDataIdx]);
5740 combinedInfo.Sizes.emplace_back(builder.getInt64(
5741 moduleTranslation.getLLVMModule()->getDataLayout().getPointerSize()));
5742 }
5743
5744 // Same MemberOfFlag to indicate its link with parent and other members
5745 // of.
5746 auto mapFlag = convertClauseMapFlags(memberClause.getMapType());
5747 mapFlag &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
5748 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
5749 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
5750 bool isDeclTargetTo = isDeclareTargetTo(parentClause.getVarPtr()
5751 ? parentClause.getVarPtr()
5752 : parentClause.getVarPtrPtr());
5753 if (checkIfPointerMap(memberClause) &&
5754 (!isDeclTargetTo ||
5755 (targetDirective != TargetDirectiveEnumTy::TargetUpdate &&
5756 targetDirective != TargetDirectiveEnumTy::TargetData))) {
5757 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
5758 }
5759
5760 combinedInfo.Types.emplace_back(mapFlag);
5761 combinedInfo.DevicePointers.emplace_back(
5762 mapData.DevicePointers[memberDataIdx]);
5763 combinedInfo.Mappers.emplace_back(mapData.Mappers[memberDataIdx]);
5764 combinedInfo.Names.emplace_back(
5765 LLVM::createMappingInformation(memberClause.getLoc(), ompBuilder));
5766 uint64_t basePointerIndex =
5767 checkIfPointerMap(memberClause) ? memberDataIdx : mapDataIndex;
5768 combinedInfo.BasePointers.emplace_back(
5769 mapData.BasePointers[basePointerIndex]);
5770 combinedInfo.Pointers.emplace_back(mapData.Pointers[memberDataIdx]);
5771
5772 llvm::Value *size = mapData.Sizes[memberDataIdx];
5773 if (checkIfPointerMap(memberClause)) {
5774 size = builder.CreateSelect(
5775 builder.CreateIsNull(mapData.Pointers[memberDataIdx]),
5776 builder.getInt64(0), size);
5777 }
5778
5779 combinedInfo.Sizes.emplace_back(size);
5780 }
5781}
5782
5783static void processIndividualMap(MapInfoData &mapData, size_t mapDataIdx,
5784 MapInfosTy &combinedInfo,
5785 TargetDirectiveEnumTy targetDirective,
5786 int mapDataParentIdx = -1) {
5787 // Declare Target Mappings are excluded from being marked as
5788 // OMP_MAP_TARGET_PARAM as they are not passed as parameters, they're
5789 // marked with OMP_MAP_PTR_AND_OBJ instead.
5790 auto mapFlag = mapData.Types[mapDataIdx];
5791 auto mapInfoOp = llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIdx]);
5792
5793 bool isPtrTy = checkIfPointerMap(mapInfoOp);
5794 if (isPtrTy)
5795 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
5796
5797 if (targetDirective == TargetDirectiveEnumTy::Target &&
5798 !mapData.IsDeclareTarget[mapDataIdx])
5799 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
5800
5801 if (mapInfoOp.getMapCaptureType() == omp::VariableCaptureKind::ByCopy &&
5802 !isPtrTy)
5803 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
5804
5805 // if we're provided a mapDataParentIdx, then the data being mapped is
5806 // part of a larger object (in a parent <-> member mapping) and in this
5807 // case our BasePointer should be the parent.
5808 if (mapDataParentIdx >= 0)
5809 combinedInfo.BasePointers.emplace_back(
5810 mapData.BasePointers[mapDataParentIdx]);
5811 else
5812 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIdx]);
5813
5814 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIdx]);
5815 combinedInfo.DevicePointers.emplace_back(mapData.DevicePointers[mapDataIdx]);
5816 combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIdx]);
5817 combinedInfo.Names.emplace_back(mapData.Names[mapDataIdx]);
5818 combinedInfo.Types.emplace_back(mapFlag);
5819 combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIdx]);
5820}
5821
5823 llvm::IRBuilderBase &builder,
5824 llvm::OpenMPIRBuilder &ompBuilder,
5825 DataLayout &dl, MapInfosTy &combinedInfo,
5826 MapInfoData &mapData, uint64_t mapDataIndex,
5827 TargetDirectiveEnumTy targetDirective) {
5828 assert(!ompBuilder.Config.isTargetDevice() &&
5829 "function only supported for host device codegen");
5830
5831 auto parentClause =
5832 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
5833
5834 // If we have a partial map (no parent referenced in the map clauses of the
5835 // directive, only members) and only a single member, we do not need to bind
5836 // the map of the member to the parent, we can pass the member separately.
5837 if (parentClause.getMembers().size() == 1 && parentClause.getPartialMap()) {
5838 auto memberClause = llvm::cast<omp::MapInfoOp>(
5839 parentClause.getMembers()[0].getDefiningOp());
5840 int memberDataIdx = getMapDataMemberIdx(mapData, memberClause);
5841 // Note: Clang treats arrays with explicit bounds that fall into this
5842 // category as a parent with map case, however, it seems this isn't a
5843 // requirement, and processing them as an individual map is fine. So,
5844 // we will handle them as individual maps for the moment, as it's
5845 // difficult for us to check this as we always require bounds to be
5846 // specified currently and it's also marginally more optimal (single
5847 // map rather than two). The difference may come from the fact that
5848 // Clang maps array without bounds as pointers (which we do not
5849 // currently do), whereas we treat them as arrays in all cases
5850 // currently.
5851 processIndividualMap(mapData, memberDataIdx, combinedInfo, targetDirective,
5852 mapDataIndex);
5853 return;
5854 }
5855
5856 llvm::omp::OpenMPOffloadMappingFlags memberOfParentFlag =
5857 mapParentWithMembers(moduleTranslation, builder, ompBuilder, dl,
5858 combinedInfo, mapData, mapDataIndex,
5859 targetDirective);
5860 processMapMembersWithParent(moduleTranslation, builder, ompBuilder, dl,
5861 combinedInfo, mapData, mapDataIndex,
5862 memberOfParentFlag, targetDirective);
5863}
5864
5865// This is a variation on Clang's GenerateOpenMPCapturedVars, which
5866// generates different operation (e.g. load/store) combinations for
5867// arguments to the kernel, based on map capture kinds which are then
5868// utilised in the combinedInfo in place of the original Map value.
5869static void
5870createAlteredByCaptureMap(MapInfoData &mapData,
5871 LLVM::ModuleTranslation &moduleTranslation,
5872 llvm::IRBuilderBase &builder) {
5873 assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
5874 "function only supported for host device codegen");
5875 for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
5876 // if it's declare target, skip it, it's handled separately.
5877 if (!mapData.IsDeclareTarget[i]) {
5878 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
5879 omp::VariableCaptureKind captureKind = mapOp.getMapCaptureType();
5880 bool isPtrTy = checkIfPointerMap(mapOp);
5881
5882 // Currently handles array sectioning lowerbound case, but more
5883 // logic may be required in the future. Clang invokes EmitLValue,
5884 // which has specialised logic for special Clang types such as user
5885 // defines, so it is possible we will have to extend this for
5886 // structures or other complex types. As the general idea is that this
5887 // function mimics some of the logic from Clang that we require for
5888 // kernel argument passing from host -> device.
5889 switch (captureKind) {
5890 case omp::VariableCaptureKind::ByRef: {
5891 llvm::Value *newV = mapData.Pointers[i];
5892 std::vector<llvm::Value *> offsetIdx = calculateBoundsOffset(
5893 moduleTranslation, builder, mapData.BaseType[i]->isArrayTy(),
5894 mapOp.getBounds());
5895 if (isPtrTy)
5896 newV = builder.CreateLoad(builder.getPtrTy(), newV);
5897
5898 if (!offsetIdx.empty())
5899 newV = builder.CreateInBoundsGEP(mapData.BaseType[i], newV, offsetIdx,
5900 "array_offset");
5901 mapData.Pointers[i] = newV;
5902 } break;
5903 case omp::VariableCaptureKind::ByCopy: {
5904 llvm::Type *type = mapData.BaseType[i];
5905 llvm::Value *newV;
5906 if (mapData.Pointers[i]->getType()->isPointerTy())
5907 newV = builder.CreateLoad(type, mapData.Pointers[i]);
5908 else
5909 newV = mapData.Pointers[i];
5910
5911 if (!isPtrTy) {
5912 auto curInsert = builder.saveIP();
5913 llvm::DebugLoc DbgLoc = builder.getCurrentDebugLocation();
5914 builder.restoreIP(findAllocaInsertPoint(builder, moduleTranslation));
5915 auto *memTempAlloc =
5916 builder.CreateAlloca(builder.getPtrTy(), nullptr, ".casted");
5917 builder.SetCurrentDebugLocation(DbgLoc);
5918 builder.restoreIP(curInsert);
5919
5920 builder.CreateStore(newV, memTempAlloc);
5921 newV = builder.CreateLoad(builder.getPtrTy(), memTempAlloc);
5922 }
5923
5924 mapData.Pointers[i] = newV;
5925 mapData.BasePointers[i] = newV;
5926 } break;
5927 case omp::VariableCaptureKind::This:
5928 case omp::VariableCaptureKind::VLAType:
5929 mapData.MapClause[i]->emitOpError("Unhandled capture kind");
5930 break;
5931 }
5932 }
5933 }
5934}
5935
5936// Generate all map related information and fill the combinedInfo.
5937static void genMapInfos(llvm::IRBuilderBase &builder,
5938 LLVM::ModuleTranslation &moduleTranslation,
5939 DataLayout &dl, MapInfosTy &combinedInfo,
5940 MapInfoData &mapData,
5941 TargetDirectiveEnumTy targetDirective) {
5942 assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
5943 "function only supported for host device codegen");
5944 // We wish to modify some of the methods in which arguments are
5945 // passed based on their capture type by the target region, this can
5946 // involve generating new loads and stores, which changes the
5947 // MLIR value to LLVM value mapping, however, we only wish to do this
5948 // locally for the current function/target and also avoid altering
5949 // ModuleTranslation, so we remap the base pointer or pointer stored
5950 // in the map infos corresponding MapInfoData, which is later accessed
5951 // by genMapInfos and createTarget to help generate the kernel and
5952 // kernel arg structure. It primarily becomes relevant in cases like
5953 // bycopy, or byref range'd arrays. In the default case, we simply
5954 // pass thee pointer byref as both basePointer and pointer.
5955 createAlteredByCaptureMap(mapData, moduleTranslation, builder);
5956
5957 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
5958
5959 // We operate under the assumption that all vectors that are
5960 // required in MapInfoData are of equal lengths (either filled with
5961 // default constructed data or appropiate information) so we can
5962 // utilise the size from any component of MapInfoData, if we can't
5963 // something is missing from the initial MapInfoData construction.
5964 for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
5965 // NOTE/TODO: We currently do not support arbitrary depth record
5966 // type mapping.
5967 if (mapData.IsAMember[i])
5968 continue;
5969
5970 auto mapInfoOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[i]);
5971 if (!mapInfoOp.getMembers().empty()) {
5972 processMapWithMembersOf(moduleTranslation, builder, *ompBuilder, dl,
5973 combinedInfo, mapData, i, targetDirective);
5974 continue;
5975 }
5976
5977 processIndividualMap(mapData, i, combinedInfo, targetDirective);
5978 }
5979}
5980
5981static llvm::Expected<llvm::Function *>
5982emitUserDefinedMapper(Operation *declMapperOp, llvm::IRBuilderBase &builder,
5983 LLVM::ModuleTranslation &moduleTranslation,
5984 llvm::StringRef mapperFuncName,
5985 TargetDirectiveEnumTy targetDirective);
5986
5987static llvm::Expected<llvm::Function *>
5988getOrCreateUserDefinedMapperFunc(Operation *op, llvm::IRBuilderBase &builder,
5989 LLVM::ModuleTranslation &moduleTranslation,
5990 TargetDirectiveEnumTy targetDirective) {
5991 assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
5992 "function only supported for host device codegen");
5993 auto declMapperOp = cast<omp::DeclareMapperOp>(op);
5994 std::string mapperFuncName =
5995 moduleTranslation.getOpenMPBuilder()->createPlatformSpecificName(
5996 {"omp_mapper", declMapperOp.getSymName()});
5997
5998 if (auto *lookupFunc = moduleTranslation.lookupFunction(mapperFuncName))
5999 return lookupFunc;
6000
6001 // Recursive types can cause re-entrant mapper emission. The mapper function
6002 // is created by OpenMPIRBuilder before the callbacks run, so it may already
6003 // exist in the LLVM module even though it is not yet registered in the
6004 // ModuleTranslation mapping table. Reuse and register it to break the
6005 // recursion.
6006 if (llvm::Function *existingFunc =
6007 moduleTranslation.getLLVMModule()->getFunction(mapperFuncName)) {
6008 moduleTranslation.mapFunction(mapperFuncName, existingFunc);
6009 return existingFunc;
6010 }
6011
6012 return emitUserDefinedMapper(declMapperOp, builder, moduleTranslation,
6013 mapperFuncName, targetDirective);
6014}
6015
6016static llvm::Expected<llvm::Function *>
6017emitUserDefinedMapper(Operation *op, llvm::IRBuilderBase &builder,
6018 LLVM::ModuleTranslation &moduleTranslation,
6019 llvm::StringRef mapperFuncName,
6020 TargetDirectiveEnumTy targetDirective) {
6021 assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
6022 "function only supported for host device codegen");
6023 auto declMapperOp = cast<omp::DeclareMapperOp>(op);
6024 auto declMapperInfoOp = declMapperOp.getDeclareMapperInfo();
6025 DataLayout dl = DataLayout(declMapperOp->getParentOfType<ModuleOp>());
6026 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
6027 llvm::Type *varType = moduleTranslation.convertType(declMapperOp.getType());
6028 SmallVector<Value> mapVars = declMapperInfoOp.getMapVars();
6029
6030 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
6031
6032 // Fill up the arrays with all the mapped variables.
6033 MapInfosTy combinedInfo;
6034 auto genMapInfoCB =
6035 [&](InsertPointTy codeGenIP, llvm::Value *ptrPHI,
6036 llvm::Value *unused2) -> llvm::OpenMPIRBuilder::MapInfosOrErrorTy {
6037 builder.restoreIP(codeGenIP);
6038 moduleTranslation.mapValue(declMapperOp.getSymVal(), ptrPHI);
6039 moduleTranslation.mapBlock(&declMapperOp.getRegion().front(),
6040 builder.GetInsertBlock());
6041 if (failed(moduleTranslation.convertBlock(declMapperOp.getRegion().front(),
6042 /*ignoreArguments=*/true,
6043 builder)))
6044 return llvm::make_error<PreviouslyReportedError>();
6045 MapInfoData mapData;
6046 collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, dl,
6047 builder);
6048 genMapInfos(builder, moduleTranslation, dl, combinedInfo, mapData,
6049 targetDirective);
6050
6051 // Drop the mapping that is no longer necessary so that the same region
6052 // can be processed multiple times.
6053 moduleTranslation.forgetMapping(declMapperOp.getRegion());
6054 return combinedInfo;
6055 };
6056
6057 auto customMapperCB = [&](unsigned i) -> llvm::Expected<llvm::Function *> {
6058 if (!combinedInfo.Mappers[i])
6059 return nullptr;
6060 return getOrCreateUserDefinedMapperFunc(combinedInfo.Mappers[i], builder,
6061 moduleTranslation, targetDirective);
6062 };
6063
6064 llvm::Expected<llvm::Function *> newFn = ompBuilder->emitUserDefinedMapper(
6065 genMapInfoCB, varType, mapperFuncName, customMapperCB);
6066 if (!newFn)
6067 return newFn.takeError();
6068 if ([[maybe_unused]] llvm::Function *mappedFunc =
6069 moduleTranslation.lookupFunction(mapperFuncName)) {
6070 assert(mappedFunc == *newFn &&
6071 "mapper function mapping disagrees with emitted function");
6072 } else {
6073 moduleTranslation.mapFunction(mapperFuncName, *newFn);
6074 }
6075 return *newFn;
6076}
6077
6078static LogicalResult
6079convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
6080 LLVM::ModuleTranslation &moduleTranslation) {
6081 llvm::Value *ifCond = nullptr;
6082 llvm::Value *deviceID = builder.getInt64(llvm::omp::OMP_DEVICEID_UNDEF);
6083 SmallVector<Value> mapVars;
6084 SmallVector<Value> useDevicePtrVars;
6085 SmallVector<Value> useDeviceAddrVars;
6086 llvm::omp::RuntimeFunction RTLFn;
6087 DataLayout DL = DataLayout(op->getParentOfType<ModuleOp>());
6088 TargetDirectiveEnumTy targetDirective = getTargetDirectiveEnumTyFromOp(op);
6089
6090 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
6091 llvm::OpenMPIRBuilder::TargetDataInfo info(
6092 /*RequiresDevicePointerInfo=*/true,
6093 /*SeparateBeginEndCalls=*/true);
6094 assert(!ompBuilder->Config.isTargetDevice() &&
6095 "target data/enter/exit/update are host ops");
6096 bool isOffloadEntry = !ompBuilder->Config.TargetTriples.empty();
6097
6098 auto getDeviceID = [&](mlir::Value dev) -> llvm::Value * {
6099 llvm::Value *v = moduleTranslation.lookupValue(dev);
6100 return builder.CreateIntCast(v, builder.getInt64Ty(), /*isSigned=*/true);
6101 };
6102
6103 LogicalResult result =
6105 .Case([&](omp::TargetDataOp dataOp) {
6106 if (failed(checkImplementationStatus(*dataOp)))
6107 return failure();
6108
6109 if (auto ifVar = dataOp.getIfExpr())
6110 ifCond = moduleTranslation.lookupValue(ifVar);
6111
6112 if (mlir::Value devId = dataOp.getDevice())
6113 deviceID = getDeviceID(devId);
6114
6115 mapVars = dataOp.getMapVars();
6116 useDevicePtrVars = dataOp.getUseDevicePtrVars();
6117 useDeviceAddrVars = dataOp.getUseDeviceAddrVars();
6118 return success();
6119 })
6120 .Case([&](omp::TargetEnterDataOp enterDataOp) -> LogicalResult {
6121 if (failed(checkImplementationStatus(*enterDataOp)))
6122 return failure();
6123
6124 if (auto ifVar = enterDataOp.getIfExpr())
6125 ifCond = moduleTranslation.lookupValue(ifVar);
6126
6127 if (mlir::Value devId = enterDataOp.getDevice())
6128 deviceID = getDeviceID(devId);
6129
6130 RTLFn =
6131 enterDataOp.getNowait()
6132 ? llvm::omp::OMPRTL___tgt_target_data_begin_nowait_mapper
6133 : llvm::omp::OMPRTL___tgt_target_data_begin_mapper;
6134 mapVars = enterDataOp.getMapVars();
6135 info.HasNoWait = enterDataOp.getNowait();
6136 return success();
6137 })
6138 .Case([&](omp::TargetExitDataOp exitDataOp) -> LogicalResult {
6139 if (failed(checkImplementationStatus(*exitDataOp)))
6140 return failure();
6141
6142 if (auto ifVar = exitDataOp.getIfExpr())
6143 ifCond = moduleTranslation.lookupValue(ifVar);
6144
6145 if (mlir::Value devId = exitDataOp.getDevice())
6146 deviceID = getDeviceID(devId);
6147
6148 RTLFn = exitDataOp.getNowait()
6149 ? llvm::omp::OMPRTL___tgt_target_data_end_nowait_mapper
6150 : llvm::omp::OMPRTL___tgt_target_data_end_mapper;
6151 mapVars = exitDataOp.getMapVars();
6152 info.HasNoWait = exitDataOp.getNowait();
6153 return success();
6154 })
6155 .Case([&](omp::TargetUpdateOp updateDataOp) -> LogicalResult {
6156 if (failed(checkImplementationStatus(*updateDataOp)))
6157 return failure();
6158
6159 if (auto ifVar = updateDataOp.getIfExpr())
6160 ifCond = moduleTranslation.lookupValue(ifVar);
6161
6162 if (mlir::Value devId = updateDataOp.getDevice())
6163 deviceID = getDeviceID(devId);
6164
6165 RTLFn =
6166 updateDataOp.getNowait()
6167 ? llvm::omp::OMPRTL___tgt_target_data_update_nowait_mapper
6168 : llvm::omp::OMPRTL___tgt_target_data_update_mapper;
6169 mapVars = updateDataOp.getMapVars();
6170 info.HasNoWait = updateDataOp.getNowait();
6171 return success();
6172 })
6173 .DefaultUnreachable("unexpected operation");
6174
6175 if (failed(result))
6176 return failure();
6177 // Pretend we have IF(false) if we're not doing offload.
6178 if (!isOffloadEntry)
6179 ifCond = builder.getFalse();
6180
6181 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
6182 MapInfoData mapData;
6183 collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, DL,
6184 builder, useDevicePtrVars, useDeviceAddrVars);
6185
6186 // Fill up the arrays with all the mapped variables.
6187 MapInfosTy combinedInfo;
6188 auto genMapInfoCB = [&](InsertPointTy codeGenIP) -> MapInfosTy & {
6189 builder.restoreIP(codeGenIP);
6190 genMapInfos(builder, moduleTranslation, DL, combinedInfo, mapData,
6191 targetDirective);
6192 return combinedInfo;
6193 };
6194
6195 // Define a lambda to apply mappings between use_device_addr and
6196 // use_device_ptr base pointers, and their associated block arguments.
6197 auto mapUseDevice =
6198 [&moduleTranslation](
6199 llvm::OpenMPIRBuilder::DeviceInfoTy type,
6201 llvm::SmallVectorImpl<Value> &useDeviceVars, MapInfoData &mapInfoData,
6202 llvm::function_ref<llvm::Value *(llvm::Value *)> mapper = nullptr) {
6203 for (auto [arg, useDevVar] :
6204 llvm::zip_equal(blockArgs, useDeviceVars)) {
6205
6206 auto getMapBasePtr = [](omp::MapInfoOp mapInfoOp) {
6207 return mapInfoOp.getVarPtrPtr() ? mapInfoOp.getVarPtrPtr()
6208 : mapInfoOp.getVarPtr();
6209 };
6210
6211 auto useDevMap = cast<omp::MapInfoOp>(useDevVar.getDefiningOp());
6212 for (auto [mapClause, devicePointer, basePointer] : llvm::zip_equal(
6213 mapInfoData.MapClause, mapInfoData.DevicePointers,
6214 mapInfoData.BasePointers)) {
6215 auto mapOp = cast<omp::MapInfoOp>(mapClause);
6216 if (getMapBasePtr(mapOp) != getMapBasePtr(useDevMap) ||
6217 devicePointer != type)
6218 continue;
6219
6220 if (llvm::Value *devPtrInfoMap =
6221 mapper ? mapper(basePointer) : basePointer) {
6222 moduleTranslation.mapValue(arg, devPtrInfoMap);
6223 break;
6224 }
6225 }
6226 }
6227 };
6228
6229 using BodyGenTy = llvm::OpenMPIRBuilder::BodyGenTy;
6230 auto bodyGenCB = [&](InsertPointTy codeGenIP, BodyGenTy bodyGenType)
6231 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
6232 // We must always restoreIP regardless of doing anything the caller
6233 // does not restore it, leading to incorrect (no) branch generation.
6234 builder.restoreIP(codeGenIP);
6235 assert(isa<omp::TargetDataOp>(op) &&
6236 "BodyGen requested for non TargetDataOp");
6237 auto blockArgIface = cast<omp::BlockArgOpenMPOpInterface>(op);
6238 Region &region = cast<omp::TargetDataOp>(op).getRegion();
6239 switch (bodyGenType) {
6240 case BodyGenTy::Priv:
6241 // Check if any device ptr/addr info is available
6242 if (!info.DevicePtrInfoMap.empty()) {
6243 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
6244 blockArgIface.getUseDeviceAddrBlockArgs(),
6245 useDeviceAddrVars, mapData,
6246 [&](llvm::Value *basePointer) -> llvm::Value * {
6247 if (!info.DevicePtrInfoMap[basePointer].second)
6248 return nullptr;
6249 return builder.CreateLoad(
6250 builder.getPtrTy(),
6251 info.DevicePtrInfoMap[basePointer].second);
6252 });
6253 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
6254 blockArgIface.getUseDevicePtrBlockArgs(), useDevicePtrVars,
6255 mapData, [&](llvm::Value *basePointer) {
6256 return info.DevicePtrInfoMap[basePointer].second;
6257 });
6258
6259 if (failed(inlineConvertOmpRegions(region, "omp.data.region", builder,
6260 moduleTranslation)))
6261 return llvm::make_error<PreviouslyReportedError>();
6262 }
6263 break;
6264 case BodyGenTy::DupNoPriv:
6265 if (info.DevicePtrInfoMap.empty()) {
6266 // For host device we still need to do the mapping for codegen,
6267 // otherwise it may try to lookup a missing value.
6268 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
6269 blockArgIface.getUseDeviceAddrBlockArgs(),
6270 useDeviceAddrVars, mapData);
6271 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
6272 blockArgIface.getUseDevicePtrBlockArgs(), useDevicePtrVars,
6273 mapData);
6274 }
6275 break;
6276 case BodyGenTy::NoPriv:
6277 // If device info is available then region has already been generated
6278 if (info.DevicePtrInfoMap.empty()) {
6279 if (failed(inlineConvertOmpRegions(region, "omp.data.region", builder,
6280 moduleTranslation)))
6281 return llvm::make_error<PreviouslyReportedError>();
6282 }
6283 break;
6284 }
6285 return builder.saveIP();
6286 };
6287
6288 auto customMapperCB =
6289 [&](unsigned int i) -> llvm::Expected<llvm::Function *> {
6290 if (!combinedInfo.Mappers[i])
6291 return nullptr;
6292 info.HasMapper = true;
6293 return getOrCreateUserDefinedMapperFunc(combinedInfo.Mappers[i], builder,
6294 moduleTranslation, targetDirective);
6295 };
6296
6297 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
6298 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
6299 findAllocaInsertPoint(builder, moduleTranslation);
6300 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = [&]() {
6301 if (isa<omp::TargetDataOp>(op))
6302 return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(),
6303 deviceID, ifCond, info, genMapInfoCB,
6304 customMapperCB,
6305 /*MapperFunc=*/nullptr, bodyGenCB,
6306 /*DeviceAddrCB=*/nullptr);
6307 return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(),
6308 deviceID, ifCond, info, genMapInfoCB,
6309 customMapperCB, &RTLFn);
6310 }();
6311
6312 if (failed(handleError(afterIP, *op)))
6313 return failure();
6314
6315 builder.restoreIP(*afterIP);
6316 return success();
6317}
6318
6319static LogicalResult
6320convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder,
6321 LLVM::ModuleTranslation &moduleTranslation) {
6322 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
6323 auto distributeOp = cast<omp::DistributeOp>(opInst);
6324 if (failed(checkImplementationStatus(opInst)))
6325 return failure();
6326
6327 /// Process teams op reduction in distribute if the reduction is contained in
6328 /// this specific distribute op.
6329 omp::TeamsOp teamsOp = opInst.getParentOfType<omp::TeamsOp>();
6330 bool doDistributeReduction =
6331 teamsOp && getDistributeCapturingTeamsReduction(teamsOp) == distributeOp;
6332
6333 DenseMap<Value, llvm::Value *> reductionVariableMap;
6334 unsigned numReductionVars = teamsOp ? teamsOp.getNumReductionVars() : 0;
6336 SmallVector<llvm::Value *> privateReductionVariables(numReductionVars);
6337 llvm::ArrayRef<bool> isByRef;
6338
6339 if (doDistributeReduction) {
6340 isByRef = getIsByRef(teamsOp.getReductionByref());
6341 assert(isByRef.size() == teamsOp.getNumReductionVars());
6342
6343 collectReductionDecls(teamsOp, reductionDecls);
6344 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
6345 findAllocaInsertPoint(builder, moduleTranslation);
6346
6347 MutableArrayRef<BlockArgument> reductionArgs =
6348 llvm::cast<omp::BlockArgOpenMPOpInterface>(*teamsOp)
6349 .getReductionBlockArgs();
6350
6352 teamsOp, reductionArgs, builder, moduleTranslation, allocaIP,
6353 reductionDecls, privateReductionVariables, reductionVariableMap,
6354 isByRef)))
6355 return failure();
6356 }
6357
6358 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
6359 auto bodyGenCB = [&](InsertPointTy allocaIP,
6360 InsertPointTy codeGenIP) -> llvm::Error {
6361 // Save the alloca insertion point on ModuleTranslation stack for use in
6362 // nested regions.
6364 moduleTranslation, allocaIP);
6365
6366 // DistributeOp has only one region associated with it.
6367 builder.restoreIP(codeGenIP);
6368 PrivateVarsInfo privVarsInfo(distributeOp);
6369
6371 allocatePrivateVars(builder, moduleTranslation, privVarsInfo, allocaIP);
6372 if (handleError(afterAllocas, opInst).failed())
6373 return llvm::make_error<PreviouslyReportedError>();
6374
6375 if (handleError(initPrivateVars(builder, moduleTranslation, privVarsInfo),
6376 opInst)
6377 .failed())
6378 return llvm::make_error<PreviouslyReportedError>();
6379
6380 if (failed(copyFirstPrivateVars(
6381 distributeOp, builder, moduleTranslation, privVarsInfo.mlirVars,
6382 privVarsInfo.llvmVars, privVarsInfo.privatizers,
6383 distributeOp.getPrivateNeedsBarrier())))
6384 return llvm::make_error<PreviouslyReportedError>();
6385
6386 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
6387 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
6389 convertOmpOpRegions(distributeOp.getRegion(), "omp.distribute.region",
6390 builder, moduleTranslation);
6391 if (!regionBlock)
6392 return regionBlock.takeError();
6393 builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
6394
6395 // Skip applying a workshare loop below when translating 'distribute
6396 // parallel do' (it's been already handled by this point while translating
6397 // the nested omp.wsloop).
6398 if (!isa_and_present<omp::WsloopOp>(distributeOp.getNestedWrapper())) {
6399 // TODO: Add support for clauses which are valid for DISTRIBUTE
6400 // constructs. Static schedule is the default.
6401 bool hasDistSchedule = distributeOp.getDistScheduleStatic();
6402 auto schedule = hasDistSchedule ? omp::ClauseScheduleKind::Distribute
6403 : omp::ClauseScheduleKind::Static;
6404 // dist_schedule clauses are ordered - otherise this should be false
6405 bool isOrdered = hasDistSchedule;
6406 std::optional<omp::ScheduleModifier> scheduleMod;
6407 bool isSimd = false;
6408 llvm::omp::WorksharingLoopType workshareLoopType =
6409 llvm::omp::WorksharingLoopType::DistributeStaticLoop;
6410 bool loopNeedsBarrier = false;
6411 llvm::Value *chunk = moduleTranslation.lookupValue(
6412 distributeOp.getDistScheduleChunkSize());
6413 llvm::CanonicalLoopInfo *loopInfo =
6414 findCurrentLoopInfo(moduleTranslation);
6415 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
6416 ompBuilder->applyWorkshareLoop(
6417 ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
6418 convertToScheduleKind(schedule), chunk, isSimd,
6419 scheduleMod == omp::ScheduleModifier::monotonic,
6420 scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
6421 workshareLoopType, false, hasDistSchedule, chunk);
6422
6423 if (!wsloopIP)
6424 return wsloopIP.takeError();
6425 }
6426 if (failed(cleanupPrivateVars(builder, moduleTranslation,
6427 distributeOp.getLoc(), privVarsInfo.llvmVars,
6428 privVarsInfo.privatizers)))
6429 return llvm::make_error<PreviouslyReportedError>();
6430
6431 return llvm::Error::success();
6432 };
6433
6434 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
6435 findAllocaInsertPoint(builder, moduleTranslation);
6436 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
6437 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
6438 ompBuilder->createDistribute(ompLoc, allocaIP, bodyGenCB);
6439
6440 if (failed(handleError(afterIP, opInst)))
6441 return failure();
6442
6443 builder.restoreIP(*afterIP);
6444
6445 if (doDistributeReduction) {
6446 // Process the reductions if required.
6448 teamsOp, builder, moduleTranslation, allocaIP, reductionDecls,
6449 privateReductionVariables, isByRef,
6450 /*isNoWait*/ false, /*isTeamsReduction*/ true);
6451 }
6452 return success();
6453}
6454
6455/// Lowers the FlagsAttr which is applied to the module on the device
6456/// pass when offloading, this attribute contains OpenMP RTL globals that can
6457/// be passed as flags to the frontend, otherwise they are set to default
6458static LogicalResult
6459convertFlagsAttr(Operation *op, mlir::omp::FlagsAttr attribute,
6460 LLVM::ModuleTranslation &moduleTranslation) {
6461 if (!cast<mlir::ModuleOp>(op))
6462 return failure();
6463
6464 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
6465
6466 ompBuilder->M.addModuleFlag(llvm::Module::Max, "openmp-device",
6467 attribute.getOpenmpDeviceVersion());
6468
6469 if (attribute.getNoGpuLib())
6470 return success();
6471
6472 ompBuilder->createGlobalFlag(
6473 attribute.getDebugKind() /*LangOpts().OpenMPTargetDebug*/,
6474 "__omp_rtl_debug_kind");
6475 ompBuilder->createGlobalFlag(
6476 attribute
6477 .getAssumeTeamsOversubscription() /*LangOpts().OpenMPTeamSubscription*/
6478 ,
6479 "__omp_rtl_assume_teams_oversubscription");
6480 ompBuilder->createGlobalFlag(
6481 attribute
6482 .getAssumeThreadsOversubscription() /*LangOpts().OpenMPThreadSubscription*/
6483 ,
6484 "__omp_rtl_assume_threads_oversubscription");
6485 ompBuilder->createGlobalFlag(
6486 attribute.getAssumeNoThreadState() /*LangOpts().OpenMPNoThreadState*/,
6487 "__omp_rtl_assume_no_thread_state");
6488 ompBuilder->createGlobalFlag(
6489 attribute
6490 .getAssumeNoNestedParallelism() /*LangOpts().OpenMPNoNestedParallelism*/
6491 ,
6492 "__omp_rtl_assume_no_nested_parallelism");
6493 return success();
6494}
6495
6496static void getTargetEntryUniqueInfo(llvm::TargetRegionEntryInfo &targetInfo,
6497 omp::TargetOp targetOp,
6498 llvm::StringRef parentName = "") {
6499 auto fileLoc = targetOp.getLoc()->findInstanceOf<FileLineColLoc>();
6500
6501 assert(fileLoc && "No file found from location");
6502 StringRef fileName = fileLoc.getFilename().getValue();
6503
6504 llvm::sys::fs::UniqueID id;
6505 uint64_t line = fileLoc.getLine();
6506 if (auto ec = llvm::sys::fs::getUniqueID(fileName, id)) {
6507 size_t fileHash = llvm::hash_value(fileName.str());
6508 size_t deviceId = 0xdeadf17e;
6509 targetInfo =
6510 llvm::TargetRegionEntryInfo(parentName, deviceId, fileHash, line);
6511 } else {
6512 targetInfo = llvm::TargetRegionEntryInfo(parentName, id.getDevice(),
6513 id.getFile(), line);
6514 }
6515}
6516
6517static void
6518handleDeclareTargetMapVar(MapInfoData &mapData,
6519 LLVM::ModuleTranslation &moduleTranslation,
6520 llvm::IRBuilderBase &builder, llvm::Function *func) {
6521 assert(moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
6522 "function only supported for target device codegen");
6523 llvm::IRBuilderBase::InsertPointGuard guard(builder);
6524 for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
6525 // In the case of declare target mapped variables, the basePointer is
6526 // the reference pointer generated by the convertDeclareTargetAttr
6527 // method. Whereas the kernelValue is the original variable, so for
6528 // the device we must replace all uses of this original global variable
6529 // (stored in kernelValue) with the reference pointer (stored in
6530 // basePointer for declare target mapped variables), as for device the
6531 // data is mapped into this reference pointer and should be loaded
6532 // from it, the original variable is discarded. On host both exist and
6533 // metadata is generated (elsewhere in the convertDeclareTargetAttr)
6534 // function to link the two variables in the runtime and then both the
6535 // reference pointer and the pointer are assigned in the kernel argument
6536 // structure for the host.
6537 if (mapData.IsDeclareTarget[i]) {
6538 // If the original map value is a constant, then we have to make sure all
6539 // of it's uses within the current kernel/function that we are going to
6540 // rewrite are converted to instructions, as we will be altering the old
6541 // use (OriginalValue) from a constant to an instruction, which will be
6542 // illegal and ICE the compiler if the user is a constant expression of
6543 // some kind e.g. a constant GEP.
6544 if (auto *constant = dyn_cast<llvm::Constant>(mapData.OriginalValue[i]))
6545 convertUsersOfConstantsToInstructions(constant, func, false);
6546
6547 // The users iterator will get invalidated if we modify an element,
6548 // so we populate this vector of uses to alter each user on an
6549 // individual basis to emit its own load (rather than one load for
6550 // all).
6552 for (llvm::User *user : mapData.OriginalValue[i]->users())
6553 userVec.push_back(user);
6554
6555 for (llvm::User *user : userVec) {
6556 if (auto *insn = dyn_cast<llvm::Instruction>(user)) {
6557 if (insn->getFunction() == func) {
6558 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
6559 llvm::Value *substitute = mapData.BasePointers[i];
6560 if (isDeclareTargetLink(mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr()
6561 : mapOp.getVarPtr())) {
6562 builder.SetCurrentDebugLocation(insn->getDebugLoc());
6563 substitute = builder.CreateLoad(
6564 mapData.BasePointers[i]->getType(), mapData.BasePointers[i]);
6565 cast<llvm::LoadInst>(substitute)->moveBefore(insn->getIterator());
6566 }
6567 user->replaceUsesOfWith(mapData.OriginalValue[i], substitute);
6568 }
6569 }
6570 }
6571 }
6572 }
6573}
6574
6575// The createDeviceArgumentAccessor function generates
6576// instructions for retrieving (acessing) kernel
6577// arguments inside of the device kernel for use by
6578// the kernel. This enables different semantics such as
6579// the creation of temporary copies of data allowing
6580// semantics like read-only/no host write back kernel
6581// arguments.
6582//
6583// This currently implements a very light version of Clang's
6584// EmitParmDecl's handling of direct argument handling as well
6585// as a portion of the argument access generation based on
6586// capture types found at the end of emitOutlinedFunctionPrologue
6587// in Clang. The indirect path handling of EmitParmDecl's may be
6588// required for future work, but a direct 1-to-1 copy doesn't seem
6589// possible as the logic is rather scattered throughout Clang's
6590// lowering and perhaps we wish to deviate slightly.
6591//
6592// \param mapData - A container containing vectors of information
6593// corresponding to the input argument, which should have a
6594// corresponding entry in the MapInfoData containers
6595// OrigialValue's.
6596// \param arg - This is the generated kernel function argument that
6597// corresponds to the passed in input argument. We generated different
6598// accesses of this Argument, based on capture type and other Input
6599// related information.
6600// \param input - This is the host side value that will be passed to
6601// the kernel i.e. the kernel input, we rewrite all uses of this within
6602// the kernel (as we generate the kernel body based on the target's region
6603// which maintians references to the original input) to the retVal argument
6604// apon exit of this function inside of the OMPIRBuilder. This interlinks
6605// the kernel argument to future uses of it in the function providing
6606// appropriate "glue" instructions inbetween.
6607// \param retVal - This is the value that all uses of input inside of the
6608// kernel will be re-written to, the goal of this function is to generate
6609// an appropriate location for the kernel argument to be accessed from,
6610// e.g. ByRef will result in a temporary allocation location and then
6611// a store of the kernel argument into this allocated memory which
6612// will then be loaded from, ByCopy will use the allocated memory
6613// directly.
6614static llvm::IRBuilderBase::InsertPoint
6615createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg,
6616 llvm::Value *input, llvm::Value *&retVal,
6617 llvm::IRBuilderBase &builder,
6618 llvm::OpenMPIRBuilder &ompBuilder,
6619 LLVM::ModuleTranslation &moduleTranslation,
6620 llvm::IRBuilderBase::InsertPoint allocaIP,
6621 llvm::IRBuilderBase::InsertPoint codeGenIP) {
6622 assert(ompBuilder.Config.isTargetDevice() &&
6623 "function only supported for target device codegen");
6624 builder.restoreIP(allocaIP);
6625
6626 omp::VariableCaptureKind capture = omp::VariableCaptureKind::ByRef;
6627 LLVM::TypeToLLVMIRTranslator typeToLLVMIRTranslator(
6628 ompBuilder.M.getContext());
6629 unsigned alignmentValue = 0;
6630 // Find the associated MapInfoData entry for the current input
6631 for (size_t i = 0; i < mapData.MapClause.size(); ++i)
6632 if (mapData.OriginalValue[i] == input) {
6633 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
6634 capture = mapOp.getMapCaptureType();
6635 // Get information of alignment of mapped object
6636 alignmentValue = typeToLLVMIRTranslator.getPreferredAlignment(
6637 mapOp.getVarType(), ompBuilder.M.getDataLayout());
6638 break;
6639 }
6640
6641 unsigned int allocaAS = ompBuilder.M.getDataLayout().getAllocaAddrSpace();
6642 unsigned int defaultAS =
6643 ompBuilder.M.getDataLayout().getProgramAddressSpace();
6644
6645 // Create the alloca for the argument the current point.
6646 llvm::Value *v = builder.CreateAlloca(arg.getType(), allocaAS);
6647
6648 if (allocaAS != defaultAS && arg.getType()->isPointerTy())
6649 v = builder.CreateAddrSpaceCast(v, builder.getPtrTy(defaultAS));
6650
6651 builder.CreateStore(&arg, v);
6652
6653 builder.restoreIP(codeGenIP);
6654
6655 switch (capture) {
6656 case omp::VariableCaptureKind::ByCopy: {
6657 retVal = v;
6658 break;
6659 }
6660 case omp::VariableCaptureKind::ByRef: {
6661 llvm::LoadInst *loadInst = builder.CreateAlignedLoad(
6662 v->getType(), v,
6663 ompBuilder.M.getDataLayout().getPrefTypeAlign(v->getType()));
6664 // CreateAlignedLoad function creates similar LLVM IR:
6665 // %res = load ptr, ptr %input, align 8
6666 // This LLVM IR does not contain information about alignment
6667 // of the loaded value. We need to add !align metadata to unblock
6668 // optimizer. The existence of the !align metadata on the instruction
6669 // tells the optimizer that the value loaded is known to be aligned to
6670 // a boundary specified by the integer value in the metadata node.
6671 // Example:
6672 // %res = load ptr, ptr %input, align 8, !align !align_md_node
6673 // ^ ^
6674 // | |
6675 // alignment of %input address |
6676 // |
6677 // alignment of %res object
6678 if (v->getType()->isPointerTy() && alignmentValue) {
6679 llvm::MDBuilder MDB(builder.getContext());
6680 loadInst->setMetadata(
6681 llvm::LLVMContext::MD_align,
6682 llvm::MDNode::get(builder.getContext(),
6683 MDB.createConstant(llvm::ConstantInt::get(
6684 llvm::Type::getInt64Ty(builder.getContext()),
6685 alignmentValue))));
6686 }
6687 retVal = loadInst;
6688
6689 break;
6690 }
6691 case omp::VariableCaptureKind::This:
6692 case omp::VariableCaptureKind::VLAType:
6693 // TODO: Consider returning error to use standard reporting for
6694 // unimplemented features.
6695 assert(false && "Currently unsupported capture kind");
6696 break;
6697 }
6698
6699 return builder.saveIP();
6700}
6701
6702/// Follow uses of `host_eval`-defined block arguments of the given `omp.target`
6703/// operation and populate output variables with their corresponding host value
6704/// (i.e. operand evaluated outside of the target region), based on their uses
6705/// inside of the target region.
6706///
6707/// Loop bounds and steps are only optionally populated, if output vectors are
6708/// provided.
6709static void
6710extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
6711 Value &numTeamsLower, Value &numTeamsUpper,
6712 Value &threadLimit,
6713 llvm::SmallVectorImpl<Value> *lowerBounds = nullptr,
6714 llvm::SmallVectorImpl<Value> *upperBounds = nullptr,
6715 llvm::SmallVectorImpl<Value> *steps = nullptr) {
6716 auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
6717 for (auto item : llvm::zip_equal(targetOp.getHostEvalVars(),
6718 blockArgIface.getHostEvalBlockArgs())) {
6719 Value hostEvalVar = std::get<0>(item), blockArg = std::get<1>(item);
6720
6721 for (Operation *user : blockArg.getUsers()) {
6723 .Case([&](omp::TeamsOp teamsOp) {
6724 if (teamsOp.getNumTeamsLower() == blockArg)
6725 numTeamsLower = hostEvalVar;
6726 else if (llvm::is_contained(teamsOp.getNumTeamsUpperVars(),
6727 blockArg))
6728 numTeamsUpper = hostEvalVar;
6729 else if (!teamsOp.getThreadLimitVars().empty() &&
6730 teamsOp.getThreadLimit(0) == blockArg)
6731 threadLimit = hostEvalVar;
6732 else
6733 llvm_unreachable("unsupported host_eval use");
6734 })
6735 .Case([&](omp::ParallelOp parallelOp) {
6736 if (!parallelOp.getNumThreadsVars().empty() &&
6737 parallelOp.getNumThreads(0) == blockArg)
6738 numThreads = hostEvalVar;
6739 else
6740 llvm_unreachable("unsupported host_eval use");
6741 })
6742 .Case([&](omp::LoopNestOp loopOp) {
6743 auto processBounds =
6744 [&](OperandRange opBounds,
6745 llvm::SmallVectorImpl<Value> *outBounds) -> bool {
6746 bool found = false;
6747 for (auto [i, lb] : llvm::enumerate(opBounds)) {
6748 if (lb == blockArg) {
6749 found = true;
6750 if (outBounds)
6751 (*outBounds)[i] = hostEvalVar;
6752 }
6753 }
6754 return found;
6755 };
6756 bool found =
6757 processBounds(loopOp.getLoopLowerBounds(), lowerBounds);
6758 found = processBounds(loopOp.getLoopUpperBounds(), upperBounds) ||
6759 found;
6760 found = processBounds(loopOp.getLoopSteps(), steps) || found;
6761 (void)found;
6762 assert(found && "unsupported host_eval use");
6763 })
6764 .DefaultUnreachable("unsupported host_eval use");
6765 }
6766 }
6767}
6768
6769/// If \p op is of the given type parameter, return it casted to that type.
6770/// Otherwise, if its immediate parent operation (or some other higher-level
6771/// parent, if \p immediateParent is false) is of that type, return that parent
6772/// casted to the given type.
6773///
6774/// If \p op is \c null or neither it or its parent(s) are of the specified
6775/// type, return a \c null operation.
6776template <typename OpTy>
6777static OpTy castOrGetParentOfType(Operation *op, bool immediateParent = false) {
6778 if (!op)
6779 return OpTy();
6780
6781 if (OpTy casted = dyn_cast<OpTy>(op))
6782 return casted;
6783
6784 if (immediateParent)
6785 return dyn_cast_if_present<OpTy>(op->getParentOp());
6786
6787 return op->getParentOfType<OpTy>();
6788}
6789
6790/// If the given \p value is defined by an \c llvm.mlir.constant operation and
6791/// it is of an integer type, return its value.
6792static std::optional<int64_t> extractConstInteger(Value value) {
6793 if (!value)
6794 return std::nullopt;
6795
6796 if (auto constOp = value.getDefiningOp<LLVM::ConstantOp>())
6797 if (auto constAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
6798 return constAttr.getInt();
6799
6800 return std::nullopt;
6801}
6802
6803static uint64_t getTypeByteSize(mlir::Type type, const DataLayout &dl) {
6804 uint64_t sizeInBits = dl.getTypeSizeInBits(type);
6805 uint64_t sizeInBytes = sizeInBits / 8;
6806 return sizeInBytes;
6807}
6808
6809template <typename OpTy>
6810static uint64_t getReductionDataSize(OpTy &op) {
6811 if (op.getNumReductionVars() > 0) {
6813 collectReductionDecls(op, reductions);
6814
6816 members.reserve(reductions.size());
6817 for (omp::DeclareReductionOp &red : reductions) {
6818 // For by-ref reductions, use the actual element type rather than the
6819 // pointer type so that the buffer size matches the access pattern in
6820 // the copy/reduce callbacks generated by OMPIRBuilder.
6821 if (red.getByrefElementType())
6822 members.push_back(*red.getByrefElementType());
6823 else
6824 members.push_back(red.getType());
6825 }
6826 Operation *opp = op.getOperation();
6827 auto structType = mlir::LLVM::LLVMStructType::getLiteral(
6828 opp->getContext(), members, /*isPacked=*/false);
6829 DataLayout dl = DataLayout(opp->getParentOfType<ModuleOp>());
6830 return getTypeByteSize(structType, dl);
6831 }
6832 return 0;
6833}
6834
6835/// Populate default `MinTeams`, `MaxTeams` and `MaxThreads` to their default
6836/// values as stated by the corresponding clauses, if constant.
6837///
6838/// These default values must be set before the creation of the outlined LLVM
6839/// function for the target region, so that they can be used to initialize the
6840/// corresponding global `ConfigurationEnvironmentTy` structure.
6841static void
6842initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
6843 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs,
6844 bool isTargetDevice, bool isGPU) {
6845 // TODO: Handle constant 'if' clauses.
6846
6847 Value numThreads, numTeamsLower, numTeamsUpper, threadLimit;
6848 if (!isTargetDevice) {
6849 extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper,
6850 threadLimit);
6851 } else {
6852 // In the target device, values for these clauses are not passed as
6853 // host_eval, but instead evaluated prior to entry to the region. This
6854 // ensures values are mapped and available inside of the target region.
6855 if (auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
6856 numTeamsLower = teamsOp.getNumTeamsLower();
6857 // Handle num_teams upper bounds (only first value for now)
6858 if (!teamsOp.getNumTeamsUpperVars().empty())
6859 numTeamsUpper = teamsOp.getNumTeams(0);
6860 if (!teamsOp.getThreadLimitVars().empty())
6861 threadLimit = teamsOp.getThreadLimit(0);
6862 }
6863
6864 if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) {
6865 if (!parallelOp.getNumThreadsVars().empty())
6866 numThreads = parallelOp.getNumThreads(0);
6867 }
6868 }
6869
6870 // Handle clauses impacting the number of teams.
6871
6872 int32_t minTeamsVal = 1, maxTeamsVal = -1;
6873 if (castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
6874 // TODO: Use `hostNumTeamsLower` to initialize `minTeamsVal`. For now,
6875 // match clang and set min and max to the same value.
6876 if (numTeamsUpper) {
6877 if (auto val = extractConstInteger(numTeamsUpper))
6878 minTeamsVal = maxTeamsVal = *val;
6879 } else {
6880 minTeamsVal = maxTeamsVal = 0;
6881 }
6882 } else if (castOrGetParentOfType<omp::ParallelOp>(capturedOp,
6883 /*immediateParent=*/true) ||
6885 /*immediateParent=*/true)) {
6886 minTeamsVal = maxTeamsVal = 1;
6887 } else {
6888 minTeamsVal = maxTeamsVal = -1;
6889 }
6890
6891 // Handle clauses impacting the number of threads.
6892
6893 auto setMaxValueFromClause = [](Value clauseValue, int32_t &result) {
6894 if (!clauseValue)
6895 return;
6896
6897 if (auto val = extractConstInteger(clauseValue))
6898 result = *val;
6899
6900 // Found an applicable clause, so it's not undefined. Mark as unknown
6901 // because it's not constant.
6902 if (result < 0)
6903 result = 0;
6904 };
6905
6906 // Extract 'thread_limit' clause from 'target' and 'teams' directives.
6907 int32_t targetThreadLimitVal = -1, teamsThreadLimitVal = -1;
6908 if (!targetOp.getThreadLimitVars().empty())
6909 setMaxValueFromClause(targetOp.getThreadLimit(0), targetThreadLimitVal);
6910 setMaxValueFromClause(threadLimit, teamsThreadLimitVal);
6911
6912 // Extract 'max_threads' clause from 'parallel' or set to 1 if it's SIMD.
6913 int32_t maxThreadsVal = -1;
6915 setMaxValueFromClause(numThreads, maxThreadsVal);
6916 else if (castOrGetParentOfType<omp::SimdOp>(capturedOp,
6917 /*immediateParent=*/true))
6918 maxThreadsVal = 1;
6919
6920 // For max values, < 0 means unset, == 0 means set but unknown. Select the
6921 // minimum value between 'max_threads' and 'thread_limit' clauses that were
6922 // set.
6923 int32_t combinedMaxThreadsVal = targetThreadLimitVal;
6924 if (combinedMaxThreadsVal < 0 ||
6925 (teamsThreadLimitVal >= 0 && teamsThreadLimitVal < combinedMaxThreadsVal))
6926 combinedMaxThreadsVal = teamsThreadLimitVal;
6927
6928 if (combinedMaxThreadsVal < 0 ||
6929 (maxThreadsVal >= 0 && maxThreadsVal < combinedMaxThreadsVal))
6930 combinedMaxThreadsVal = maxThreadsVal;
6931
6932 int32_t reductionDataSize = 0;
6933 if (isGPU && capturedOp) {
6934 if (auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp))
6935 reductionDataSize = getReductionDataSize(teamsOp);
6936 }
6937
6938 // Update kernel bounds structure for the `OpenMPIRBuilder` to use.
6939 omp::TargetRegionFlags kernelFlags = targetOp.getKernelExecFlags(capturedOp);
6940 assert(
6941 omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic |
6942 omp::TargetRegionFlags::spmd) &&
6943 "invalid kernel flags");
6944 attrs.ExecFlags =
6945 omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic)
6946 ? omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::spmd)
6947 ? llvm::omp::OMP_TGT_EXEC_MODE_GENERIC_SPMD
6948 : llvm::omp::OMP_TGT_EXEC_MODE_GENERIC
6949 : llvm::omp::OMP_TGT_EXEC_MODE_SPMD;
6950 if (omp::bitEnumContainsAll(kernelFlags,
6951 omp::TargetRegionFlags::spmd |
6952 omp::TargetRegionFlags::no_loop) &&
6953 !omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic))
6954 attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_SPMD_NO_LOOP;
6955
6956 attrs.MinTeams = minTeamsVal;
6957 attrs.MaxTeams.front() = maxTeamsVal;
6958 attrs.MinThreads = 1;
6959 attrs.MaxThreads.front() = combinedMaxThreadsVal;
6960 attrs.ReductionDataSize = reductionDataSize;
6961 // TODO: Allow modified buffer length similar to
6962 // fopenmp-cuda-teams-reduction-recs-num flag in clang.
6963 if (attrs.ReductionDataSize != 0)
6964 attrs.ReductionBufferLength = 1024;
6965}
6966
6967/// Gather LLVM runtime values for all clauses evaluated in the host that are
6968/// passed to the kernel invocation.
6969///
6970/// This function must be called only when compiling for the host. Also, it will
6971/// only provide correct results if it's called after the body of \c targetOp
6972/// has been fully generated.
6973static void
6974initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
6975 LLVM::ModuleTranslation &moduleTranslation,
6976 omp::TargetOp targetOp, Operation *capturedOp,
6977 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs) {
6978 omp::LoopNestOp loopOp = castOrGetParentOfType<omp::LoopNestOp>(capturedOp);
6979 unsigned numLoops = loopOp ? loopOp.getNumLoops() : 0;
6980
6981 Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
6982 llvm::SmallVector<Value> lowerBounds(numLoops), upperBounds(numLoops),
6983 steps(numLoops);
6984 extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper,
6985 teamsThreadLimit, &lowerBounds, &upperBounds, &steps);
6986
6987 // TODO: Handle constant 'if' clauses.
6988 if (!targetOp.getThreadLimitVars().empty()) {
6989 Value targetThreadLimit = targetOp.getThreadLimit(0);
6990 attrs.TargetThreadLimit.front() =
6991 moduleTranslation.lookupValue(targetThreadLimit);
6992 }
6993
6994 // The __kmpc_push_num_teams_51 function expects int32 as the arguments. So,
6995 // truncate or sign extend lower and upper num_teams bounds as well as
6996 // thread_limit to match int32 ABI requirements for the OpenMP runtime.
6997 if (numTeamsLower)
6998 attrs.MinTeams = builder.CreateSExtOrTrunc(
6999 moduleTranslation.lookupValue(numTeamsLower), builder.getInt32Ty());
7000
7001 if (numTeamsUpper)
7002 attrs.MaxTeams.front() = builder.CreateSExtOrTrunc(
7003 moduleTranslation.lookupValue(numTeamsUpper), builder.getInt32Ty());
7004
7005 if (teamsThreadLimit)
7006 attrs.TeamsThreadLimit.front() = builder.CreateSExtOrTrunc(
7007 moduleTranslation.lookupValue(teamsThreadLimit), builder.getInt32Ty());
7008
7009 if (numThreads)
7010 attrs.MaxThreads = moduleTranslation.lookupValue(numThreads);
7011
7012 if (omp::bitEnumContainsAny(targetOp.getKernelExecFlags(capturedOp),
7013 omp::TargetRegionFlags::trip_count)) {
7014 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
7015 attrs.LoopTripCount = nullptr;
7016
7017 // To calculate the trip count, we multiply together the trip counts of
7018 // every collapsed canonical loop. We don't need to create the loop nests
7019 // here, since we're only interested in the trip count.
7020 for (auto [loopLower, loopUpper, loopStep] :
7021 llvm::zip_equal(lowerBounds, upperBounds, steps)) {
7022 llvm::Value *lowerBound = moduleTranslation.lookupValue(loopLower);
7023 llvm::Value *upperBound = moduleTranslation.lookupValue(loopUpper);
7024 llvm::Value *step = moduleTranslation.lookupValue(loopStep);
7025
7026 if (!lowerBound || !upperBound || !step) {
7027 attrs.LoopTripCount = nullptr;
7028 break;
7029 }
7030
7031 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
7032 llvm::Value *tripCount = ompBuilder->calculateCanonicalLoopTripCount(
7033 loc, lowerBound, upperBound, step, /*IsSigned=*/true,
7034 loopOp.getLoopInclusive());
7035
7036 if (!attrs.LoopTripCount) {
7037 attrs.LoopTripCount = tripCount;
7038 continue;
7039 }
7040
7041 // TODO: Enable UndefinedSanitizer to diagnose an overflow here.
7042 attrs.LoopTripCount = builder.CreateMul(attrs.LoopTripCount, tripCount,
7043 {}, /*HasNUW=*/true);
7044 }
7045 }
7046
7047 attrs.DeviceID = builder.getInt64(llvm::omp::OMP_DEVICEID_UNDEF);
7048 if (mlir::Value devId = targetOp.getDevice()) {
7049 attrs.DeviceID = moduleTranslation.lookupValue(devId);
7050 attrs.DeviceID =
7051 builder.CreateSExtOrTrunc(attrs.DeviceID, builder.getInt64Ty());
7052 }
7053}
7054
7055static LogicalResult
7056convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
7057 LLVM::ModuleTranslation &moduleTranslation) {
7058 auto targetOp = cast<omp::TargetOp>(opInst);
7059
7060 // The current debug location already has the DISubprogram for the outlined
7061 // function that will be created for the target op. We save it here so that
7062 // we can set it on the outlined function.
7063 llvm::DebugLoc outlinedFnLoc = builder.getCurrentDebugLocation();
7064 if (failed(checkImplementationStatus(opInst)))
7065 return failure();
7066
7067 // During the handling of target op, we will generate instructions in the
7068 // parent function like call to the oulined function or branch to a new
7069 // BasicBlock. We set the debug location here to parent function so that those
7070 // get the correct debug locations. For outlined functions, the normal MLIR op
7071 // conversion will automatically pick the correct location.
7072 llvm::BasicBlock *parentBB = builder.GetInsertBlock();
7073 assert(parentBB && "No insert block is set for the builder");
7074 llvm::Function *parentLLVMFn = parentBB->getParent();
7075 assert(parentLLVMFn && "Parent Function must be valid");
7076 if (llvm::DISubprogram *SP = parentLLVMFn->getSubprogram())
7077 builder.SetCurrentDebugLocation(llvm::DILocation::get(
7078 parentLLVMFn->getContext(), outlinedFnLoc.getLine(),
7079 outlinedFnLoc.getCol(), SP, outlinedFnLoc.getInlinedAt()));
7080
7081 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
7082 bool isTargetDevice = ompBuilder->Config.isTargetDevice();
7083 bool isGPU = ompBuilder->Config.isGPU();
7084
7085 auto parentFn = opInst.getParentOfType<LLVM::LLVMFuncOp>();
7086 auto argIface = cast<omp::BlockArgOpenMPOpInterface>(opInst);
7087 auto &targetRegion = targetOp.getRegion();
7088 // Holds the private vars that have been mapped along with the block
7089 // argument that corresponds to the MapInfoOp corresponding to the private
7090 // var in question. So, for instance:
7091 //
7092 // %10 = omp.map.info var_ptr(%6#0 : !fir.ref<!fir.box<!fir.heap<i32>>>, ..)
7093 // omp.target map_entries(%10 -> %arg0) private(@box.privatizer %6#0-> %arg1)
7094 //
7095 // Then, %10 has been created so that the descriptor can be used by the
7096 // privatizer @box.privatizer on the device side. Here we'd record {%6#0,
7097 // %arg0} in the mappedPrivateVars map.
7098 llvm::DenseMap<Value, Value> mappedPrivateVars;
7099 DataLayout dl = DataLayout(opInst.getParentOfType<ModuleOp>());
7100 SmallVector<Value> mapVars = targetOp.getMapVars();
7101 SmallVector<Value> hdaVars = targetOp.getHasDeviceAddrVars();
7102 ArrayRef<BlockArgument> mapBlockArgs = argIface.getMapBlockArgs();
7103 ArrayRef<BlockArgument> hdaBlockArgs = argIface.getHasDeviceAddrBlockArgs();
7104 llvm::Function *llvmOutlinedFn = nullptr;
7105 TargetDirectiveEnumTy targetDirective =
7106 getTargetDirectiveEnumTyFromOp(&opInst);
7107
7108 // TODO: It can also be false if a compile-time constant `false` IF clause is
7109 // specified.
7110 bool isOffloadEntry =
7111 isTargetDevice || !ompBuilder->Config.TargetTriples.empty();
7112
7113 // For some private variables, the MapsForPrivatizedVariablesPass
7114 // creates MapInfoOp instances. Go through the private variables and
7115 // the mapped variables so that during codegeneration we are able
7116 // to quickly look up the corresponding map variable, if any for each
7117 // private variable.
7118 if (!targetOp.getPrivateVars().empty() && !targetOp.getMapVars().empty()) {
7119 OperandRange privateVars = targetOp.getPrivateVars();
7120 std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms();
7121 std::optional<DenseI64ArrayAttr> privateMapIndices =
7122 targetOp.getPrivateMapsAttr();
7123
7124 for (auto [privVarIdx, privVarSymPair] :
7125 llvm::enumerate(llvm::zip_equal(privateVars, *privateSyms))) {
7126 auto privVar = std::get<0>(privVarSymPair);
7127 auto privSym = std::get<1>(privVarSymPair);
7128
7129 SymbolRefAttr privatizerName = llvm::cast<SymbolRefAttr>(privSym);
7130 omp::PrivateClauseOp privatizer =
7131 findPrivatizer(targetOp, privatizerName);
7132
7133 if (!privatizer.needsMap())
7134 continue;
7135
7136 mlir::Value mappedValue =
7137 targetOp.getMappedValueForPrivateVar(privVarIdx);
7138 assert(mappedValue && "Expected to find mapped value for a privatized "
7139 "variable that needs mapping");
7140
7141 // The MapInfoOp defining the map var isn't really needed later.
7142 // So, we don't store it in any datastructure. Instead, we just
7143 // do some sanity checks on it right now.
7144 auto mapInfoOp = mappedValue.getDefiningOp<omp::MapInfoOp>();
7145 [[maybe_unused]] Type varType = mapInfoOp.getVarType();
7146
7147 // Check #1: Check that the type of the private variable matches
7148 // the type of the variable being mapped.
7149 if (!isa<LLVM::LLVMPointerType>(privVar.getType()))
7150 assert(
7151 varType == privVar.getType() &&
7152 "Type of private var doesn't match the type of the mapped value");
7153
7154 // Ok, only 1 sanity check for now.
7155 // Record the block argument corresponding to this mapvar.
7156 mappedPrivateVars.insert(
7157 {privVar,
7158 targetRegion.getArgument(argIface.getMapBlockArgsStart() +
7159 (*privateMapIndices)[privVarIdx])});
7160 }
7161 }
7162
7163 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
7164 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP)
7165 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
7166 llvm::IRBuilderBase::InsertPointGuard guard(builder);
7167 builder.SetCurrentDebugLocation(llvm::DebugLoc());
7168 // Forward target-cpu and target-features function attributes from the
7169 // original function to the new outlined function.
7170 llvm::Function *llvmParentFn =
7171 moduleTranslation.lookupFunction(parentFn.getName());
7172 llvmOutlinedFn = codeGenIP.getBlock()->getParent();
7173 assert(llvmParentFn && llvmOutlinedFn &&
7174 "Both parent and outlined functions must exist at this point");
7175
7176 if (outlinedFnLoc && llvmParentFn->getSubprogram())
7177 llvmOutlinedFn->setSubprogram(outlinedFnLoc->getScope()->getSubprogram());
7178
7179 if (auto attr = llvmParentFn->getFnAttribute("target-cpu");
7180 attr.isStringAttribute())
7181 llvmOutlinedFn->addFnAttr(attr);
7182
7183 if (auto attr = llvmParentFn->getFnAttribute("target-features");
7184 attr.isStringAttribute())
7185 llvmOutlinedFn->addFnAttr(attr);
7186
7187 for (auto [arg, mapOp] : llvm::zip_equal(mapBlockArgs, mapVars)) {
7188 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
7189 llvm::Value *mapOpValue =
7190 moduleTranslation.lookupValue(mapInfoOp.getVarPtr());
7191 moduleTranslation.mapValue(arg, mapOpValue);
7192 }
7193 for (auto [arg, mapOp] : llvm::zip_equal(hdaBlockArgs, hdaVars)) {
7194 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
7195 llvm::Value *mapOpValue =
7196 moduleTranslation.lookupValue(mapInfoOp.getVarPtr());
7197 moduleTranslation.mapValue(arg, mapOpValue);
7198 }
7199
7200 // Do privatization after moduleTranslation has already recorded
7201 // mapped values.
7202 PrivateVarsInfo privateVarsInfo(targetOp);
7203
7205 allocatePrivateVars(builder, moduleTranslation, privateVarsInfo,
7206 allocaIP, &mappedPrivateVars);
7207
7208 if (failed(handleError(afterAllocas, *targetOp)))
7209 return llvm::make_error<PreviouslyReportedError>();
7210
7211 builder.restoreIP(codeGenIP);
7212 if (handleError(initPrivateVars(builder, moduleTranslation, privateVarsInfo,
7213 &mappedPrivateVars),
7214 *targetOp)
7215 .failed())
7216 return llvm::make_error<PreviouslyReportedError>();
7217
7218 if (failed(copyFirstPrivateVars(
7219 targetOp, builder, moduleTranslation, privateVarsInfo.mlirVars,
7220 privateVarsInfo.llvmVars, privateVarsInfo.privatizers,
7221 targetOp.getPrivateNeedsBarrier(), &mappedPrivateVars)))
7222 return llvm::make_error<PreviouslyReportedError>();
7223
7224 SmallVector<Region *> privateCleanupRegions;
7225 llvm::transform(privateVarsInfo.privatizers,
7226 std::back_inserter(privateCleanupRegions),
7227 [](omp::PrivateClauseOp privatizer) {
7228 return &privatizer.getDeallocRegion();
7229 });
7230
7232 targetRegion, "omp.target", builder, moduleTranslation);
7233
7234 if (!exitBlock)
7235 return exitBlock.takeError();
7236
7237 builder.SetInsertPoint(*exitBlock);
7238 if (!privateCleanupRegions.empty()) {
7239 if (failed(inlineOmpRegionCleanup(
7240 privateCleanupRegions, privateVarsInfo.llvmVars,
7241 moduleTranslation, builder, "omp.targetop.private.cleanup",
7242 /*shouldLoadCleanupRegionArg=*/false))) {
7243 return llvm::createStringError(
7244 "failed to inline `dealloc` region of `omp.private` "
7245 "op in the target region");
7246 }
7247 return builder.saveIP();
7248 }
7249
7250 return InsertPointTy(exitBlock.get(), exitBlock.get()->end());
7251 };
7252
7253 StringRef parentName = parentFn.getName();
7254
7255 llvm::TargetRegionEntryInfo entryInfo;
7256
7257 getTargetEntryUniqueInfo(entryInfo, targetOp, parentName);
7258
7259 MapInfoData mapData;
7260 collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, dl,
7261 builder, /*useDevPtrOperands=*/{},
7262 /*useDevAddrOperands=*/{}, hdaVars);
7263
7264 MapInfosTy combinedInfos;
7265 auto genMapInfoCB =
7266 [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP) -> MapInfosTy & {
7267 builder.restoreIP(codeGenIP);
7268 genMapInfos(builder, moduleTranslation, dl, combinedInfos, mapData,
7269 targetDirective);
7270
7271 // Append a null entry for the implicit dyn_ptr argument so the argument
7272 // count sent to the runtime already includes it.
7273 auto *nullPtr = llvm::Constant::getNullValue(builder.getPtrTy());
7274 combinedInfos.BasePointers.push_back(nullPtr);
7275 combinedInfos.Pointers.push_back(nullPtr);
7276 combinedInfos.DevicePointers.push_back(
7277 llvm::OpenMPIRBuilder::DeviceInfoTy::None);
7278 combinedInfos.Sizes.push_back(builder.getInt64(0));
7279 combinedInfos.Types.push_back(
7280 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM |
7281 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL);
7282 if (!combinedInfos.Names.empty())
7283 combinedInfos.Names.push_back(nullPtr);
7284 combinedInfos.Mappers.push_back(nullptr);
7285
7286 return combinedInfos;
7287 };
7288
7289 auto argAccessorCB = [&](llvm::Argument &arg, llvm::Value *input,
7290 llvm::Value *&retVal, InsertPointTy allocaIP,
7291 InsertPointTy codeGenIP)
7292 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
7293 llvm::IRBuilderBase::InsertPointGuard guard(builder);
7294 builder.SetCurrentDebugLocation(llvm::DebugLoc());
7295 // We just return the unaltered argument for the host function
7296 // for now, some alterations may be required in the future to
7297 // keep host fallback functions working identically to the device
7298 // version (e.g. pass ByCopy values should be treated as such on
7299 // host and device, currently not always the case)
7300 if (!isTargetDevice) {
7301 retVal = cast<llvm::Value>(&arg);
7302 return codeGenIP;
7303 }
7304
7305 return createDeviceArgumentAccessor(mapData, arg, input, retVal, builder,
7306 *ompBuilder, moduleTranslation,
7307 allocaIP, codeGenIP);
7308 };
7309
7310 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs runtimeAttrs;
7311 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs;
7312 Operation *targetCapturedOp = targetOp.getInnermostCapturedOmpOp();
7313 initTargetDefaultAttrs(targetOp, targetCapturedOp, defaultAttrs,
7314 isTargetDevice, isGPU);
7315
7316 // Collect host-evaluated values needed to properly launch the kernel from the
7317 // host.
7318 if (!isTargetDevice)
7319 initTargetRuntimeAttrs(builder, moduleTranslation, targetOp,
7320 targetCapturedOp, runtimeAttrs);
7321
7322 // Pass host-evaluated values as parameters to the kernel / host fallback,
7323 // except if they are constants. In any case, map the MLIR block argument to
7324 // the corresponding LLVM values.
7326 SmallVector<Value> hostEvalVars = targetOp.getHostEvalVars();
7327 ArrayRef<BlockArgument> hostEvalBlockArgs = argIface.getHostEvalBlockArgs();
7328 for (auto [arg, var] : llvm::zip_equal(hostEvalBlockArgs, hostEvalVars)) {
7329 llvm::Value *value = moduleTranslation.lookupValue(var);
7330 moduleTranslation.mapValue(arg, value);
7331
7332 if (!llvm::isa<llvm::Constant>(value))
7333 kernelInput.push_back(value);
7334 }
7335
7336 for (size_t i = 0, e = mapData.OriginalValue.size(); i != e; ++i) {
7337 // declare target arguments are not passed to kernels as arguments
7338 // TODO: We currently do not handle cases where a member is explicitly
7339 // passed in as an argument, this will likley need to be handled in
7340 // the near future, rather than using IsAMember, it may be better to
7341 // test if the relevant BlockArg is used within the target region and
7342 // then use that as a basis for exclusion in the kernel inputs.
7343 if (!mapData.IsDeclareTarget[i] && !mapData.IsAMember[i])
7344 kernelInput.push_back(mapData.OriginalValue[i]);
7345 }
7346
7347 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
7348 findAllocaInsertPoint(builder, moduleTranslation);
7349
7350 llvm::OpenMPIRBuilder::DependenciesInfo dds;
7351 if (failed(buildDependData(
7352 targetOp.getDependVars(), targetOp.getDependKinds(),
7353 targetOp.getDependIterated(), targetOp.getDependIteratedKinds(),
7354 builder, moduleTranslation, dds)))
7355 return failure();
7356
7357 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
7358
7359 llvm::OpenMPIRBuilder::TargetDataInfo info(
7360 /*RequiresDevicePointerInfo=*/false,
7361 /*SeparateBeginEndCalls=*/true);
7362
7363 auto customMapperCB =
7364 [&](unsigned int i) -> llvm::Expected<llvm::Function *> {
7365 if (!combinedInfos.Mappers[i])
7366 return nullptr;
7367 info.HasMapper = true;
7368 return getOrCreateUserDefinedMapperFunc(combinedInfos.Mappers[i], builder,
7369 moduleTranslation, targetDirective);
7370 };
7371
7372 llvm::Value *ifCond = nullptr;
7373 if (Value targetIfCond = targetOp.getIfExpr())
7374 ifCond = moduleTranslation.lookupValue(targetIfCond);
7375
7376 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
7377 moduleTranslation.getOpenMPBuilder()->createTarget(
7378 ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), info, entryInfo,
7379 defaultAttrs, runtimeAttrs, ifCond, kernelInput, genMapInfoCB, bodyCB,
7380 argAccessorCB, customMapperCB, dds, targetOp.getNowait());
7381
7382 if (failed(handleError(afterIP, opInst)))
7383 return failure();
7384
7385 builder.restoreIP(*afterIP);
7386
7387 if (dds.DepArray)
7388 builder.CreateFree(dds.DepArray);
7389
7390 // Remap access operations to declare target reference pointers for the
7391 // device, essentially generating extra loadop's as necessary
7392 if (moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice())
7393 handleDeclareTargetMapVar(mapData, moduleTranslation, builder,
7394 llvmOutlinedFn);
7395
7396 return success();
7397}
7398
7399static LogicalResult
7400convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute,
7401 LLVM::ModuleTranslation &moduleTranslation) {
7402 // Amend omp.declare_target by deleting the IR of the outlined functions
7403 // created for target regions. They cannot be filtered out from MLIR earlier
7404 // because the omp.target operation inside must be translated to LLVM, but
7405 // the wrapper functions themselves must not remain at the end of the
7406 // process. We know that functions where omp.declare_target does not match
7407 // omp.is_target_device at this stage can only be wrapper functions because
7408 // those that aren't are removed earlier as an MLIR transformation pass.
7409 if (FunctionOpInterface funcOp = dyn_cast<FunctionOpInterface>(op)) {
7410 if (auto offloadMod = dyn_cast<omp::OffloadModuleInterface>(
7411 op->getParentOfType<ModuleOp>().getOperation())) {
7412 if (!offloadMod.getIsTargetDevice())
7413 return success();
7414
7415 omp::DeclareTargetDeviceType declareType =
7416 attribute.getDeviceType().getValue();
7417
7418 if (declareType == omp::DeclareTargetDeviceType::host) {
7419 llvm::Function *llvmFunc =
7420 moduleTranslation.lookupFunction(funcOp.getName());
7421 llvmFunc->dropAllReferences();
7422 llvmFunc->eraseFromParent();
7423 }
7424 }
7425 return success();
7426 }
7427
7428 if (LLVM::GlobalOp gOp = dyn_cast<LLVM::GlobalOp>(op)) {
7429 llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
7430 if (auto *gVal = llvmModule->getNamedValue(gOp.getSymName())) {
7431 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
7432 bool isDeclaration = gOp.isDeclaration();
7433 bool isExternallyVisible =
7434 gOp.getVisibility() != mlir::SymbolTable::Visibility::Private;
7435 auto loc = op->getLoc()->findInstanceOf<FileLineColLoc>();
7436 llvm::StringRef mangledName = gOp.getSymName();
7437 auto captureClause =
7438 convertToCaptureClauseKind(attribute.getCaptureClause().getValue());
7439 auto deviceClause =
7440 convertToDeviceClauseKind(attribute.getDeviceType().getValue());
7441 // unused for MLIR at the moment, required in Clang for book
7442 // keeping
7443 std::vector<llvm::GlobalVariable *> generatedRefs;
7444
7445 std::vector<llvm::Triple> targetTriple;
7446 auto targetTripleAttr = dyn_cast_or_null<mlir::StringAttr>(
7447 op->getParentOfType<mlir::ModuleOp>()->getAttr(
7448 LLVM::LLVMDialect::getTargetTripleAttrName()));
7449 if (targetTripleAttr)
7450 targetTriple.emplace_back(targetTripleAttr.data());
7451
7452 auto fileInfoCallBack = [&loc]() {
7453 std::string filename = "";
7454 std::uint64_t lineNo = 0;
7455
7456 if (loc) {
7457 filename = loc.getFilename().str();
7458 lineNo = loc.getLine();
7459 }
7460
7461 return std::pair<std::string, std::uint64_t>(llvm::StringRef(filename),
7462 lineNo);
7463 };
7464
7465 auto vfs = llvm::vfs::getRealFileSystem();
7466
7467 ompBuilder->registerTargetGlobalVariable(
7468 captureClause, deviceClause, isDeclaration, isExternallyVisible,
7469 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack, *vfs),
7470 mangledName, generatedRefs, /*OpenMPSimd*/ false, targetTriple,
7471 /*GlobalInitializer*/ nullptr, /*VariableLinkage*/ nullptr,
7472 gVal->getType(), gVal);
7473
7474 if (ompBuilder->Config.isTargetDevice() &&
7475 (attribute.getCaptureClause().getValue() !=
7476 mlir::omp::DeclareTargetCaptureClause::to ||
7477 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
7478 ompBuilder->getAddrOfDeclareTargetVar(
7479 captureClause, deviceClause, isDeclaration, isExternallyVisible,
7480 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack, *vfs),
7481 mangledName, generatedRefs, /*OpenMPSimd*/ false, targetTriple,
7482 gVal->getType(), /*GlobalInitializer*/ nullptr,
7483 /*VariableLinkage*/ nullptr);
7484 }
7485 }
7486 }
7487
7488 return success();
7489}
7490
7491namespace {
7492
7493/// Implementation of the dialect interface that converts operations belonging
7494/// to the OpenMP dialect to LLVM IR.
7495class OpenMPDialectLLVMIRTranslationInterface
7496 : public LLVMTranslationDialectInterface {
7497public:
7498 using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
7499
7500 /// Translates the given operation to LLVM IR using the provided IR builder
7501 /// and saving the state in `moduleTranslation`.
7502 LogicalResult
7503 convertOperation(Operation *op, llvm::IRBuilderBase &builder,
7504 LLVM::ModuleTranslation &moduleTranslation) const final;
7505
7506 /// Given an OpenMP MLIR attribute, create the corresponding LLVM-IR,
7507 /// runtime calls, or operation amendments
7508 LogicalResult
7509 amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
7510 NamedAttribute attribute,
7511 LLVM::ModuleTranslation &moduleTranslation) const final;
7512
7513 /// Records the LLVM alloc pointer produced for an OMP ALLOCATE variable so
7514 /// that the paired omp.allocate_free op can generate the matching
7515 /// __kmpc_free call.
7516 void registerAllocatedPtr(Value var, llvm::Value *ptr) const {
7517 ompAllocatedPtrs[var] = ptr;
7518 }
7519
7520 /// Returns the LLVM alloc pointer previously registered for var, or
7521 /// nullptr if no allocation was recorded.
7522 llvm::Value *lookupAllocatedPtr(Value var) const {
7523 auto it = ompAllocatedPtrs.find(var);
7524 return it != ompAllocatedPtrs.end() ? it->second : nullptr;
7525 }
7526
7527private:
7528 /// Maps each MLIR variable value that appeared in an omp.allocate_dir op to
7529 /// the LLVM pointer returned by the corresponding __kmpc_alloc call. The
7530 /// paired omp.allocate_free op looks up these pointers to emit __kmpc_free.
7531 mutable DenseMap<Value, llvm::Value *> ompAllocatedPtrs;
7532};
7533
7534} // namespace
7535
7536LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
7537 Operation *op, ArrayRef<llvm::Instruction *> instructions,
7538 NamedAttribute attribute,
7539 LLVM::ModuleTranslation &moduleTranslation) const {
7540 return llvm::StringSwitch<llvm::function_ref<LogicalResult(Attribute)>>(
7541 attribute.getName())
7542 .Case("omp.is_target_device",
7543 [&](Attribute attr) {
7544 if (auto deviceAttr = dyn_cast<BoolAttr>(attr)) {
7545 llvm::OpenMPIRBuilderConfig &config =
7546 moduleTranslation.getOpenMPBuilder()->Config;
7547 config.setIsTargetDevice(deviceAttr.getValue());
7548 return success();
7549 }
7550 return failure();
7551 })
7552 .Case("omp.is_gpu",
7553 [&](Attribute attr) {
7554 if (auto gpuAttr = dyn_cast<BoolAttr>(attr)) {
7555 llvm::OpenMPIRBuilderConfig &config =
7556 moduleTranslation.getOpenMPBuilder()->Config;
7557 config.setIsGPU(gpuAttr.getValue());
7558 return success();
7559 }
7560 return failure();
7561 })
7562 .Case("omp.host_ir_filepath",
7563 [&](Attribute attr) {
7564 if (auto filepathAttr = dyn_cast<StringAttr>(attr)) {
7565 llvm::OpenMPIRBuilder *ompBuilder =
7566 moduleTranslation.getOpenMPBuilder();
7567 auto VFS = llvm::vfs::getRealFileSystem();
7568 ompBuilder->loadOffloadInfoMetadata(*VFS,
7569 filepathAttr.getValue());
7570 return success();
7571 }
7572 return failure();
7573 })
7574 .Case("omp.flags",
7575 [&](Attribute attr) {
7576 if (auto rtlAttr = dyn_cast<omp::FlagsAttr>(attr))
7577 return convertFlagsAttr(op, rtlAttr, moduleTranslation);
7578 return failure();
7579 })
7580 .Case("omp.version",
7581 [&](Attribute attr) {
7582 if (auto versionAttr = dyn_cast<omp::VersionAttr>(attr)) {
7583 llvm::OpenMPIRBuilder *ompBuilder =
7584 moduleTranslation.getOpenMPBuilder();
7585 ompBuilder->M.addModuleFlag(llvm::Module::Max, "openmp",
7586 versionAttr.getVersion());
7587 return success();
7588 }
7589 return failure();
7590 })
7591 .Case("omp.declare_target",
7592 [&](Attribute attr) {
7593 if (auto declareTargetAttr =
7594 dyn_cast<omp::DeclareTargetAttr>(attr))
7595 return convertDeclareTargetAttr(op, declareTargetAttr,
7596 moduleTranslation);
7597 return failure();
7598 })
7599 .Case("omp.requires",
7600 [&](Attribute attr) {
7601 if (auto requiresAttr = dyn_cast<omp::ClauseRequiresAttr>(attr)) {
7602 using Requires = omp::ClauseRequires;
7603 Requires flags = requiresAttr.getValue();
7604 llvm::OpenMPIRBuilderConfig &config =
7605 moduleTranslation.getOpenMPBuilder()->Config;
7606 config.setHasRequiresReverseOffload(
7607 bitEnumContainsAll(flags, Requires::reverse_offload));
7608 config.setHasRequiresUnifiedAddress(
7609 bitEnumContainsAll(flags, Requires::unified_address));
7610 config.setHasRequiresUnifiedSharedMemory(
7611 bitEnumContainsAll(flags, Requires::unified_shared_memory));
7612 config.setHasRequiresDynamicAllocators(
7613 bitEnumContainsAll(flags, Requires::dynamic_allocators));
7614 return success();
7615 }
7616 return failure();
7617 })
7618 .Case("omp.target_triples",
7619 [&](Attribute attr) {
7620 if (auto triplesAttr = dyn_cast<ArrayAttr>(attr)) {
7621 llvm::OpenMPIRBuilderConfig &config =
7622 moduleTranslation.getOpenMPBuilder()->Config;
7623 config.TargetTriples.clear();
7624 config.TargetTriples.reserve(triplesAttr.size());
7625 for (Attribute tripleAttr : triplesAttr) {
7626 if (auto tripleStrAttr = dyn_cast<StringAttr>(tripleAttr))
7627 config.TargetTriples.emplace_back(tripleStrAttr.getValue());
7628 else
7629 return failure();
7630 }
7631 return success();
7632 }
7633 return failure();
7634 })
7635 .Default([](Attribute) {
7636 // Fall through for omp attributes that do not require lowering.
7637 return success();
7638 })(attribute.getValue());
7639
7640 return failure();
7641}
7642
7643// Returns true if the operation is not inside a TargetOp, it is part of a
7644// function and that function is not declare target.
7645static bool isHostDeviceOp(Operation *op) {
7646 // Assumes no reverse offloading
7647 if (op->getParentOfType<omp::TargetOp>())
7648 return false;
7649
7650 if (auto parentFn = op->getParentOfType<LLVM::LLVMFuncOp>()) {
7651 if (auto declareTargetIface =
7652 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
7653 parentFn.getOperation()))
7654 if (declareTargetIface.isDeclareTarget() &&
7655 declareTargetIface.getDeclareTargetDeviceType() !=
7656 mlir::omp::DeclareTargetDeviceType::host)
7657 return false;
7658
7659 return true;
7660 }
7661
7662 return false;
7663}
7664
7665static llvm::Function *getOmpTargetAlloc(llvm::IRBuilderBase &builder,
7666 llvm::Module *llvmModule) {
7667 llvm::Type *i64Ty = builder.getInt64Ty();
7668 llvm::Type *i32Ty = builder.getInt32Ty();
7669 llvm::Type *returnType = builder.getPtrTy(0);
7670 llvm::FunctionType *fnType =
7671 llvm::FunctionType::get(returnType, {i64Ty, i32Ty}, false);
7672 llvm::Function *func = cast<llvm::Function>(
7673 llvmModule->getOrInsertFunction("omp_target_alloc", fnType).getCallee());
7674 return func;
7675}
7676
7677static LogicalResult
7678convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder,
7679 LLVM::ModuleTranslation &moduleTranslation) {
7680 auto allocMemOp = cast<omp::TargetAllocMemOp>(opInst);
7681 if (!allocMemOp)
7682 return failure();
7683
7684 // Get "omp_target_alloc" function
7685 llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
7686 llvm::Function *ompTargetAllocFunc = getOmpTargetAlloc(builder, llvmModule);
7687 // Get the corresponding device value in llvm
7688 mlir::Value deviceNum = allocMemOp.getDevice();
7689 llvm::Value *llvmDeviceNum = moduleTranslation.lookupValue(deviceNum);
7690 // Get the allocation size.
7691 llvm::DataLayout dataLayout = llvmModule->getDataLayout();
7692 mlir::Type heapTy = allocMemOp.getAllocatedType();
7693 llvm::Type *llvmHeapTy = moduleTranslation.convertType(heapTy);
7694 llvm::TypeSize typeSize = dataLayout.getTypeStoreSize(llvmHeapTy);
7695 llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue());
7696 for (auto typeParam : allocMemOp.getTypeparams())
7697 allocSize =
7698 builder.CreateMul(allocSize, moduleTranslation.lookupValue(typeParam));
7699 // Create call to "omp_target_alloc" with the args as translated llvm values.
7700 llvm::CallInst *call =
7701 builder.CreateCall(ompTargetAllocFunc, {allocSize, llvmDeviceNum});
7702 llvm::Value *resultI64 = builder.CreatePtrToInt(call, builder.getInt64Ty());
7703
7704 // Map the result
7705 moduleTranslation.mapValue(allocMemOp.getResult(), resultI64);
7706 return success();
7707}
7708
7709static LogicalResult
7710convertAllocateDirOp(Operation &opInst, llvm::IRBuilderBase &builder,
7711 LLVM::ModuleTranslation &moduleTranslation,
7712 const OpenMPDialectLLVMIRTranslationInterface &ompIface) {
7713 auto allocateDirOp = cast<omp::AllocateDirOp>(opInst);
7714 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
7715
7716 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
7717 llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
7718 llvm::DataLayout dataLayout = llvmModule->getDataLayout();
7719 SmallVector<Value> vars = allocateDirOp.getVarList();
7720 std::optional<int64_t> alignAttr = allocateDirOp.getAlign();
7721
7722 llvm::Value *allocator;
7723 if (auto allocatorVar = allocateDirOp.getAllocator()) {
7724 allocator = moduleTranslation.lookupValue(allocatorVar);
7725 if (allocator->getType()->isIntegerTy())
7726 allocator = builder.CreateIntToPtr(allocator, builder.getPtrTy());
7727 else if (allocator->getType()->isPointerTy())
7728 allocator = builder.CreatePointerBitCastOrAddrSpaceCast(
7729 allocator, builder.getPtrTy());
7730 } else {
7731 allocator = llvm::ConstantPointerNull::get(builder.getPtrTy());
7732 }
7733
7734 for (Value var : vars) {
7735 llvm::Type *llvmVarTy = moduleTranslation.convertType(var.getType());
7736
7737 // Opaque pointers lose element type. Trace to GlobalOp for type
7738 // Falls back to llvmVarTy when not from a global.
7739 llvm::Type *typeToInspect = llvmVarTy;
7740 if (llvmVarTy->isPointerTy()) {
7741 Value baseVar = getBaseValueForTypeLookup(var);
7742 if (Operation *globalOp = getGlobalOpFromValue(baseVar)) {
7743 if (auto gop = dyn_cast<LLVM::GlobalOp>(globalOp))
7744 typeToInspect = moduleTranslation.convertType(gop.getGlobalType());
7745 }
7746 }
7747
7748 llvm::Value *size;
7749 if (auto arrTy = llvm::dyn_cast<llvm::ArrayType>(typeToInspect)) {
7750 llvm::Value *elementCount = builder.getInt64(1);
7751 llvm::Type *currentType = arrTy;
7752 while (auto nestedArrTy = llvm::dyn_cast<llvm::ArrayType>(currentType)) {
7753 elementCount = builder.CreateMul(
7754 elementCount, builder.getInt64(nestedArrTy->getNumElements()));
7755 currentType = nestedArrTy->getElementType();
7756 }
7757 uint64_t elemSizeInBits = dataLayout.getTypeSizeInBits(currentType);
7758 size =
7759 builder.CreateMul(elementCount, builder.getInt64(elemSizeInBits / 8));
7760 } else {
7761 size = builder.getInt64(
7762 dataLayout.getTypeStoreSize(typeToInspect).getFixedValue());
7763 }
7764
7765 uint64_t alignValue =
7766 alignAttr ? alignAttr.value()
7767 : dataLayout.getABITypeAlign(typeToInspect).value();
7768 llvm::Value *alignConst = builder.getInt64(alignValue);
7769 // Align the size: ((size + align - 1) / align) * align
7770 size = builder.CreateAdd(size, builder.getInt64(alignValue - 1), "", true);
7771 size = builder.CreateUDiv(size, alignConst);
7772 size = builder.CreateMul(size, alignConst, "", true);
7773
7774 std::string allocName =
7775 ompBuilder->createPlatformSpecificName({".void.addr"});
7776 llvm::CallInst *allocCall;
7777 if (alignAttr.has_value()) {
7778 allocCall = ompBuilder->createOMPAlignedAlloc(
7779 ompLoc, builder.getInt64(alignAttr.value()), size, allocator,
7780 allocName);
7781 } else {
7782 allocCall =
7783 ompBuilder->createOMPAlloc(ompLoc, size, allocator, allocName);
7784 }
7785 // Record the alloc pointer keyed by the MLIR variable value.
7786 ompIface.registerAllocatedPtr(var, allocCall);
7787 }
7788
7789 return success();
7790}
7791
7792static LogicalResult
7793convertAllocateFreeOp(Operation &opInst, llvm::IRBuilderBase &builder,
7794 LLVM::ModuleTranslation &moduleTranslation,
7795 const OpenMPDialectLLVMIRTranslationInterface &ompIface) {
7796 auto freeOp = cast<omp::AllocateFreeOp>(opInst);
7797 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
7798 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
7799
7800 llvm::Value *allocator;
7801 if (auto allocatorVar = freeOp.getAllocator()) {
7802 allocator = moduleTranslation.lookupValue(allocatorVar);
7803 if (allocator->getType()->isIntegerTy())
7804 allocator = builder.CreateIntToPtr(allocator, builder.getPtrTy());
7805 else if (allocator->getType()->isPointerTy())
7806 allocator = builder.CreatePointerBitCastOrAddrSpaceCast(
7807 allocator, builder.getPtrTy());
7808 } else {
7809 allocator = llvm::ConstantPointerNull::get(builder.getPtrTy());
7810 }
7811
7812 // Emit __kmpc_free for each variable in reverse allocation order.
7813 SmallVector<Value> vars = freeOp.getVarList();
7814 for (Value var : llvm::reverse(vars)) {
7815 llvm::Value *allocPtr = ompIface.lookupAllocatedPtr(var);
7816 if (!allocPtr)
7817 return opInst.emitError("omp.allocate_free: no allocation recorded");
7818 ompBuilder->createOMPFree(ompLoc, allocPtr, allocator, "");
7819 }
7820
7821 return success();
7822}
7823
7824static llvm::Function *getOmpTargetFree(llvm::IRBuilderBase &builder,
7825 llvm::Module *llvmModule) {
7826 llvm::Type *ptrTy = builder.getPtrTy(0);
7827 llvm::Type *i32Ty = builder.getInt32Ty();
7828 llvm::Type *voidTy = builder.getVoidTy();
7829 llvm::FunctionType *fnType =
7830 llvm::FunctionType::get(voidTy, {ptrTy, i32Ty}, false);
7831 llvm::Function *func = dyn_cast<llvm::Function>(
7832 llvmModule->getOrInsertFunction("omp_target_free", fnType).getCallee());
7833 return func;
7834}
7835
7836static LogicalResult
7837convertTargetFreeMemOp(Operation &opInst, llvm::IRBuilderBase &builder,
7838 LLVM::ModuleTranslation &moduleTranslation) {
7839 auto freeMemOp = cast<omp::TargetFreeMemOp>(opInst);
7840 if (!freeMemOp)
7841 return failure();
7842
7843 // Get "omp_target_free" function
7844 llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
7845 llvm::Function *ompTragetFreeFunc = getOmpTargetFree(builder, llvmModule);
7846 // Get the corresponding device value in llvm
7847 mlir::Value deviceNum = freeMemOp.getDevice();
7848 llvm::Value *llvmDeviceNum = moduleTranslation.lookupValue(deviceNum);
7849 // Get the corresponding heapref value in llvm
7850 mlir::Value heapref = freeMemOp.getHeapref();
7851 llvm::Value *llvmHeapref = moduleTranslation.lookupValue(heapref);
7852 // Convert heapref int to ptr and call "omp_target_free"
7853 llvm::Value *intToPtr =
7854 builder.CreateIntToPtr(llvmHeapref, builder.getPtrTy(0));
7855 builder.CreateCall(ompTragetFreeFunc, {intToPtr, llvmDeviceNum});
7856 return success();
7857}
7858
7859/// Given an OpenMP MLIR operation, create the corresponding LLVM IR (including
7860/// OpenMP runtime calls).
7861LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
7862 Operation *op, llvm::IRBuilderBase &builder,
7863 LLVM::ModuleTranslation &moduleTranslation) const {
7864 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
7865
7866 if (ompBuilder->Config.isTargetDevice() &&
7867 !isa<omp::TargetOp, omp::MapInfoOp, omp::TerminatorOp, omp::YieldOp>(
7868 op) &&
7869 isHostDeviceOp(op))
7870 return op->emitOpError() << "unsupported host op found in device";
7871
7872 // For each loop, introduce one stack frame to hold loop information. Ensure
7873 // this is only done for the outermost loop wrapper to prevent introducing
7874 // multiple stack frames for a single loop. Initially set to null, the loop
7875 // information structure is initialized during translation of the nested
7876 // omp.loop_nest operation, making it available to translation of all loop
7877 // wrappers after their body has been successfully translated.
7878 bool isOutermostLoopWrapper =
7879 isa_and_present<omp::LoopWrapperInterface>(op) &&
7880 !dyn_cast_if_present<omp::LoopWrapperInterface>(op->getParentOp());
7881
7882 // The TASKLOOP construct is implemented with an outer taskloop.context
7883 // operation which is not a loop wrapper, containing an inner taskloop
7884 // operation which is a loop wrapper. The stack frame should be pushed when
7885 // translating the outer taskloop.context and popped when translating the
7886 // inner taskloop which is a loop wrapper. We need access to the loop
7887 // information in the outer taskloop context so we need to create it and pop
7888 // it around the taskloop context not the inner loop wrapper.
7889 if (isa<omp::TaskloopContextOp>(op))
7890 isOutermostLoopWrapper = true;
7891 else if (isa<omp::TaskloopWrapperOp>(op))
7892 isOutermostLoopWrapper = false;
7893
7894 if (isOutermostLoopWrapper)
7895 moduleTranslation.stackPush<OpenMPLoopInfoStackFrame>();
7896
7897 auto result =
7898 llvm::TypeSwitch<Operation *, LogicalResult>(op)
7899 .Case([&](omp::BarrierOp op) -> LogicalResult {
7901 return failure();
7902
7903 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
7904 ompBuilder->createBarrier(builder.saveIP(),
7905 llvm::omp::OMPD_barrier);
7906 LogicalResult res = handleError(afterIP, *op);
7907 if (res.succeeded()) {
7908 // If the barrier generated a cancellation check, the insertion
7909 // point might now need to be changed to a new continuation block
7910 builder.restoreIP(*afterIP);
7911 }
7912 return res;
7913 })
7914 .Case([&](omp::TaskyieldOp op) {
7916 return failure();
7917
7918 ompBuilder->createTaskyield(builder.saveIP());
7919 return success();
7920 })
7921 .Case([&](omp::FlushOp op) {
7923 return failure();
7924
7925 // No support in Openmp runtime function (__kmpc_flush) to accept
7926 // the argument list.
7927 // OpenMP standard states the following:
7928 // "An implementation may implement a flush with a list by ignoring
7929 // the list, and treating it the same as a flush without a list."
7930 //
7931 // The argument list is discarded so that, flush with a list is
7932 // treated same as a flush without a list.
7933 ompBuilder->createFlush(builder.saveIP());
7934 return success();
7935 })
7936 .Case([&](omp::ParallelOp op) {
7937 return convertOmpParallel(op, builder, moduleTranslation);
7938 })
7939 .Case([&](omp::MaskedOp) {
7940 return convertOmpMasked(*op, builder, moduleTranslation);
7941 })
7942 .Case([&](omp::MasterOp) {
7943 return convertOmpMaster(*op, builder, moduleTranslation);
7944 })
7945 .Case([&](omp::CriticalOp) {
7946 return convertOmpCritical(*op, builder, moduleTranslation);
7947 })
7948 .Case([&](omp::OrderedRegionOp) {
7949 return convertOmpOrderedRegion(*op, builder, moduleTranslation);
7950 })
7951 .Case([&](omp::OrderedOp) {
7952 return convertOmpOrdered(*op, builder, moduleTranslation);
7953 })
7954 .Case([&](omp::WsloopOp) {
7955 return convertOmpWsloop(*op, builder, moduleTranslation);
7956 })
7957 .Case([&](omp::SimdOp) {
7958 return convertOmpSimd(*op, builder, moduleTranslation);
7959 })
7960 .Case([&](omp::AtomicReadOp) {
7961 return convertOmpAtomicRead(*op, builder, moduleTranslation);
7962 })
7963 .Case([&](omp::AtomicWriteOp) {
7964 return convertOmpAtomicWrite(*op, builder, moduleTranslation);
7965 })
7966 .Case([&](omp::AtomicUpdateOp op) {
7967 return convertOmpAtomicUpdate(op, builder, moduleTranslation);
7968 })
7969 .Case([&](omp::AtomicCaptureOp op) {
7970 return convertOmpAtomicCapture(op, builder, moduleTranslation);
7971 })
7972 .Case([&](omp::CancelOp op) {
7973 return convertOmpCancel(op, builder, moduleTranslation);
7974 })
7975 .Case([&](omp::CancellationPointOp op) {
7976 return convertOmpCancellationPoint(op, builder, moduleTranslation);
7977 })
7978 .Case([&](omp::SectionsOp) {
7979 return convertOmpSections(*op, builder, moduleTranslation);
7980 })
7981 .Case([&](omp::SingleOp op) {
7982 return convertOmpSingle(op, builder, moduleTranslation);
7983 })
7984 .Case([&](omp::TeamsOp op) {
7985 return convertOmpTeams(op, builder, moduleTranslation);
7986 })
7987 .Case([&](omp::TaskOp op) {
7988 return convertOmpTaskOp(op, builder, moduleTranslation);
7989 })
7990 .Case([&](omp::TaskloopWrapperOp op) {
7991 return convertOmpTaskloopWrapperOp(op, builder, moduleTranslation);
7992 })
7993 .Case([&](omp::TaskloopContextOp op) {
7994 return convertOmpTaskloopContextOp(op, builder, moduleTranslation);
7995 })
7996 .Case([&](omp::TaskgroupOp op) {
7997 return convertOmpTaskgroupOp(op, builder, moduleTranslation);
7998 })
7999 .Case([&](omp::TaskwaitOp op) {
8000 return convertOmpTaskwaitOp(op, builder, moduleTranslation);
8001 })
8002 .Case<omp::YieldOp, omp::TerminatorOp, omp::DeclareMapperOp,
8003 omp::DeclareMapperInfoOp, omp::DeclareReductionOp,
8004 omp::CriticalDeclareOp>([](auto op) {
8005 // `yield` and `terminator` can be just omitted. The block structure
8006 // was created in the region that handles their parent operation.
8007 // `declare_reduction` will be used by reductions and is not
8008 // converted directly, skip it.
8009 // `declare_mapper` and `declare_mapper.info` are handled whenever
8010 // they are referred to through a `map` clause.
8011 // `critical.declare` is only used to declare names of critical
8012 // sections which will be used by `critical` ops and hence can be
8013 // ignored for lowering. The OpenMP IRBuilder will create unique
8014 // name for critical section names.
8015 return success();
8016 })
8017 .Case([&](omp::ThreadprivateOp) {
8018 return convertOmpThreadprivate(*op, builder, moduleTranslation);
8019 })
8020 .Case<omp::TargetDataOp, omp::TargetEnterDataOp,
8021 omp::TargetExitDataOp, omp::TargetUpdateOp>([&](auto op) {
8022 return convertOmpTargetData(op, builder, moduleTranslation);
8023 })
8024 .Case([&](omp::TargetOp) {
8025 return convertOmpTarget(*op, builder, moduleTranslation);
8026 })
8027 .Case([&](omp::DistributeOp) {
8028 return convertOmpDistribute(*op, builder, moduleTranslation);
8029 })
8030 .Case([&](omp::LoopNestOp) {
8031 return convertOmpLoopNest(*op, builder, moduleTranslation);
8032 })
8033 .Case<omp::MapInfoOp, omp::MapBoundsOp, omp::PrivateClauseOp,
8034 omp::AffinityEntryOp, omp::IteratorOp>([&](auto op) {
8035 // No-op, should be handled by relevant owning operations e.g.
8036 // TargetOp, TargetEnterDataOp, TargetExitDataOp, TargetDataOp
8037 // etc. and then discarded
8038 return success();
8039 })
8040 .Case([&](omp::NewCliOp op) {
8041 // Meta-operation: Doesn't do anything by itself, but used to
8042 // identify a loop.
8043 return success();
8044 })
8045 .Case([&](omp::CanonicalLoopOp op) {
8046 return convertOmpCanonicalLoopOp(op, builder, moduleTranslation);
8047 })
8048 .Case([&](omp::UnrollHeuristicOp op) {
8049 // FIXME: Handling omp.unroll_heuristic as an executable requires
8050 // that the generator (e.g. omp.canonical_loop) has been seen first.
8051 // For construct that require all codegen to occur inside a callback
8052 // (e.g. OpenMPIRBilder::createParallel), all codegen of that
8053 // contained region including their transformations must occur at
8054 // the omp.canonical_loop.
8055 return applyUnrollHeuristic(op, builder, moduleTranslation);
8056 })
8057 .Case([&](omp::TileOp op) {
8058 return applyTile(op, builder, moduleTranslation);
8059 })
8060 .Case([&](omp::FuseOp op) {
8061 return applyFuse(op, builder, moduleTranslation);
8062 })
8063 .Case([&](omp::TargetAllocMemOp) {
8064 return convertTargetAllocMemOp(*op, builder, moduleTranslation);
8065 })
8066 .Case([&](omp::TargetFreeMemOp) {
8067 return convertTargetFreeMemOp(*op, builder, moduleTranslation);
8068 })
8069 .Case([&](omp::AllocateDirOp) {
8070 return convertAllocateDirOp(*op, builder, moduleTranslation, *this);
8071 })
8072 .Case([&](omp::AllocateFreeOp) {
8073 return convertAllocateFreeOp(*op, builder, moduleTranslation,
8074 *this);
8075 })
8076 .Default([&](Operation *inst) {
8077 return inst->emitError()
8078 << "not yet implemented: " << inst->getName();
8079 });
8080
8081 if (isOutermostLoopWrapper)
8082 moduleTranslation.stackPop();
8083
8084 return result;
8085}
8086
8088 registry.insert<omp::OpenMPDialect>();
8089 registry.addExtension(+[](MLIRContext *ctx, omp::OpenMPDialect *dialect) {
8090 dialect->addInterfaces<OpenMPDialectLLVMIRTranslationInterface>();
8091 });
8092}
8093
8095 DialectRegistry registry;
8097 context.appendDialectRegistry(registry);
8098}
for(Operation *op :ops)
return success()
lhs
static ze_device_handle_t getDevice(const uint32_t driverIdx=0, const int32_t devIdx=0)
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
if(!isCopyOut)
static bool mapTypeToBool(ClauseMapFlags value, ClauseMapFlags flag)
static mlir::LogicalResult buildDependData(OperandRange dependVars, std::optional< ArrayAttr > dependKinds, OperandRange dependIterated, std::optional< ArrayAttr > dependIteratedKinds, llvm::IRBuilderBase &builder, mlir::LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::DependenciesInfo &taskDeps)
static void handleDeclareTargetMapVar(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::Function *func)
static LogicalResult convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP atomic update operation using OpenMPIRBuilder.
static llvm::omp::OrderKind convertOrderKind(std::optional< omp::ClauseOrderKind > o)
Convert Order attribute to llvm::omp::OrderKind.
static LogicalResult convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
owningDataPtrPtrReductionGens[i]
static LogicalResult convertOmpTaskloopContextOp(omp::TaskloopContextOp contextOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static Operation * getGlobalOpFromValue(Value value)
static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind convertToCaptureClauseKind(mlir::omp::DeclareTargetCaptureClause captureClause)
static mlir::LogicalResult convertIteratorRegion(llvm::Value *linearIV, IteratorInfo &iterInfo, mlir::Block &iteratorRegionBlock, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo, bool first)
static OpTy castOrGetParentOfType(Operation *op, bool immediateParent=false)
If op is of the given type parameter, return it casted to that type. Otherwise, if its immediate pare...
static LogicalResult convertOmpOrderedRegion(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'ordered_region' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertTargetFreeMemOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static void processMapMembersWithParent(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, llvm::omp::OpenMPOffloadMappingFlags memberOfFlag, TargetDirectiveEnumTy targetDirective)
static LogicalResult convertOmpAtomicWrite(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an omp.atomic.write operation to LLVM IR.
static OwningAtomicReductionGen makeAtomicReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Create an OpenMPIRBuilder-compatible atomic reduction generator for the given reduction declaration.
static OwningDataPtrPtrReductionGen makeRefDataPtrGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, bool isByRef)
Create an OpenMPIRBuilder-compatible data_ptr_ptr reduction generator for the given reduction declara...
static void popCancelFinalizationCB(const ArrayRef< llvm::UncondBrInst * > cancelTerminators, llvm::OpenMPIRBuilder &ompBuilder, const llvm::OpenMPIRBuilder::InsertPointTy &afterIP)
If we cancelled the construct, we should branch to the finalization block of that construct....
static llvm::Value * getRefPtrIfDeclareTarget(Value value, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult applyUnrollHeuristic(omp::UnrollHeuristicOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Apply a #pragma omp unroll / "!$omp unroll" transformation using the OpenMPIRBuilder.
static LogicalResult convertOmpMaster(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'master' operation into LLVM IR using OpenMPIRBuilder.
static void getAsIntegers(ArrayAttr values, llvm::SmallVector< int64_t > &ints)
static llvm::Value * findAssociatedValue(Value privateVar, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Return the llvm::Value * corresponding to the privateVar that is being privatized....
static ArrayRef< bool > getIsByRef(std::optional< ArrayRef< bool > > attr)
static llvm::Expected< llvm::Value * > lookupOrTranslatePureValue(Value value, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder)
Look up the given value in the mapping, and if it's not there, translate its defining operation at th...
static LogicalResult convertOmpTaskwaitOp(omp::TaskwaitOp twOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP loop nest into LLVM IR using OpenMPIRBuilder.
static mlir::LogicalResult fillIteratorLoop(mlir::omp::IteratorOp itersOp, llvm::IRBuilderBase &builder, mlir::LLVM::ModuleTranslation &moduleTranslation, IteratorInfo &iterInfo, llvm::StringRef loopName, IteratorStoreEntryTy genStoreEntry)
static LogicalResult cleanupPrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, Location loc, SmallVectorImpl< llvm::Value * > &llvmPrivateVars, SmallVectorImpl< omp::PrivateClauseOp > &privateDecls)
static void createAlteredByCaptureMap(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder)
static LogicalResult convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP task construct into LLVM IR using OpenMPIRBuilder.
static void genMapInfos(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, TargetDirectiveEnumTy targetDirective)
static llvm::AtomicOrdering convertAtomicOrdering(std::optional< omp::ClauseMemoryOrderKind > ao)
Convert an Atomic Ordering attribute to llvm::AtomicOrdering.
static void setInsertPointForPossiblyEmptyBlock(llvm::IRBuilderBase &builder, llvm::BasicBlock *block=nullptr)
llvm::function_ref< void(llvm::Value *linearIV, mlir::omp::YieldOp yield)> IteratorStoreEntryTy
static LogicalResult convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'critical' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, TargetDirectiveEnumTy targetDirective)
static omp::DistributeOp getDistributeCapturingTeamsReduction(omp::TeamsOp teamsOp)
static LogicalResult convertOmpCanonicalLoopOp(omp::CanonicalLoopOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Convert an omp.canonical_loop to LLVM-IR.
static LogicalResult convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static std::optional< int64_t > extractConstInteger(Value value)
If the given value is defined by an llvm.mlir.constant operation and it is of an integer type,...
static llvm::Expected< llvm::Value * > initPrivateVar(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, omp::PrivateClauseOp &privDecl, llvm::Value *nonPrivateVar, BlockArgument &blockArg, llvm::Value *llvmPrivateVar, llvm::BasicBlock *privInitBlock, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Initialize a single (first)private variable. You probably want to use allocateAndInitPrivateVars inst...
static mlir::LogicalResult buildAffinityData(mlir::omp::TaskOp &taskOp, llvm::IRBuilderBase &builder, mlir::LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::AffinityData &ad)
static LogicalResult allocAndInitializeReductionVars(OP op, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, llvm::ArrayRef< bool > isByRef)
static LogicalResult convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Function * getOmpTargetAlloc(llvm::IRBuilderBase &builder, llvm::Module *llvmModule)
static llvm::Expected< llvm::Function * > emitUserDefinedMapper(Operation *declMapperOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::StringRef mapperFuncName, TargetDirectiveEnumTy targetDirective)
static llvm::Expected< llvm::BasicBlock * > allocatePrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, PrivateVarsInfo &privateVarsInfo, const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
Allocate and initialize delayed private variables. Returns the basic block which comes after all of t...
static LogicalResult convertOmpOrdered(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'ordered' operation into LLVM IR using OpenMPIRBuilder.
static void processMapWithMembersOf(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, TargetDirectiveEnumTy targetDirective)
static LogicalResult convertOmpMasked(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP 'masked' operation into LLVM IR using OpenMPIRBuilder.
static llvm::AtomicRMWInst::BinOp convertBinOpToAtomic(Operation &op)
Converts an LLVM dialect binary operation to the corresponding enum value for atomicrmw supported bin...
static LogicalResult convertOmpCancel(omp::CancelOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp)
allocatedType moduleTranslation static convertType(allocatedType) LogicalResult inlineOmpRegionCleanup(llvm::SmallVectorImpl< Region * > &cleanupRegions, llvm::ArrayRef< llvm::Value * > privateVariables, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, StringRef regionName, bool shouldLoadCleanupRegionArg=true)
handling of DeclareReductionOp's cleanup region
static LogicalResult applyFuse(omp::FuseOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Apply a #pragma omp fuse / !$omp fuse transformation using the OpenMPIRBuilder.
static llvm::SmallString< 64 > getDeclareTargetRefPtrSuffix(LLVM::GlobalOp globalOp, llvm::OpenMPIRBuilder &ompBuilder)
static llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Error initPrivateVars(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, PrivateVarsInfo &privateVarsInfo, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
static llvm::CanonicalLoopInfo * findCurrentLoopInfo(LLVM::ModuleTranslation &moduleTranslation)
Find the loop information structure for the loop nest being translated.
static OwningReductionGen makeReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Create an OpenMPIRBuilder-compatible reduction generator for the given reduction declaration.
static std::vector< llvm::Value * > calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, bool isArrayTy, OperandRange bounds)
This function calculates the array/pointer offset for map data provided with bounds operations,...
static void storeAffinityEntry(llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, llvm::Value *affinityList, llvm::Value *index, llvm::Value *addr, llvm::Value *len)
static LogicalResult convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts the OpenMP parallel operation to LLVM IR.
static void pushCancelFinalizationCB(SmallVectorImpl< llvm::UncondBrInst * > &cancelTerminators, llvm::IRBuilderBase &llvmBuilder, llvm::OpenMPIRBuilder &ompBuilder, mlir::Operation *op, llvm::omp::Directive cancelDirective)
Shared implementation of a callback which adds a termiator for the new block created for the branch t...
static llvm::OpenMPIRBuilder::InsertPointTy findAllocaInsertPoint(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Find the insertion point for allocas given the current insertion point for normal operations in the b...
static LogicalResult inlineConvertOmpRegions(Region &region, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::Value * > *continuationBlockArgs=nullptr)
Translates the blocks contained in the given region and appends them to at the current insertion poin...
static LogicalResult convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP Threadprivate operation into LLVM IR using OpenMPIRBuilder.
static omp::PrivateClauseOp findPrivatizer(Operation *from, SymbolRefAttr symbolName)
Looks up from the operation from and returns the PrivateClauseOp with name symbolName.
static LogicalResult convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Expected< llvm::Function * > getOrCreateUserDefinedMapperFunc(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, TargetDirectiveEnumTy targetDirective)
static uint64_t getTypeByteSize(mlir::Type type, const DataLayout &dl)
static void processIndividualMap(MapInfoData &mapData, size_t mapDataIdx, MapInfosTy &combinedInfo, TargetDirectiveEnumTy targetDirective, int mapDataParentIdx=-1)
static void getTargetEntryUniqueInfo(llvm::TargetRegionEntryInfo &targetInfo, omp::TargetOp targetOp, llvm::StringRef parentName="")
static void extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads, Value &numTeamsLower, Value &numTeamsUpper, Value &threadLimit, llvm::SmallVectorImpl< Value > *lowerBounds=nullptr, llvm::SmallVectorImpl< Value > *upperBounds=nullptr, llvm::SmallVectorImpl< Value > *steps=nullptr)
Follow uses of host_eval-defined block arguments of the given omp.target operation and populate outpu...
static llvm::Expected< llvm::BasicBlock * > convertOmpOpRegions(Region &region, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::PHINode * > *continuationBlockPHIs=nullptr)
Converts the given region that appears within an OpenMP dialect operation to LLVM IR,...
static LogicalResult copyFirstPrivateVars(mlir::Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::Value * > &moldVars, ArrayRef< llvm::Value * > llvmPrivateVars, SmallVectorImpl< omp::PrivateClauseOp > &privateDecls, bool insertBarrier, llvm::DenseMap< Value, Value > *mappedPrivateVars=nullptr)
static LogicalResult allocReductionVars(T loop, ArrayRef< BlockArgument > reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, SmallVectorImpl< llvm::Value * > &privateReductionVariables, DenseMap< Value, llvm::Value * > &reductionVariableMap, SmallVectorImpl< DeferredStore > &deferredStores, llvm::ArrayRef< bool > isByRefs)
Allocate space for privatized reduction variables.
static LogicalResult convertAllocateDirOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, const OpenMPDialectLLVMIRTranslationInterface &ompIface)
static bool constructIsCancellable(Operation *op)
Returns true if the construct contains omp.cancel or omp.cancellation_point.
static llvm::omp::OpenMPOffloadMappingFlags convertClauseMapFlags(omp::ClauseMapFlags mlirFlags)
static void buildDependDataLocator(std::optional< ArrayAttr > dependKinds, OperandRange dependVars, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl< llvm::OpenMPIRBuilder::DependData > &dds)
static llvm::IRBuilderBase::InsertPoint createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg, llvm::Value *input, llvm::Value *&retVal, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase::InsertPoint allocaIP, llvm::IRBuilderBase::InsertPoint codeGenIP)
static void mapInitializationArgs(T loop, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, DenseMap< Value, llvm::Value * > &reductionVariableMap, unsigned i)
Map input arguments to reduction initialization region.
static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind)
Convert ProcBindKind from MLIR-generated enum to LLVM enum.
static void fillAffinityLocators(Operation::operand_range affinityVars, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::Value *affinityList)
static LogicalResult convertOmpTaskloopWrapperOp(omp::TaskloopWrapperOp loopWrapperOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
The correct entry point is convertOmpTaskloopContextOp. This gets called whilst lowering the body of ...
static void getOverlappedMembers(llvm::SmallVectorImpl< size_t > &overlapMapDataIdxs, omp::MapInfoOp parentOp)
static LogicalResult convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP single construct into LLVM IR using OpenMPIRBuilder.
static bool isDeclareTargetTo(Value value)
static uint64_t getArrayElementSizeInBits(LLVM::LLVMArrayType arrTy, DataLayout &dl)
static void collectReductionDecls(T op, SmallVectorImpl< omp::DeclareReductionOp > &reductions)
Populates reductions with reduction declarations used in the given op.
static LogicalResult handleError(llvm::Error error, Operation &op)
static LogicalResult convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind convertToDeviceClauseKind(mlir::omp::DeclareTargetDeviceType deviceClause)
static llvm::Error computeTaskloopBounds(omp::LoopNestOp loopOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::Value *&lbVal, llvm::Value *&ubVal, llvm::Value *&stepVal)
static LogicalResult checkImplementationStatus(Operation &op)
Check whether translation to LLVM IR for the given operation is currently supported.
static LogicalResult createReductionsAndCleanup(OP op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl< omp::DeclareReductionOp > &reductionDecls, ArrayRef< llvm::Value * > privateReductionVariables, ArrayRef< bool > isByRef, bool isNowait=false, bool isTeamsReduction=false)
static LogicalResult convertOmpCancellationPoint(omp::CancellationPointOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static bool opIsInSingleThread(mlir::Operation *op)
This can't always be determined statically, but when we can, it is good to avoid generating compiler-...
static uint64_t getReductionDataSize(OpTy &op)
static LogicalResult convertOmpAtomicRead(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Convert omp.atomic.read operation to LLVM IR.
static llvm::omp::Directive convertCancellationConstructType(omp::ClauseCancellationConstructType directive)
static void initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp, llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs, bool isTargetDevice, bool isGPU)
Populate default MinTeams, MaxTeams and MaxThreads to their default values as stated by the correspon...
static llvm::omp::RTLDependenceKindTy convertDependKind(mlir::omp::ClauseTaskDepend kind)
static void initTargetRuntimeAttrs(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, omp::TargetOp targetOp, Operation *capturedOp, llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs)
Gather LLVM runtime values for all clauses evaluated in the host that are passed to the kernel invoca...
static LogicalResult convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static Value getBaseValueForTypeLookup(Value value)
static bool isHostDeviceOp(Operation *op)
static bool isDeclareTargetLink(Value value)
static LogicalResult convertFlagsAttr(Operation *op, mlir::omp::FlagsAttr attribute, LLVM::ModuleTranslation &moduleTranslation)
Lowers the FlagsAttr which is applied to the module on the device pass when offloading,...
static bool checkIfPointerMap(omp::MapInfoOp mapOp)
static LogicalResult applyTile(omp::TileOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Apply a #pragma omp tile / !$omp tile transformation using the OpenMPIRBuilder.
static LogicalResult convertAllocateFreeOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, const OpenMPDialectLLVMIRTranslationInterface &ompIface)
static void sortMapIndices(llvm::SmallVectorImpl< size_t > &indices, omp::MapInfoOp mapInfo)
static llvm::Function * getOmpTargetFree(llvm::IRBuilderBase &builder, llvm::Module *llvmModule)
static LogicalResult convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder.
static void collectMapDataFromMapOperands(MapInfoData &mapData, SmallVectorImpl< Value > &mapVars, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, llvm::IRBuilderBase &builder, ArrayRef< Value > useDevPtrOperands={}, ArrayRef< Value > useDevAddrOperands={}, ArrayRef< Value > hasDevAddrOperands={})
static void extractAtomicControlFlags(omp::AtomicUpdateOp atomicUpdateOp, bool &isIgnoreDenormalMode, bool &isFineGrainedMemory, bool &isRemoteMemory)
static Operation * genLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr, unsigned numCases, bool needsUniv, ArrayRef< TensorLevel > tidLvls)
Generates a for-loop or a while-loop, depending on whether it implements singleton iteration or co-it...
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition TypeID.h:331
#define div(a, b)
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:306
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
Operation & front()
Definition Block.h:163
Operation & back()
Definition Block.h:162
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
iterator begin()
Definition Block.h:153
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
An instance of this location represents a tuple of file, line number, and column number.
Definition Location.h:174
Implementation class for module translation.
llvm::BasicBlock * lookupBlock(Block *block) const
Finds an LLVM IR basic block that corresponds to the given MLIR block.
WalkResult stackWalk(llvm::function_ref< WalkResult(T &)> callback)
Calls callback for every ModuleTranslation stack frame of type T starting from the top of the stack.
void stackPush(Args &&...args)
Creates a stack frame of type T on ModuleTranslation stack.
LogicalResult convertBlock(Block &bb, bool ignoreArguments, llvm::IRBuilderBase &builder)
Translates the contents of the given block to LLVM IR using this translator.
SmallVector< llvm::Value * > lookupValues(ValueRange values)
Looks up remapped a list of remapped values.
void mapFunction(StringRef name, llvm::Function *func)
Stores the mapping between a function name and its LLVM IR representation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
void invalidateOmpLoop(omp::NewCliOp mlir)
Mark an OpenMP loop as having been consumed.
SymbolTableCollection & symbolTable()
llvm::Type * convertType(Type type)
Converts the type from MLIR LLVM dialect to LLVM.
llvm::OpenMPIRBuilder * getOpenMPBuilder()
Returns the OpenMP IR builder associated with the LLVM IR module being constructed.
void mapOmpLoop(omp::NewCliOp mlir, llvm::CanonicalLoopInfo *llvm)
Map an MLIR OpenMP dialect CanonicalLoopInfo to its lowered LLVM-IR OpenMPIRBuilder CanonicalLoopInfo...
llvm::GlobalValue * lookupGlobal(Operation *op)
Finds an LLVM IR global value that corresponds to the given MLIR operation defining a global value.
SaveStateStack< T, ModuleTranslation > SaveStack
RAII object calling stackPush/stackPop on construction/destruction.
LogicalResult convertOperation(Operation &op, llvm::IRBuilderBase &builder)
Converts the given MLIR operation into LLVM IR using this translator.
llvm::Function * lookupFunction(StringRef name) const
Finds an LLVM IR function by its name.
void mapBlock(Block *mlir, llvm::BasicBlock *llvm)
Stores the mapping between an MLIR block and LLVM IR basic block.
llvm::Module * getLLVMModule()
Returns the LLVM module in which the IR is being constructed.
void stackPop()
Pops the last element from the ModuleTranslation stack.
void forgetMapping(Region &region)
Removes the mapping for blocks contained in the region and values defined in these blocks.
void mapValue(Value mlir, llvm::Value *llvm)
Stores the mapping between an MLIR value and its LLVM IR counterpart.
llvm::CanonicalLoopInfo * lookupOMPLoop(omp::NewCliOp mlir) const
Find the LLVM-IR loop that represents an MLIR loop.
llvm::LLVMContext & getLLVMContext() const
Returns the LLVM context in which the IR is being constructed.
Utility class to translate MLIR LLVM dialect types to LLVM IR.
Definition TypeToLLVM.h:39
unsigned getPreferredAlignment(Type type, const llvm::DataLayout &layout)
Returns the preferred alignment for the type given the data layout.
T findInstanceOf()
Return an instance of the given location type if one is nested under the current location.
Definition Location.h:45
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
void appendDialectRegistry(const DialectRegistry &registry)
Append the contents of the given dialect registry to the registry associated with this context.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:44
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:712
Value getOperand(unsigned idx)
Definition Operation.h:376
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:433
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:700
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:241
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:252
unsigned getNumOperands()
Definition Operation.h:372
OperandRange operand_range
Definition Operation.h:397
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:256
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:116
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:404
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:823
user_range getUsers()
Returns a range of all users.
Definition Operation.h:899
result_range getResults()
Definition Operation.h:441
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:234
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:430
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
BlockArgListType getArguments()
Definition Region.h:81
bool empty()
Definition Region.h:60
unsigned getNumArguments()
Definition Region.h:123
iterator begin()
Definition Region.h:55
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition Region.h:200
BlockListType & getBlocks()
Definition Region.h:45
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
Concrete CRTP base class for StateStack frames.
Definition StateStack.h:47
@ Private
The symbol is private and may only be referenced by SymbolRefAttrs local to the operations within the...
Definition SymbolTable.h:97
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult skip()
Definition WalkResult.h:48
static WalkResult advance()
Definition WalkResult.h:47
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition WalkResult.h:51
static WalkResult interrupt()
Definition WalkResult.h:46
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Definition CallGraph.h:229
void connectPHINodes(Region &region, const ModuleTranslation &state)
For all blocks in the region that were converted to LLVM IR using the given ModuleTranslation,...
llvm::Constant * createMappingInformation(Location loc, llvm::OpenMPIRBuilder &builder)
Create a constant string representing the mapping information extracted from the MLIR location inform...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
auto getDims(VectorType vType)
Returns a range over the dims (size and scalability) of a VectorType.
Include the generated interface declarations.
SetVector< Block * > getBlocksSortedByDominance(Region &region)
Gets a list of blocks that is sorted according to dominance.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool isPure(Operation *op)
Returns true if the given operation is pure, i.e., is speculatable that does not touch memory.
void registerOpenMPDialectTranslation(DialectRegistry &registry)
Register the OpenMP dialect and the translation from it to the LLVM IR in the given registry;.
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:125
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition Utils.cpp:1330
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:120
A util to collect info needed to convert delayed privatizers from MLIR to LLVM.
SmallVector< mlir::Value > mlirVars
SmallVector< omp::PrivateClauseOp > privatizers
MutableArrayRef< BlockArgument > blockArgs
SmallVector< llvm::Value * > llvmVars