214#include "llvm/ADT/STLExtras.h"
215#include "llvm/ADT/SmallVector.h"
216#include "llvm/ADT/TypeSwitch.h"
217#include "llvm/Support/ErrorHandling.h"
218#include <type_traits>
222#define GEN_PASS_DEF_ACCIMPLICITDATA
223#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
227#define DEBUG_TYPE "acc-implicit-data"
237 void runOnOperation()
override;
247 template <
typename OpT>
248 Operation *getOriginalDataClauseOpForAlias(
254 template <
typename OpT>
255 Operation *generateDataClauseOpForCandidate(
256 Value var, ModuleOp &module,
OpBuilder &builder, OpT computeConstructOp,
258 const std::optional<acc::ClauseDefaultValue> &defaultClause);
261 template <
typename OpT>
262 void generateImplicitDataOps(
263 ModuleOp &module, OpT computeConstructOp,
264 std::optional<acc::ClauseDefaultValue> &defaultClause);
267 acc::PrivateRecipeOp generatePrivateRecipe(ModuleOp &module,
Value var,
272 acc::FirstprivateRecipeOp
273 generateFirstprivateRecipe(ModuleOp &module,
Value var,
Location loc,
278 void generateRecipes(ModuleOp &module,
OpBuilder &builder,
286static bool isCandidateForImplicitData(
Value val,
Region &accRegion) {
294 if (isa_and_nonnull<ACC_DATA_ENTRY_OPS>(val.
getDefiningOp()))
305ACCImplicitData::getDominatingDataClauses(
Operation *computeConstructOp) {
306 llvm::SmallSetVector<Value, 8> dominatingDataClauses;
309 .Case<acc::ParallelOp, acc::KernelsOp, acc::SerialOp>([&](
auto op) {
310 for (
auto dataClause : op.getDataClauseOperands()) {
311 dominatingDataClauses.insert(dataClause);
318 while (currParentOp) {
319 if (isa<acc::DataOp>(currParentOp)) {
320 for (
auto dataClause :
321 dyn_cast<acc::DataOp>(currParentOp).getDataClauseOperands()) {
322 dominatingDataClauses.insert(dataClause);
325 currParentOp = currParentOp->getParentOp();
329 auto funcOp = computeConstructOp->
getParentOfType<FunctionOpInterface>();
331 return dominatingDataClauses.takeVector();
336 auto &domInfo = this->getAnalysis<DominanceInfo>();
337 auto &postDomInfo = this->getAnalysis<PostDominanceInfo>();
338 funcOp->walk([&](acc::DeclareEnterOp declareEnterOp) {
339 if (domInfo.dominates(declareEnterOp.getOperation(), computeConstructOp)) {
342 for (
auto *user : declareEnterOp.getToken().getUsers())
343 if (
auto declareExit = dyn_cast<acc::DeclareExitOp>(user))
344 exits.push_back(declareExit);
348 if (!exits.empty() && llvm::all_of(exits, [&](acc::DeclareExitOp exitOp) {
349 return postDomInfo.postDominates(exitOp, computeConstructOp);
351 for (
auto dataClause : declareEnterOp.getDataClauseOperands())
352 dominatingDataClauses.insert(dataClause);
357 return dominatingDataClauses.takeVector();
360template <
typename OpT>
361Operation *ACCImplicitData::getOriginalDataClauseOpForAlias(
364 auto &aliasAnalysis = this->getAnalysis<AliasAnalysis>();
365 for (
auto dataClause : dominatingDataClauses) {
366 if (
auto *dataClauseOp = dataClause.getDefiningOp()) {
368 if (isa<acc::CopyinOp, acc::CreateOp, acc::PresentOp, acc::NoCreateOp,
369 acc::DevicePtrOp>(dataClauseOp))
370 if (aliasAnalysis.alias(
acc::getVar(dataClauseOp), var).isMust())
378static void fillInBoundsForUnknownDimensions(
Operation *dataClauseOp,
389 if (
auto mappableTy = dyn_cast<acc::MappableType>(type)) {
390 if (mappableTy.hasUnknownDimensions()) {
393 if (std::is_same_v<
decltype(dataClauseOp), acc::DevicePtrOp>)
397 auto bounds = mappableTy.generateAccBounds(var, builder);
399 dataClauseOp.getBoundsMutable().assign(bounds);
406ACCImplicitData::generatePrivateRecipe(ModuleOp &module,
Value var,
410 std::string recipeName =
411 accSupport.
getRecipeName(acc::RecipeKind::private_recipe, type, var);
414 auto existingRecipe =
module.lookupSymbol<acc::PrivateRecipeOp>(recipeName);
416 return existingRecipe;
423 acc::PrivateRecipeOp::createAndPopulate(builder, loc, recipeName, type);
424 if (!recipe.has_value())
425 return accSupport.
emitNYI(loc,
"implicit private"),
nullptr;
426 return recipe.value();
429acc::FirstprivateRecipeOp
430ACCImplicitData::generateFirstprivateRecipe(ModuleOp &module,
Value var,
434 std::string recipeName =
435 accSupport.
getRecipeName(acc::RecipeKind::firstprivate_recipe, type, var);
438 auto existingRecipe =
439 module.lookupSymbol<acc::FirstprivateRecipeOp>(recipeName);
441 return existingRecipe;
447 auto recipe = acc::FirstprivateRecipeOp::createAndPopulate(builder, loc,
449 if (!recipe.has_value())
450 return accSupport.
emitNYI(loc,
"implicit firstprivate"),
nullptr;
451 return recipe.value();
454void ACCImplicitData::generateRecipes(ModuleOp &module,
OpBuilder &builder,
458 auto &accSupport = this->getAnalysis<acc::OpenACCSupport>();
459 for (
auto var : newOperands) {
462 auto recipe = generatePrivateRecipe(
465 newRecipeSyms.push_back(SymbolRefAttr::get(module->getContext(),
466 recipe.getSymName().str()));
468 auto recipe = generateFirstprivateRecipe(
471 newRecipeSyms.push_back(SymbolRefAttr::get(module->getContext(),
472 recipe.getSymName().str()));
488template <
typename OpT>
489Operation *ACCImplicitData::generateDataClauseOpForCandidate(
490 Value var, ModuleOp &module,
OpBuilder &builder, OpT computeConstructOp,
492 const std::optional<acc::ClauseDefaultValue> &defaultClause) {
493 auto &accSupport = this->getAnalysis<acc::OpenACCSupport>();
494 acc::VariableTypeCategory typeCategory =
495 acc::VariableTypeCategory::uncategorized;
496 if (
auto mappableTy = dyn_cast<acc::MappableType>(var.
getType())) {
497 typeCategory = mappableTy.getTypeCategory(var);
498 }
else if (
auto pointerLikeTy =
499 dyn_cast<acc::PointerLikeType>(var.
getType())) {
500 typeCategory = pointerLikeTy.getPointeeTypeCategory(
502 pointerLikeTy.getElementType());
506 acc::bitEnumContainsAny(typeCategory, acc::VariableTypeCategory::scalar);
507 bool isAnyAggregate = acc::bitEnumContainsAny(
508 typeCategory, acc::VariableTypeCategory::aggregate);
512 op = getOriginalDataClauseOpForAlias(var, builder, computeConstructOp,
513 dominatingDataClauses);
515 if (isa<acc::NoCreateOp>(op))
516 return acc::NoCreateOp::create(builder, loc, var,
521 if (isa<acc::DevicePtrOp>(op))
522 return acc::DevicePtrOp::create(builder, loc, var,
529 return acc::PresentOp::create(builder, loc, var,
533 }
else if (isScalar) {
534 if (enableImplicitReductionCopy &&
538 acc::CopyinOp::create(builder, loc, var,
541 copyinOp.setDataClause(acc::DataClause::acc_reduction);
542 return copyinOp.getOperation();
544 if constexpr (std::is_same_v<OpT, acc::KernelsOp> ||
545 std::is_same_v<OpT, acc::KernelEnvironmentOp>) {
553 acc::CopyinOp::create(builder, loc, var,
556 copyinOp.setDataClause(acc::DataClause::acc_copy);
557 return copyinOp.getOperation();
560 return acc::FirstprivateOp::create(builder, loc, var,
564 }
else if (isAnyAggregate) {
568 if (defaultClause.has_value() &&
569 defaultClause.value() == acc::ClauseDefaultValue::Present) {
570 newDataOp = acc::PresentOp::create(builder, loc, var,
575 acc::CopyinOp::create(builder, loc, var,
578 copyinOp.setDataClause(acc::DataClause::acc_copy);
579 newDataOp = copyinOp.getOperation();
588 LLVM_DEBUG(llvm::dbgs()
589 <<
"Unhandled case for implicit data mapping " << var <<
"\n");
604static void legalizeValuesInRegion(
Region &accRegion,
607 for (
Value dataClause :
608 llvm::concat<Value>(newDataClauseOperands, newPrivateOperands)) {
617template <
typename OpT>
621 assert(privateOperands.size() == privateRecipeSyms.size());
622 if (privateOperands.empty())
632 if (accOp.getPrivatizationRecipes().has_value())
633 for (
auto privatization : accOp.getPrivatizationRecipesAttr())
634 completePrivateRecipesSyms.push_back(privatization);
635 if (accOp.getFirstprivatizationRecipes().has_value())
636 for (
auto privatization : accOp.getFirstprivatizationRecipesAttr())
637 completeFirstprivateRecipesSyms.push_back(privatization);
640 for (
auto [priv, privateRecipeSym] :
641 llvm::zip(privateOperands, privateRecipeSyms)) {
642 if (isa<acc::PrivateOp>(priv.getDefiningOp())) {
643 newPrivateOperands.push_back(priv);
644 completePrivateRecipesSyms.push_back(privateRecipeSym);
645 }
else if (isa<acc::FirstprivateOp>(priv.getDefiningOp())) {
646 newFirstprivateOperands.push_back(priv);
647 completeFirstprivateRecipesSyms.push_back(privateRecipeSym);
649 llvm_unreachable(
"unhandled private operand");
654 accOp.getPrivateOperandsMutable().append(newPrivateOperands);
655 accOp.getFirstprivateOperandsMutable().append(newFirstprivateOperands);
658 if (!completePrivateRecipesSyms.empty())
659 accOp.setPrivatizationRecipesAttr(
660 ArrayAttr::get(accOp.getContext(), completePrivateRecipesSyms));
661 if (!completeFirstprivateRecipesSyms.empty())
662 accOp.setFirstprivatizationRecipesAttr(
663 ArrayAttr::get(accOp.getContext(), completeFirstprivateRecipesSyms));
668 for (
auto *user : res.getUsers())
669 if (isa<ACC_DATA_EXIT_OPS>(user))
687 Value lastDataClause =
nullptr;
688 for (
auto dataEntry : llvm::reverse(sortedDataClauseOperands)) {
689 if (llvm::find(newDataClauseOperands, dataEntry) ==
690 newDataClauseOperands.end()) {
693 lastDataClause = dataEntry;
697 if (
auto *dataExitOp = findDataExitOp(lastDataClause.
getDefiningOp()))
699 Operation *dataEntryOp = dataEntry.getDefiningOp();
700 if (isa<acc::CopyinOp>(dataEntryOp)) {
701 auto copyoutOp = acc::CopyoutOp::create(
705 copyoutOp.setDataClause(acc::DataClause::acc_copy);
706 }
else if (isa<acc::PresentOp, acc::NoCreateOp>(dataEntryOp)) {
707 auto deleteOp = acc::DeleteOp::create(
708 builder, dataEntryOp->
getLoc(), dataEntry,
712 }
else if (isa<acc::DevicePtrOp>(dataEntryOp)) {
715 llvm_unreachable(
"unhandled data exit");
717 lastDataClause = dataEntry;
728 baseRefs.push_back(val);
733 if (val != baseRefs.front())
734 baseRefs.insert(baseRefs.begin(), val);
738 if (
auto viewLikeOp = val.
getDefiningOp<ViewLikeOpInterface>()) {
739 val = viewLikeOp.getViewSource();
740 baseRefs.insert(baseRefs.begin(), val);
754 std::find_if(sortedDataClauseOperands.begin(),
755 sortedDataClauseOperands.end(), [&](
Value dataClauseVal) {
758 auto var = acc::getVar(dataClauseVal.getDefiningOp());
759 auto baseRefs = getBaseRefsChain(var);
765 return std::find(baseRefs.begin(), baseRefs.end(),
766 acc::getVar(newClause)) != baseRefs.end();
769 if (insertPos != sortedDataClauseOperands.end()) {
770 newClause->
moveBefore(insertPos->getDefiningOp());
771 sortedDataClauseOperands.insert(insertPos,
acc::getAccVar(newClause));
777template <
typename OpT>
778void ACCImplicitData::generateImplicitDataOps(
779 ModuleOp &module, OpT computeConstructOp,
780 std::optional<acc::ClauseDefaultValue> &defaultClause) {
783 if (defaultClause.has_value() &&
784 defaultClause.value() == acc::ClauseDefaultValue::None)
786 assert(!defaultClause.has_value() ||
787 defaultClause.value() == acc::ClauseDefaultValue::Present);
795 auto isCandidate{[&](
Value val) ->
bool {
796 return isCandidateForImplicitData(val, accRegion);
799 llvm::to_vector(llvm::make_filter_range(liveInValues, isCandidate)));
800 if (candidateVars.empty())
807 if (!candidateVars.empty()) {
808 LLVM_DEBUG(llvm::dbgs() <<
"== Generating clauses for ==\n"
809 << computeConstructOp <<
"\n");
811 auto dominatingDataClauses = getDominatingDataClauses(computeConstructOp);
812 for (
auto var : candidateVars) {
813 auto newDataClauseOp = generateDataClauseOpForCandidate(
814 var, module, builder, computeConstructOp, dominatingDataClauses,
816 fillInBoundsForUnknownDimensions(newDataClauseOp, builder);
817 LLVM_DEBUG(llvm::dbgs() <<
"Generated data clause for " << var <<
":\n"
818 <<
"\t" << *newDataClauseOp <<
"\n");
819 if (isa_and_nonnull<acc::PrivateOp, acc::FirstprivateOp, acc::ReductionOp>(
822 }
else if (isa_and_nonnull<ACC_DATA_CLAUSE_OPS>(newDataClauseOp)) {
830 legalizeValuesInRegion(accRegion, newPrivateOperands, newDataClauseOperands);
835 if constexpr (!std::is_same_v<OpT, acc::KernelsOp> &&
836 !std::is_same_v<OpT, acc::KernelEnvironmentOp>)
837 generateRecipes(module, builder, computeConstructOp, newPrivateOperands,
838 newPrivateRecipeSyms);
842 computeConstructOp.getDataClauseOperands());
843 for (
auto newClause : newDataClauseOperands)
844 insertInSortedOrder(sortedDataClauseOperands, newClause.getDefiningOp());
847 generateDataExitOperations(builder, computeConstructOp, newDataClauseOperands,
848 sortedDataClauseOperands);
851 assert(newPrivateOperands.size() == newPrivateRecipeSyms.size() &&
853 if constexpr (!std::is_same_v<OpT, acc::KernelsOp> &&
854 !std::is_same_v<OpT, acc::KernelEnvironmentOp>)
855 addNewPrivateOperands(computeConstructOp, newPrivateOperands,
856 newPrivateRecipeSyms);
858 computeConstructOp.getDataClauseOperandsMutable().assign(
859 sortedDataClauseOperands);
862void ACCImplicitData::runOnOperation() {
863 ModuleOp module = this->getOperation();
864 module.walk([&](Operation *op) {
865 if (isa<ACC_COMPUTE_CONSTRUCT_OPS, acc::KernelEnvironmentOp>(op)) {
866 assert(op->getNumRegions() == 1 && "must have 1 region");
872 generateImplicitDataOps(module, op, defaultClause);
#define ACC_COMPUTE_CONSTRUCT_OPS
#define ACC_DATA_ENTRY_OPS
#define ACC_DATA_EXIT_OPS
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation is the basic unit of execution within MLIR.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
InFlightDiagnostic emitNYI(Location loc, const Twine &message)
Report a case that is not yet supported by the implementation.
std::string getVariableName(Value v)
Get the variable name for a given value.
std::string getRecipeName(RecipeKind kind, Type type, Value var)
Get the recipe name for a given type and value.
mlir::Value getAccVar(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation.
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
std::optional< mlir::acc::DataClause > getDataClause(mlir::Operation *accDataEntryOp)
Used to obtain the dataClause from a data entry operation.
bool isPointerLikeType(mlir::Type type)
Used to check whether the provided type implements the PointerLikeType interface.
mlir::SmallVector< mlir::Value > getBounds(mlir::Operation *accDataClauseOp)
Used to obtain bounds from an acc data clause operation.
std::optional< ClauseDefaultValue > getDefaultAttr(mlir::Operation *op)
Looks for an OpenACC default attribute on the current operation op or in a parent operation which enc...
bool isOnlyUsedByReductionClauses(mlir::Value val, mlir::Region ®ion)
Returns true if this value is only used by acc.reduction operations in the region.
std::optional< llvm::StringRef > getVarName(mlir::Operation *accOp)
Used to obtain the name from an acc operation.
bool isMappableType(mlir::Type type)
Used to check whether the provided type implements the MappableType interface.
mlir::Value getBaseEntity(mlir::Value val)
bool isOnlyUsedByPrivateClauses(mlir::Value val, mlir::Region ®ion)
Returns true if this value is only used by acc.private operations in the region.
Include the generated interface declarations.
void replaceAllUsesInRegionWith(Value orig, Value replacement, Region ®ion)
Replace all uses of orig within the given region with replacement.
llvm::SetVector< T, Vector, Set, N > SetVector
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
void getUsedValuesDefinedAbove(Region ®ion, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
llvm::TypeSwitch< T, ResultT > TypeSwitch