214#include "llvm/ADT/STLExtras.h"
215#include "llvm/ADT/SmallVector.h"
216#include "llvm/ADT/TypeSwitch.h"
217#include "llvm/Support/ErrorHandling.h"
218#include <type_traits>
222#define GEN_PASS_DEF_ACCIMPLICITDATA
223#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
227#define DEBUG_TYPE "acc-implicit-data"
237 void runOnOperation()
override;
242 template <
typename OpT>
243 Operation *getOriginalDataClauseOpForAlias(
249 template <
typename OpT>
250 Operation *generateDataClauseOpForCandidate(
251 Value var, ModuleOp &module,
OpBuilder &builder, OpT computeConstructOp,
253 const std::optional<acc::ClauseDefaultValue> &defaultClause);
256 template <
typename OpT>
258 generateImplicitDataOps(ModuleOp &module, OpT computeConstructOp,
259 std::optional<acc::ClauseDefaultValue> &defaultClause,
263 acc::PrivateRecipeOp generatePrivateRecipe(ModuleOp &module,
Value var,
268 acc::FirstprivateRecipeOp
269 generateFirstprivateRecipe(ModuleOp &module,
Value var,
Location loc,
274 void generateRecipes(ModuleOp &module,
OpBuilder &builder,
281static bool isCandidateForImplicitData(
Value val,
Region &accRegion,
290 if (isa_and_nonnull<ACC_DATA_ENTRY_OPS>(val.
getDefiningOp()))
304template <
typename OpT>
305Operation *ACCImplicitData::getOriginalDataClauseOpForAlias(
308 auto &aliasAnalysis = this->getAnalysis<AliasAnalysis>();
309 for (
auto dataClause : dominatingDataClauses) {
310 if (
auto *dataClauseOp = dataClause.getDefiningOp()) {
312 if (isa<acc::CopyinOp, acc::CreateOp, acc::PresentOp, acc::NoCreateOp,
313 acc::DevicePtrOp>(dataClauseOp))
314 if (aliasAnalysis.alias(
acc::getVar(dataClauseOp), var).isMust())
322static void fillInBoundsForUnknownDimensions(
Operation *dataClauseOp,
333 if (
auto mappableTy = dyn_cast<acc::MappableType>(type)) {
334 if (mappableTy.hasUnknownDimensions()) {
337 if (std::is_same_v<
decltype(dataClauseOp), acc::DevicePtrOp>)
341 auto bounds = mappableTy.generateAccBounds(var, builder);
343 dataClauseOp.getBoundsMutable().assign(bounds);
350ACCImplicitData::generatePrivateRecipe(ModuleOp &module,
Value var,
354 std::string recipeName =
355 accSupport.
getRecipeName(acc::RecipeKind::private_recipe, type, var);
358 auto existingRecipe =
module.lookupSymbol<acc::PrivateRecipeOp>(recipeName);
360 return existingRecipe;
367 acc::PrivateRecipeOp::createAndPopulate(builder, loc, recipeName, type);
368 if (!recipe.has_value())
369 return accSupport.
emitNYI(loc,
"implicit private"),
nullptr;
370 return recipe.value();
373acc::FirstprivateRecipeOp
374ACCImplicitData::generateFirstprivateRecipe(ModuleOp &module,
Value var,
378 std::string recipeName =
379 accSupport.
getRecipeName(acc::RecipeKind::firstprivate_recipe, type, var);
382 auto existingRecipe =
383 module.lookupSymbol<acc::FirstprivateRecipeOp>(recipeName);
385 return existingRecipe;
391 auto recipe = acc::FirstprivateRecipeOp::createAndPopulate(builder, loc,
393 if (!recipe.has_value())
394 return accSupport.
emitNYI(loc,
"implicit firstprivate"),
nullptr;
395 return recipe.value();
398void ACCImplicitData::generateRecipes(ModuleOp &module,
OpBuilder &builder,
401 auto &accSupport = this->getAnalysis<acc::OpenACCSupport>();
402 for (
auto var : newOperands) {
405 auto recipe = generatePrivateRecipe(
408 privateOp.setRecipeAttr(
409 SymbolRefAttr::get(module->getContext(), recipe.getSymName()));
410 }
else if (
auto firstprivateOp = var.
getDefiningOp<acc::FirstprivateOp>()) {
411 auto recipe = generateFirstprivateRecipe(
414 firstprivateOp.setRecipeAttr(SymbolRefAttr::get(
415 module->getContext(), recipe.getSymName().str()));
431template <
typename OpT>
432Operation *ACCImplicitData::generateDataClauseOpForCandidate(
433 Value var, ModuleOp &module,
OpBuilder &builder, OpT computeConstructOp,
435 const std::optional<acc::ClauseDefaultValue> &defaultClause) {
436 auto &accSupport = this->getAnalysis<acc::OpenACCSupport>();
437 acc::VariableTypeCategory typeCategory =
438 acc::VariableTypeCategory::uncategorized;
439 if (
auto mappableTy = dyn_cast<acc::MappableType>(var.
getType())) {
440 typeCategory = mappableTy.getTypeCategory(var);
441 }
else if (
auto pointerLikeTy =
442 dyn_cast<acc::PointerLikeType>(var.
getType())) {
443 typeCategory = pointerLikeTy.getPointeeTypeCategory(
445 pointerLikeTy.getElementType());
449 acc::bitEnumContainsAny(typeCategory, acc::VariableTypeCategory::scalar);
450 bool isAnyAggregate = acc::bitEnumContainsAny(
451 typeCategory, acc::VariableTypeCategory::aggregate);
452 Location loc = computeConstructOp->getLoc();
455 op = getOriginalDataClauseOpForAlias(var, builder, computeConstructOp,
456 dominatingDataClauses);
458 if (isa<acc::NoCreateOp>(op))
459 return acc::NoCreateOp::create(builder, loc, var,
464 if (isa<acc::DevicePtrOp>(op))
465 return acc::DevicePtrOp::create(builder, loc, var,
472 return acc::PresentOp::create(builder, loc, var,
480 return acc::DevicePtrOp::create(builder, loc, var,
486 if (enableImplicitReductionCopy &&
488 computeConstructOp->getRegion(0))) {
490 acc::CopyinOp::create(builder, loc, var,
493 copyinOp.setDataClause(acc::DataClause::acc_reduction);
494 return copyinOp.getOperation();
496 if constexpr (std::is_same_v<OpT, acc::KernelsOp> ||
497 std::is_same_v<OpT, acc::KernelEnvironmentOp>) {
505 acc::CopyinOp::create(builder, loc, var,
508 copyinOp.setDataClause(acc::DataClause::acc_copy);
509 return copyinOp.getOperation();
512 return acc::FirstprivateOp::create(builder, loc, var,
516 }
else if (isAnyAggregate) {
520 if (defaultClause.has_value() &&
521 defaultClause.value() == acc::ClauseDefaultValue::Present) {
522 newDataOp = acc::PresentOp::create(builder, loc, var,
529 acc::CopyinOp::create(builder, loc, var,
532 copyinOp.setDataClause(acc::DataClause::acc_copy);
533 newDataOp = copyinOp.getOperation();
542 LLVM_DEBUG(llvm::dbgs()
543 <<
"Unhandled case for implicit data mapping " << var <<
"\n");
558static void legalizeValuesInRegion(
Region &accRegion,
561 for (
Value dataClause :
562 llvm::concat<Value>(newDataClauseOperands, newPrivateOperands)) {
569template <
typename OpT>
570static void addNewPrivateOperands(OpT &accOp,
572 if (privateOperands.empty())
575 for (
auto priv : privateOperands) {
576 if (isa<acc::PrivateOp>(priv.getDefiningOp())) {
577 accOp.getPrivateOperandsMutable().append(priv);
578 }
else if (isa<acc::FirstprivateOp>(priv.getDefiningOp())) {
579 accOp.getFirstprivateOperandsMutable().append(priv);
581 llvm_unreachable(
"unhandled reduction operand");
588 for (
auto *user : res.getUsers())
589 if (isa<ACC_DATA_EXIT_OPS>(user))
607 Value lastDataClause =
nullptr;
608 for (
auto dataEntry : llvm::reverse(sortedDataClauseOperands)) {
609 if (llvm::find(newDataClauseOperands, dataEntry) ==
610 newDataClauseOperands.end()) {
613 lastDataClause = dataEntry;
617 if (
auto *dataExitOp = findDataExitOp(lastDataClause.
getDefiningOp()))
619 Operation *dataEntryOp = dataEntry.getDefiningOp();
620 if (isa<acc::CopyinOp>(dataEntryOp)) {
621 auto copyoutOp = acc::CopyoutOp::create(
625 copyoutOp.setDataClause(acc::DataClause::acc_copy);
626 }
else if (isa<acc::PresentOp, acc::NoCreateOp>(dataEntryOp)) {
627 auto deleteOp = acc::DeleteOp::create(
628 builder, dataEntryOp->
getLoc(), dataEntry,
632 }
else if (isa<acc::DevicePtrOp>(dataEntryOp)) {
635 llvm_unreachable(
"unhandled data exit");
637 lastDataClause = dataEntry;
648 baseRefs.push_back(val);
653 if (val != baseRefs.front())
654 baseRefs.insert(baseRefs.begin(), val);
658 if (
auto viewLikeOp = val.
getDefiningOp<ViewLikeOpInterface>()) {
659 val = viewLikeOp.getViewSource();
660 baseRefs.insert(baseRefs.begin(), val);
674 std::find_if(sortedDataClauseOperands.begin(),
675 sortedDataClauseOperands.end(), [&](
Value dataClauseVal) {
678 auto var = acc::getVar(dataClauseVal.getDefiningOp());
679 auto baseRefs = getBaseRefsChain(var);
685 return std::find(baseRefs.begin(), baseRefs.end(),
686 acc::getVar(newClause)) != baseRefs.end();
689 if (insertPos != sortedDataClauseOperands.end()) {
690 newClause->
moveBefore(insertPos->getDefiningOp());
691 sortedDataClauseOperands.insert(insertPos,
acc::getAccVar(newClause));
697template <
typename OpT>
698void ACCImplicitData::generateImplicitDataOps(
699 ModuleOp &module, OpT computeConstructOp,
700 std::optional<acc::ClauseDefaultValue> &defaultClause,
704 if (defaultClause.has_value() &&
705 defaultClause.value() == acc::ClauseDefaultValue::None)
707 assert(!defaultClause.has_value() ||
708 defaultClause.value() == acc::ClauseDefaultValue::Present);
711 Region &accRegion = computeConstructOp->getRegion(0);
716 auto isCandidate{[&](
Value val) ->
bool {
717 return isCandidateForImplicitData(val, accRegion, accSupport);
720 llvm::to_vector(llvm::make_filter_range(liveInValues, isCandidate)));
721 if (candidateVars.empty())
728 if (!candidateVars.empty()) {
729 LLVM_DEBUG(llvm::dbgs() <<
"== Generating clauses for ==\n"
730 << computeConstructOp <<
"\n");
732 auto &domInfo = this->getAnalysis<DominanceInfo>();
733 auto &postDomInfo = this->getAnalysis<PostDominanceInfo>();
734 auto dominatingDataClauses =
736 for (
auto var : candidateVars) {
737 auto newDataClauseOp = generateDataClauseOpForCandidate(
738 var, module, builder, computeConstructOp, dominatingDataClauses,
740 fillInBoundsForUnknownDimensions(newDataClauseOp, builder);
741 LLVM_DEBUG(llvm::dbgs() <<
"Generated data clause for " << var <<
":\n"
742 <<
"\t" << *newDataClauseOp <<
"\n");
743 if (isa_and_nonnull<acc::PrivateOp, acc::FirstprivateOp, acc::ReductionOp>(
746 }
else if (isa_and_nonnull<ACC_DATA_CLAUSE_OPS>(newDataClauseOp)) {
754 legalizeValuesInRegion(accRegion, newPrivateOperands, newDataClauseOperands);
758 if constexpr (!std::is_same_v<OpT, acc::KernelsOp> &&
759 !std::is_same_v<OpT, acc::KernelEnvironmentOp>)
760 generateRecipes(module, builder, computeConstructOp, newPrivateOperands);
764 computeConstructOp.getDataClauseOperands());
765 for (
auto newClause : newDataClauseOperands)
766 insertInSortedOrder(sortedDataClauseOperands, newClause.getDefiningOp());
769 generateDataExitOperations(builder, computeConstructOp, newDataClauseOperands,
770 sortedDataClauseOperands);
772 if constexpr (!std::is_same_v<OpT, acc::KernelsOp> &&
773 !std::is_same_v<OpT, acc::KernelEnvironmentOp>)
774 addNewPrivateOperands(computeConstructOp, newPrivateOperands);
775 computeConstructOp.getDataClauseOperandsMutable().assign(
776 sortedDataClauseOperands);
779void ACCImplicitData::runOnOperation() {
780 ModuleOp module = this->getOperation();
784 module.walk([&](Operation *op) {
785 if (isa<ACC_COMPUTE_CONSTRUCT_OPS, acc::KernelEnvironmentOp>(op)) {
786 assert(op->getNumRegions() == 1 && "must have 1 region");
792 generateImplicitDataOps(module, op, defaultClause, accSupport);
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
InFlightDiagnostic emitNYI(Location loc, const Twine &message)
Report a case that is not yet supported by the implementation.
bool isValidValueUse(Value v, Region ®ion)
Check if a value use is legal in an OpenACC region.
std::string getVariableName(Value v)
Get the variable name for a given value.
std::string getRecipeName(RecipeKind kind, Type type, Value var)
Get the recipe name for a given type and value.
#define ACC_COMPUTE_CONSTRUCT_OPS
#define ACC_DATA_ENTRY_OPS
#define ACC_DATA_EXIT_OPS
static constexpr StringLiteral getFromDefaultClauseAttrName()
mlir::Value getAccVar(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation.
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
std::optional< mlir::acc::DataClause > getDataClause(mlir::Operation *accDataEntryOp)
Used to obtain the dataClause from a data entry operation.
bool isPointerLikeType(mlir::Type type)
Used to check whether the provided type implements the PointerLikeType interface.
mlir::SmallVector< mlir::Value > getBounds(mlir::Operation *accDataClauseOp)
Used to obtain bounds from an acc data clause operation.
std::optional< ClauseDefaultValue > getDefaultAttr(mlir::Operation *op)
Looks for an OpenACC default attribute on the current operation op or in a parent operation which enc...
bool isOnlyUsedByReductionClauses(mlir::Value val, mlir::Region ®ion)
Returns true if this value is only used by acc.reduction operations in the region.
std::optional< llvm::StringRef > getVarName(mlir::Operation *accOp)
Used to obtain the name from an acc operation.
llvm::SmallVector< mlir::Value > getDominatingDataClauses(mlir::Operation *computeConstructOp, mlir::DominanceInfo &domInfo, mlir::PostDominanceInfo &postDomInfo)
Collects all data clauses that dominate the compute construct.
bool isMappableType(mlir::Type type)
Used to check whether the provided type implements the MappableType interface.
bool isDeviceValue(mlir::Value val)
Check if a value represents device data.
mlir::Value getBaseEntity(mlir::Value val)
Include the generated interface declarations.
void replaceAllUsesInRegionWith(Value orig, Value replacement, Region ®ion)
Replace all uses of orig within the given region with replacement.
llvm::SetVector< T, Vector, Set, N > SetVector
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
void getUsedValuesDefinedAbove(Region ®ion, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
llvm::TypeSwitch< T, ResultT > TypeSwitch