201#include "llvm/ADT/SmallVectorExtras.h"
215#include "llvm/ADT/STLExtras.h"
216#include "llvm/ADT/SmallVector.h"
217#include "llvm/ADT/TypeSwitch.h"
218#include "llvm/Support/ErrorHandling.h"
219#include <type_traits>
223#define GEN_PASS_DEF_ACCIMPLICITDATA
224#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
228#define DEBUG_TYPE "acc-implicit-data"
234class ACCImplicitData :
public acc::impl::ACCImplicitDataBase<ACCImplicitData> {
236 using acc::impl::ACCImplicitDataBase<ACCImplicitData>::ACCImplicitDataBase;
238 void runOnOperation()
override;
243 template <
typename OpT>
244 Operation *getOriginalDataClauseOpForAlias(
250 template <
typename OpT>
251 Operation *generateDataClauseOpForCandidate(
252 Value var, ModuleOp &module,
OpBuilder &builder, OpT computeConstructOp,
254 const std::optional<acc::ClauseDefaultValue> &defaultClause);
257 template <
typename OpT>
259 generateImplicitDataOps(ModuleOp &module, OpT computeConstructOp,
260 std::optional<acc::ClauseDefaultValue> &defaultClause,
264 acc::PrivateRecipeOp generatePrivateRecipe(ModuleOp &module,
Value var,
269 acc::FirstprivateRecipeOp
270 generateFirstprivateRecipe(ModuleOp &module,
Value var,
Location loc,
275 void generateRecipes(ModuleOp &module,
OpBuilder &builder,
282static bool isCandidateForImplicitData(
Value val,
Region &accRegion,
291 if (isa_and_nonnull<ACC_DATA_ENTRY_OPS>(val.
getDefiningOp()))
305template <
typename OpT>
306Operation *ACCImplicitData::getOriginalDataClauseOpForAlias(
309 auto &aliasAnalysis = this->getAnalysis<AliasAnalysis>();
310 for (
auto dataClause : dominatingDataClauses) {
311 if (
auto *dataClauseOp = dataClause.getDefiningOp()) {
313 if (isa<acc::CopyinOp, acc::CreateOp, acc::PresentOp, acc::NoCreateOp,
314 acc::DevicePtrOp>(dataClauseOp))
315 if (aliasAnalysis.alias(
acc::getVar(dataClauseOp), var).isMust())
323static void fillInBoundsForUnknownDimensions(
Operation *dataClauseOp,
334 if (
auto mappableTy = dyn_cast<acc::MappableType>(type)) {
335 if (mappableTy.hasUnknownDimensions()) {
338 if (std::is_same_v<
decltype(dataClauseOp), acc::DevicePtrOp>)
342 auto bounds = mappableTy.generateAccBounds(var, builder);
344 dataClauseOp.getBoundsMutable().assign(bounds);
351ACCImplicitData::generatePrivateRecipe(ModuleOp &module,
Value var,
355 std::string recipeName =
356 accSupport.
getRecipeName(acc::RecipeKind::private_recipe, type, var);
359 auto existingRecipe =
module.lookupSymbol<acc::PrivateRecipeOp>(recipeName);
361 return existingRecipe;
368 acc::PrivateRecipeOp::createAndPopulate(builder, loc, recipeName, type);
369 if (!recipe.has_value())
370 return accSupport.
emitNYI(loc,
"implicit private"),
nullptr;
371 return recipe.value();
374acc::FirstprivateRecipeOp
375ACCImplicitData::generateFirstprivateRecipe(ModuleOp &module,
Value var,
379 std::string recipeName =
380 accSupport.
getRecipeName(acc::RecipeKind::firstprivate_recipe, type, var);
383 auto existingRecipe =
384 module.lookupSymbol<acc::FirstprivateRecipeOp>(recipeName);
386 return existingRecipe;
392 auto recipe = acc::FirstprivateRecipeOp::createAndPopulate(builder, loc,
394 if (!recipe.has_value())
395 return accSupport.
emitNYI(loc,
"implicit firstprivate"),
nullptr;
396 return recipe.value();
399void ACCImplicitData::generateRecipes(ModuleOp &module,
OpBuilder &builder,
402 auto &accSupport = this->getAnalysis<acc::OpenACCSupport>();
403 for (
auto var : newOperands) {
406 auto recipe = generatePrivateRecipe(
409 privateOp.setRecipeAttr(
410 SymbolRefAttr::get(module->getContext(), recipe.getSymName()));
411 }
else if (
auto firstprivateOp = var.
getDefiningOp<acc::FirstprivateOp>()) {
412 auto recipe = generateFirstprivateRecipe(
415 firstprivateOp.setRecipeAttr(SymbolRefAttr::get(
416 module->getContext(), recipe.getSymName().str()));
432template <
typename OpT>
433Operation *ACCImplicitData::generateDataClauseOpForCandidate(
434 Value var, ModuleOp &module,
OpBuilder &builder, OpT computeConstructOp,
436 const std::optional<acc::ClauseDefaultValue> &defaultClause) {
437 auto &accSupport = this->getAnalysis<acc::OpenACCSupport>();
438 acc::VariableTypeCategory typeCategory =
439 acc::VariableTypeCategory::uncategorized;
440 if (
auto mappableTy = dyn_cast<acc::MappableType>(var.
getType())) {
441 typeCategory = mappableTy.getTypeCategory(var);
442 }
else if (
auto pointerLikeTy =
443 dyn_cast<acc::PointerLikeType>(var.
getType())) {
444 typeCategory = pointerLikeTy.getPointeeTypeCategory(
446 pointerLikeTy.getElementType());
450 acc::bitEnumContainsAny(typeCategory, acc::VariableTypeCategory::scalar);
451 bool isAnyAggregate = acc::bitEnumContainsAny(
452 typeCategory, acc::VariableTypeCategory::aggregate);
453 Location loc = computeConstructOp->getLoc();
456 op = getOriginalDataClauseOpForAlias(var, builder, computeConstructOp,
457 dominatingDataClauses);
459 if (isa<acc::NoCreateOp>(op))
460 return acc::NoCreateOp::create(builder, loc, var,
465 if (isa<acc::DevicePtrOp>(op))
466 return acc::DevicePtrOp::create(builder, loc, var,
473 return acc::PresentOp::create(builder, loc, var,
481 return acc::DevicePtrOp::create(builder, loc, var,
487 if (enableImplicitReductionCopy &&
489 computeConstructOp->getRegion(0))) {
491 acc::CopyinOp::create(builder, loc, var,
494 copyinOp.setDataClause(acc::DataClause::acc_reduction);
495 return copyinOp.getOperation();
497 if constexpr (std::is_same_v<OpT, acc::KernelsOp> ||
498 std::is_same_v<OpT, acc::KernelEnvironmentOp>) {
506 acc::CopyinOp::create(builder, loc, var,
509 copyinOp.setDataClause(acc::DataClause::acc_copy);
510 return copyinOp.getOperation();
513 return acc::FirstprivateOp::create(builder, loc, var,
517 }
else if (isAnyAggregate) {
521 if (defaultClause.has_value() &&
522 defaultClause.value() == acc::ClauseDefaultValue::Present) {
523 newDataOp = acc::PresentOp::create(builder, loc, var,
530 acc::CopyinOp::create(builder, loc, var,
533 copyinOp.setDataClause(acc::DataClause::acc_copy);
534 newDataOp = copyinOp.getOperation();
543 LLVM_DEBUG(llvm::dbgs()
544 <<
"Unhandled case for implicit data mapping " << var <<
"\n");
559static void legalizeValuesInRegion(
Region &accRegion,
562 for (
Value dataClause :
563 llvm::concat<Value>(newDataClauseOperands, newPrivateOperands)) {
570template <
typename OpT>
571static void addNewPrivateOperands(OpT &accOp,
573 if (privateOperands.empty())
576 for (
auto priv : privateOperands) {
577 if (isa<acc::PrivateOp>(priv.getDefiningOp())) {
578 accOp.getPrivateOperandsMutable().append(priv);
579 }
else if (isa<acc::FirstprivateOp>(priv.getDefiningOp())) {
580 accOp.getFirstprivateOperandsMutable().append(priv);
582 llvm_unreachable(
"unhandled reduction operand");
589 for (
auto *user : res.getUsers())
590 if (isa<ACC_DATA_EXIT_OPS>(user))
608 Value lastDataClause =
nullptr;
609 for (
auto dataEntry : llvm::reverse(sortedDataClauseOperands)) {
610 if (llvm::find(newDataClauseOperands, dataEntry) ==
611 newDataClauseOperands.end()) {
614 lastDataClause = dataEntry;
618 if (
auto *dataExitOp = findDataExitOp(lastDataClause.
getDefiningOp()))
620 Operation *dataEntryOp = dataEntry.getDefiningOp();
621 if (isa<acc::CopyinOp>(dataEntryOp)) {
622 auto copyoutOp = acc::CopyoutOp::create(
626 copyoutOp.setDataClause(acc::DataClause::acc_copy);
627 }
else if (isa<acc::PresentOp, acc::NoCreateOp>(dataEntryOp)) {
628 auto deleteOp = acc::DeleteOp::create(
629 builder, dataEntryOp->
getLoc(), dataEntry,
633 }
else if (isa<acc::DevicePtrOp>(dataEntryOp)) {
636 llvm_unreachable(
"unhandled data exit");
638 lastDataClause = dataEntry;
649 baseRefs.push_back(val);
654 if (val != baseRefs.front())
655 baseRefs.insert(baseRefs.begin(), val);
659 if (
auto viewLikeOp = val.
getDefiningOp<ViewLikeOpInterface>()) {
660 val = viewLikeOp.getViewSource();
661 baseRefs.insert(baseRefs.begin(), val);
675 std::find_if(sortedDataClauseOperands.begin(),
676 sortedDataClauseOperands.end(), [&](
Value dataClauseVal) {
679 auto var = acc::getVar(dataClauseVal.getDefiningOp());
680 auto baseRefs = getBaseRefsChain(var);
686 return std::find(baseRefs.begin(), baseRefs.end(),
687 acc::getVar(newClause)) != baseRefs.end();
690 if (insertPos != sortedDataClauseOperands.end()) {
691 newClause->
moveBefore(insertPos->getDefiningOp());
692 sortedDataClauseOperands.insert(insertPos,
acc::getAccVar(newClause));
698template <
typename OpT>
699void ACCImplicitData::generateImplicitDataOps(
700 ModuleOp &module, OpT computeConstructOp,
701 std::optional<acc::ClauseDefaultValue> &defaultClause,
705 if (defaultClause.has_value() &&
706 defaultClause.value() == acc::ClauseDefaultValue::None)
708 assert(!defaultClause.has_value() ||
709 defaultClause.value() == acc::ClauseDefaultValue::Present);
712 Region &accRegion = computeConstructOp->getRegion(0);
717 auto isCandidate{[&](
Value val) ->
bool {
718 return isCandidateForImplicitData(val, accRegion, accSupport);
720 auto candidateVars(llvm::filter_to_vector(liveInValues, isCandidate));
721 if (candidateVars.empty())
728 if (!candidateVars.empty()) {
729 LLVM_DEBUG(llvm::dbgs() <<
"== Generating clauses for ==\n"
730 << computeConstructOp <<
"\n");
732 auto &domInfo = this->getAnalysis<DominanceInfo>();
733 auto &postDomInfo = this->getAnalysis<PostDominanceInfo>();
734 auto dominatingDataClauses =
736 for (
auto var : candidateVars) {
737 auto newDataClauseOp = generateDataClauseOpForCandidate(
738 var, module, builder, computeConstructOp, dominatingDataClauses,
740 fillInBoundsForUnknownDimensions(newDataClauseOp, builder);
741 LLVM_DEBUG(llvm::dbgs() <<
"Generated data clause for " << var <<
":\n"
742 <<
"\t" << *newDataClauseOp <<
"\n");
743 if (isa_and_nonnull<acc::PrivateOp, acc::FirstprivateOp, acc::ReductionOp>(
746 }
else if (isa_and_nonnull<ACC_DATA_CLAUSE_OPS>(newDataClauseOp)) {
754 legalizeValuesInRegion(accRegion, newPrivateOperands, newDataClauseOperands);
758 if constexpr (!std::is_same_v<OpT, acc::KernelsOp> &&
759 !std::is_same_v<OpT, acc::KernelEnvironmentOp>)
760 generateRecipes(module, builder, computeConstructOp, newPrivateOperands);
764 computeConstructOp.getDataClauseOperands());
765 for (
auto newClause : newDataClauseOperands)
766 insertInSortedOrder(sortedDataClauseOperands, newClause.getDefiningOp());
769 generateDataExitOperations(builder, computeConstructOp, newDataClauseOperands,
770 sortedDataClauseOperands);
772 if constexpr (!std::is_same_v<OpT, acc::KernelsOp> &&
773 !std::is_same_v<OpT, acc::KernelEnvironmentOp>)
774 addNewPrivateOperands(computeConstructOp, newPrivateOperands);
775 computeConstructOp.getDataClauseOperandsMutable().assign(
776 sortedDataClauseOperands);
779void ACCImplicitData::runOnOperation() {
780 ModuleOp module = this->getOperation();
784 module.walk([&](Operation *op) {
785 if (isa<ACC_COMPUTE_CONSTRUCT_OPS, acc::KernelEnvironmentOp>(op)) {
786 assert(op->getNumRegions() == 1 && "must have 1 region");
792 generateImplicitDataOps(module, op, defaultClause, accSupport);
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
InFlightDiagnostic emitNYI(Location loc, const Twine &message)
Report a case that is not yet supported by the implementation.
bool isValidValueUse(Value v, Region ®ion)
Check if a value use is legal in an OpenACC region.
std::string getVariableName(Value v)
Get the variable name for a given value.
std::string getRecipeName(RecipeKind kind, Type type, Value var)
Get the recipe name for a given type and value.
#define ACC_COMPUTE_CONSTRUCT_OPS
#define ACC_DATA_ENTRY_OPS
#define ACC_DATA_EXIT_OPS
static constexpr StringLiteral getFromDefaultClauseAttrName()
mlir::Value getAccVar(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation.
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
std::optional< mlir::acc::DataClause > getDataClause(mlir::Operation *accDataEntryOp)
Used to obtain the dataClause from a data entry operation.
bool isPointerLikeType(mlir::Type type)
Used to check whether the provided type implements the PointerLikeType interface.
mlir::SmallVector< mlir::Value > getBounds(mlir::Operation *accDataClauseOp)
Used to obtain bounds from an acc data clause operation.
std::optional< ClauseDefaultValue > getDefaultAttr(mlir::Operation *op)
Looks for an OpenACC default attribute on the current operation op or in a parent operation which enc...
bool isOnlyUsedByReductionClauses(mlir::Value val, mlir::Region ®ion)
Returns true if this value is only used by acc.reduction operations in the region.
std::optional< llvm::StringRef > getVarName(mlir::Operation *accOp)
Used to obtain the name from an acc operation.
llvm::SmallVector< mlir::Value > getDominatingDataClauses(mlir::Operation *computeConstructOp, mlir::DominanceInfo &domInfo, mlir::PostDominanceInfo &postDomInfo)
Collects all data clauses that dominate the compute construct.
bool isMappableType(mlir::Type type)
Used to check whether the provided type implements the MappableType interface.
bool isDeviceValue(mlir::Value val)
Check if a value represents device data.
mlir::Value getBaseEntity(mlir::Value val)
Include the generated interface declarations.
void replaceAllUsesInRegionWith(Value orig, Value replacement, Region ®ion)
Replace all uses of orig within the given region with replacement.
llvm::SetVector< T, Vector, Set, N > SetVector
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
void getUsedValuesDefinedAbove(Region ®ion, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
llvm::TypeSwitch< T, ResultT > TypeSwitch