214#include "llvm/ADT/STLExtras.h"
215#include "llvm/ADT/SmallVector.h"
216#include "llvm/ADT/TypeSwitch.h"
217#include "llvm/Support/ErrorHandling.h"
218#include <type_traits>
222#define GEN_PASS_DEF_ACCIMPLICITDATA
223#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
227#define DEBUG_TYPE "acc-implicit-data"
233class ACCImplicitData :
public acc::impl::ACCImplicitDataBase<ACCImplicitData> {
235 using acc::impl::ACCImplicitDataBase<ACCImplicitData>::ACCImplicitDataBase;
237 void runOnOperation()
override;
242 template <
typename OpT>
243 Operation *getOriginalDataClauseOpForAlias(
249 template <
typename OpT>
250 Operation *generateDataClauseOpForCandidate(
251 Value var, ModuleOp &module,
OpBuilder &builder, OpT computeConstructOp,
253 const std::optional<acc::ClauseDefaultValue> &defaultClause);
256 template <
typename OpT>
258 generateImplicitDataOps(ModuleOp &module, OpT computeConstructOp,
259 std::optional<acc::ClauseDefaultValue> &defaultClause,
263 acc::PrivateRecipeOp generatePrivateRecipe(ModuleOp &module,
Value var,
268 acc::FirstprivateRecipeOp
269 generateFirstprivateRecipe(ModuleOp &module,
Value var,
Location loc,
274 void generateRecipes(ModuleOp &module,
OpBuilder &builder,
281static bool isCandidateForImplicitData(
Value val,
Region &accRegion,
293 if (isa_and_nonnull<ACC_DATA_ENTRY_OPS>(val.
getDefiningOp()))
303template <
typename OpT>
304Operation *ACCImplicitData::getOriginalDataClauseOpForAlias(
307 auto &aliasAnalysis = this->getAnalysis<AliasAnalysis>();
308 for (
auto dataClause : dominatingDataClauses) {
309 if (
auto *dataClauseOp = dataClause.getDefiningOp()) {
311 if (isa<acc::CopyinOp, acc::CreateOp, acc::PresentOp, acc::NoCreateOp,
312 acc::DevicePtrOp>(dataClauseOp))
313 if (aliasAnalysis.alias(
acc::getVar(dataClauseOp), var).isMust())
321static void fillInBoundsForUnknownDimensions(
Operation *dataClauseOp,
332 if (
auto mappableTy = dyn_cast<acc::MappableType>(type)) {
333 if (mappableTy.hasUnknownDimensions()) {
336 if (std::is_same_v<
decltype(dataClauseOp), acc::DevicePtrOp>)
340 auto bounds = mappableTy.generateAccBounds(var, builder);
342 dataClauseOp.getBoundsMutable().assign(bounds);
349ACCImplicitData::generatePrivateRecipe(ModuleOp &module,
Value var,
353 std::string recipeName =
354 accSupport.
getRecipeName(acc::RecipeKind::private_recipe, type, var);
357 auto existingRecipe =
module.lookupSymbol<acc::PrivateRecipeOp>(recipeName);
359 return existingRecipe;
366 acc::PrivateRecipeOp::createAndPopulate(builder, loc, recipeName, type);
367 if (!recipe.has_value())
368 return accSupport.
emitNYI(loc,
"implicit private"),
nullptr;
369 return recipe.value();
372acc::FirstprivateRecipeOp
373ACCImplicitData::generateFirstprivateRecipe(ModuleOp &module,
Value var,
377 std::string recipeName =
378 accSupport.
getRecipeName(acc::RecipeKind::firstprivate_recipe, type, var);
381 auto existingRecipe =
382 module.lookupSymbol<acc::FirstprivateRecipeOp>(recipeName);
384 return existingRecipe;
390 auto recipe = acc::FirstprivateRecipeOp::createAndPopulate(builder, loc,
392 if (!recipe.has_value())
393 return accSupport.
emitNYI(loc,
"implicit firstprivate"),
nullptr;
394 return recipe.value();
397void ACCImplicitData::generateRecipes(ModuleOp &module,
OpBuilder &builder,
400 auto &accSupport = this->getAnalysis<acc::OpenACCSupport>();
401 for (
auto var : newOperands) {
404 auto recipe = generatePrivateRecipe(
407 privateOp.setRecipeAttr(
408 SymbolRefAttr::get(module->getContext(), recipe.getSymName()));
409 }
else if (
auto firstprivateOp = var.
getDefiningOp<acc::FirstprivateOp>()) {
410 auto recipe = generateFirstprivateRecipe(
413 firstprivateOp.setRecipeAttr(SymbolRefAttr::get(
414 module->getContext(), recipe.getSymName().str()));
430template <
typename OpT>
431Operation *ACCImplicitData::generateDataClauseOpForCandidate(
432 Value var, ModuleOp &module,
OpBuilder &builder, OpT computeConstructOp,
434 const std::optional<acc::ClauseDefaultValue> &defaultClause) {
435 auto &accSupport = this->getAnalysis<acc::OpenACCSupport>();
436 acc::VariableTypeCategory typeCategory =
437 acc::VariableTypeCategory::uncategorized;
438 if (
auto mappableTy = dyn_cast<acc::MappableType>(var.
getType())) {
439 typeCategory = mappableTy.getTypeCategory(var);
440 }
else if (
auto pointerLikeTy =
441 dyn_cast<acc::PointerLikeType>(var.
getType())) {
442 typeCategory = pointerLikeTy.getPointeeTypeCategory(
444 pointerLikeTy.getElementType());
448 acc::bitEnumContainsAny(typeCategory, acc::VariableTypeCategory::scalar);
449 bool isAnyAggregate = acc::bitEnumContainsAny(
450 typeCategory, acc::VariableTypeCategory::aggregate);
451 Location loc = computeConstructOp->getLoc();
454 op = getOriginalDataClauseOpForAlias(var, builder, computeConstructOp,
455 dominatingDataClauses);
457 if (isa<acc::NoCreateOp>(op))
458 return acc::NoCreateOp::create(builder, loc, var,
463 if (isa<acc::DevicePtrOp>(op))
464 return acc::DevicePtrOp::create(builder, loc, var,
471 return acc::PresentOp::create(builder, loc, var,
475 }
else if (isScalar) {
476 if (enableImplicitReductionCopy &&
478 computeConstructOp->getRegion(0))) {
480 acc::CopyinOp::create(builder, loc, var,
483 copyinOp.setDataClause(acc::DataClause::acc_reduction);
484 return copyinOp.getOperation();
486 if constexpr (std::is_same_v<OpT, acc::KernelsOp> ||
487 std::is_same_v<OpT, acc::KernelEnvironmentOp>) {
495 acc::CopyinOp::create(builder, loc, var,
498 copyinOp.setDataClause(acc::DataClause::acc_copy);
499 return copyinOp.getOperation();
502 return acc::FirstprivateOp::create(builder, loc, var,
506 }
else if (isAnyAggregate) {
510 if (defaultClause.has_value() &&
511 defaultClause.value() == acc::ClauseDefaultValue::Present) {
512 newDataOp = acc::PresentOp::create(builder, loc, var,
519 acc::CopyinOp::create(builder, loc, var,
522 copyinOp.setDataClause(acc::DataClause::acc_copy);
523 newDataOp = copyinOp.getOperation();
532 LLVM_DEBUG(llvm::dbgs()
533 <<
"Unhandled case for implicit data mapping " << var <<
"\n");
548static void legalizeValuesInRegion(
Region &accRegion,
551 for (
Value dataClause :
552 llvm::concat<Value>(newDataClauseOperands, newPrivateOperands)) {
559template <
typename OpT>
560static void addNewPrivateOperands(OpT &accOp,
562 if (privateOperands.empty())
565 for (
auto priv : privateOperands) {
566 if (isa<acc::PrivateOp>(priv.getDefiningOp())) {
567 accOp.getPrivateOperandsMutable().append(priv);
568 }
else if (isa<acc::FirstprivateOp>(priv.getDefiningOp())) {
569 accOp.getFirstprivateOperandsMutable().append(priv);
571 llvm_unreachable(
"unhandled reduction operand");
578 for (
auto *user : res.getUsers())
579 if (isa<ACC_DATA_EXIT_OPS>(user))
597 Value lastDataClause =
nullptr;
598 for (
auto dataEntry : llvm::reverse(sortedDataClauseOperands)) {
599 if (llvm::find(newDataClauseOperands, dataEntry) ==
600 newDataClauseOperands.end()) {
603 lastDataClause = dataEntry;
607 if (
auto *dataExitOp = findDataExitOp(lastDataClause.
getDefiningOp()))
609 Operation *dataEntryOp = dataEntry.getDefiningOp();
610 if (isa<acc::CopyinOp>(dataEntryOp)) {
611 auto copyoutOp = acc::CopyoutOp::create(
615 copyoutOp.setDataClause(acc::DataClause::acc_copy);
616 }
else if (isa<acc::PresentOp, acc::NoCreateOp>(dataEntryOp)) {
617 auto deleteOp = acc::DeleteOp::create(
618 builder, dataEntryOp->
getLoc(), dataEntry,
622 }
else if (isa<acc::DevicePtrOp>(dataEntryOp)) {
625 llvm_unreachable(
"unhandled data exit");
627 lastDataClause = dataEntry;
638 baseRefs.push_back(val);
643 if (val != baseRefs.front())
644 baseRefs.insert(baseRefs.begin(), val);
648 if (
auto viewLikeOp = val.
getDefiningOp<ViewLikeOpInterface>()) {
649 val = viewLikeOp.getViewSource();
650 baseRefs.insert(baseRefs.begin(), val);
664 std::find_if(sortedDataClauseOperands.begin(),
665 sortedDataClauseOperands.end(), [&](
Value dataClauseVal) {
668 auto var = acc::getVar(dataClauseVal.getDefiningOp());
669 auto baseRefs = getBaseRefsChain(var);
675 return std::find(baseRefs.begin(), baseRefs.end(),
676 acc::getVar(newClause)) != baseRefs.end();
679 if (insertPos != sortedDataClauseOperands.end()) {
680 newClause->
moveBefore(insertPos->getDefiningOp());
681 sortedDataClauseOperands.insert(insertPos,
acc::getAccVar(newClause));
687template <
typename OpT>
688void ACCImplicitData::generateImplicitDataOps(
689 ModuleOp &module, OpT computeConstructOp,
690 std::optional<acc::ClauseDefaultValue> &defaultClause,
694 if (defaultClause.has_value() &&
695 defaultClause.value() == acc::ClauseDefaultValue::None)
697 assert(!defaultClause.has_value() ||
698 defaultClause.value() == acc::ClauseDefaultValue::Present);
701 Region &accRegion = computeConstructOp->getRegion(0);
706 auto isCandidate{[&](
Value val) ->
bool {
707 return isCandidateForImplicitData(val, accRegion, accSupport);
710 llvm::to_vector(llvm::make_filter_range(liveInValues, isCandidate)));
711 if (candidateVars.empty())
718 if (!candidateVars.empty()) {
719 LLVM_DEBUG(llvm::dbgs() <<
"== Generating clauses for ==\n"
720 << computeConstructOp <<
"\n");
722 auto &domInfo = this->getAnalysis<DominanceInfo>();
723 auto &postDomInfo = this->getAnalysis<PostDominanceInfo>();
724 auto dominatingDataClauses =
726 for (
auto var : candidateVars) {
727 auto newDataClauseOp = generateDataClauseOpForCandidate(
728 var, module, builder, computeConstructOp, dominatingDataClauses,
730 fillInBoundsForUnknownDimensions(newDataClauseOp, builder);
731 LLVM_DEBUG(llvm::dbgs() <<
"Generated data clause for " << var <<
":\n"
732 <<
"\t" << *newDataClauseOp <<
"\n");
733 if (isa_and_nonnull<acc::PrivateOp, acc::FirstprivateOp, acc::ReductionOp>(
736 }
else if (isa_and_nonnull<ACC_DATA_CLAUSE_OPS>(newDataClauseOp)) {
744 legalizeValuesInRegion(accRegion, newPrivateOperands, newDataClauseOperands);
748 if constexpr (!std::is_same_v<OpT, acc::KernelsOp> &&
749 !std::is_same_v<OpT, acc::KernelEnvironmentOp>)
750 generateRecipes(module, builder, computeConstructOp, newPrivateOperands);
754 computeConstructOp.getDataClauseOperands());
755 for (
auto newClause : newDataClauseOperands)
756 insertInSortedOrder(sortedDataClauseOperands, newClause.getDefiningOp());
759 generateDataExitOperations(builder, computeConstructOp, newDataClauseOperands,
760 sortedDataClauseOperands);
762 if constexpr (!std::is_same_v<OpT, acc::KernelsOp> &&
763 !std::is_same_v<OpT, acc::KernelEnvironmentOp>)
764 addNewPrivateOperands(computeConstructOp, newPrivateOperands);
765 computeConstructOp.getDataClauseOperandsMutable().assign(
766 sortedDataClauseOperands);
769void ACCImplicitData::runOnOperation() {
770 ModuleOp module = this->getOperation();
774 module.walk([&](Operation *op) {
775 if (isa<ACC_COMPUTE_CONSTRUCT_OPS, acc::KernelEnvironmentOp>(op)) {
776 assert(op->getNumRegions() == 1 && "must have 1 region");
782 generateImplicitDataOps(module, op, defaultClause, accSupport);
#define ACC_COMPUTE_CONSTRUCT_OPS
#define ACC_DATA_ENTRY_OPS
#define ACC_DATA_EXIT_OPS
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
InFlightDiagnostic emitNYI(Location loc, const Twine &message)
Report a case that is not yet supported by the implementation.
bool isValidValueUse(Value v, Region ®ion)
Check if a value use is legal in an OpenACC region.
std::string getVariableName(Value v)
Get the variable name for a given value.
std::string getRecipeName(RecipeKind kind, Type type, Value var)
Get the recipe name for a given type and value.
static constexpr StringLiteral getFromDefaultClauseAttrName()
mlir::Value getAccVar(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation.
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
std::optional< mlir::acc::DataClause > getDataClause(mlir::Operation *accDataEntryOp)
Used to obtain the dataClause from a data entry operation.
bool isPointerLikeType(mlir::Type type)
Used to check whether the provided type implements the PointerLikeType interface.
mlir::SmallVector< mlir::Value > getBounds(mlir::Operation *accDataClauseOp)
Used to obtain bounds from an acc data clause operation.
std::optional< ClauseDefaultValue > getDefaultAttr(mlir::Operation *op)
Looks for an OpenACC default attribute on the current operation op or in a parent operation which enc...
bool isOnlyUsedByReductionClauses(mlir::Value val, mlir::Region ®ion)
Returns true if this value is only used by acc.reduction operations in the region.
std::optional< llvm::StringRef > getVarName(mlir::Operation *accOp)
Used to obtain the name from an acc operation.
llvm::SmallVector< mlir::Value > getDominatingDataClauses(mlir::Operation *computeConstructOp, mlir::DominanceInfo &domInfo, mlir::PostDominanceInfo &postDomInfo)
Collects all data clauses that dominate the compute construct.
bool isMappableType(mlir::Type type)
Used to check whether the provided type implements the MappableType interface.
mlir::Value getBaseEntity(mlir::Value val)
bool isOnlyUsedByPrivateClauses(mlir::Value val, mlir::Region ®ion)
Returns true if this value is only used by acc.private operations in the region.
Include the generated interface declarations.
void replaceAllUsesInRegionWith(Value orig, Value replacement, Region ®ion)
Replace all uses of orig within the given region with replacement.
llvm::SetVector< T, Vector, Set, N > SetVector
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
void getUsedValuesDefinedAbove(Region ®ion, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
llvm::TypeSwitch< T, ResultT > TypeSwitch