214#include "llvm/ADT/STLExtras.h"
215#include "llvm/ADT/SmallVector.h"
216#include "llvm/ADT/TypeSwitch.h"
217#include "llvm/Support/ErrorHandling.h"
218#include <type_traits>
222#define GEN_PASS_DEF_ACCIMPLICITDATA
223#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
227#define DEBUG_TYPE "acc-implicit-data"
237 void runOnOperation()
override;
242 template <
typename OpT>
243 Operation *getOriginalDataClauseOpForAlias(
249 template <
typename OpT>
250 Operation *generateDataClauseOpForCandidate(
251 Value var, ModuleOp &module,
OpBuilder &builder, OpT computeConstructOp,
253 const std::optional<acc::ClauseDefaultValue> &defaultClause);
256 template <
typename OpT>
258 generateImplicitDataOps(ModuleOp &module, OpT computeConstructOp,
259 std::optional<acc::ClauseDefaultValue> &defaultClause,
263 acc::PrivateRecipeOp generatePrivateRecipe(ModuleOp &module,
Value var,
268 acc::FirstprivateRecipeOp
269 generateFirstprivateRecipe(ModuleOp &module,
Value var,
Location loc,
274 void generateRecipes(ModuleOp &module,
OpBuilder &builder,
281static bool isCandidateForImplicitData(
Value val,
Region &accRegion,
293 if (isa_and_nonnull<ACC_DATA_ENTRY_OPS>(val.
getDefiningOp()))
303template <
typename OpT>
304Operation *ACCImplicitData::getOriginalDataClauseOpForAlias(
307 auto &aliasAnalysis = this->getAnalysis<AliasAnalysis>();
308 for (
auto dataClause : dominatingDataClauses) {
309 if (
auto *dataClauseOp = dataClause.getDefiningOp()) {
311 if (isa<acc::CopyinOp, acc::CreateOp, acc::PresentOp, acc::NoCreateOp,
312 acc::DevicePtrOp>(dataClauseOp))
313 if (aliasAnalysis.alias(
acc::getVar(dataClauseOp), var).isMust())
321static void fillInBoundsForUnknownDimensions(
Operation *dataClauseOp,
332 if (
auto mappableTy = dyn_cast<acc::MappableType>(type)) {
333 if (mappableTy.hasUnknownDimensions()) {
336 if (std::is_same_v<
decltype(dataClauseOp), acc::DevicePtrOp>)
340 auto bounds = mappableTy.generateAccBounds(var, builder);
342 dataClauseOp.getBoundsMutable().assign(bounds);
349ACCImplicitData::generatePrivateRecipe(ModuleOp &module,
Value var,
353 std::string recipeName =
354 accSupport.
getRecipeName(acc::RecipeKind::private_recipe, type, var);
357 auto existingRecipe =
module.lookupSymbol<acc::PrivateRecipeOp>(recipeName);
359 return existingRecipe;
366 acc::PrivateRecipeOp::createAndPopulate(builder, loc, recipeName, type);
367 if (!recipe.has_value())
368 return accSupport.
emitNYI(loc,
"implicit private"),
nullptr;
369 return recipe.value();
372acc::FirstprivateRecipeOp
373ACCImplicitData::generateFirstprivateRecipe(ModuleOp &module,
Value var,
377 std::string recipeName =
378 accSupport.
getRecipeName(acc::RecipeKind::firstprivate_recipe, type, var);
381 auto existingRecipe =
382 module.lookupSymbol<acc::FirstprivateRecipeOp>(recipeName);
384 return existingRecipe;
390 auto recipe = acc::FirstprivateRecipeOp::createAndPopulate(builder, loc,
392 if (!recipe.has_value())
393 return accSupport.
emitNYI(loc,
"implicit firstprivate"),
nullptr;
394 return recipe.value();
397void ACCImplicitData::generateRecipes(ModuleOp &module,
OpBuilder &builder,
400 auto &accSupport = this->getAnalysis<acc::OpenACCSupport>();
401 for (
auto var : newOperands) {
403 if (
auto privateOp = dyn_cast<acc::PrivateOp>(var.
getDefiningOp())) {
404 auto recipe = generatePrivateRecipe(
407 privateOp.setRecipeAttr(
408 SymbolRefAttr::get(module->getContext(), recipe.getSymName()));
409 }
else if (
auto firstprivateOp =
411 auto recipe = generateFirstprivateRecipe(
414 firstprivateOp.setRecipeAttr(SymbolRefAttr::get(
415 module->getContext(), recipe.getSymName().str()));
431template <
typename OpT>
432Operation *ACCImplicitData::generateDataClauseOpForCandidate(
433 Value var, ModuleOp &module,
OpBuilder &builder, OpT computeConstructOp,
435 const std::optional<acc::ClauseDefaultValue> &defaultClause) {
436 auto &accSupport = this->getAnalysis<acc::OpenACCSupport>();
437 acc::VariableTypeCategory typeCategory =
438 acc::VariableTypeCategory::uncategorized;
439 if (
auto mappableTy = dyn_cast<acc::MappableType>(var.
getType())) {
440 typeCategory = mappableTy.getTypeCategory(var);
441 }
else if (
auto pointerLikeTy =
442 dyn_cast<acc::PointerLikeType>(var.
getType())) {
443 typeCategory = pointerLikeTy.getPointeeTypeCategory(
445 pointerLikeTy.getElementType());
449 acc::bitEnumContainsAny(typeCategory, acc::VariableTypeCategory::scalar);
450 bool isAnyAggregate = acc::bitEnumContainsAny(
451 typeCategory, acc::VariableTypeCategory::aggregate);
452 Location loc = computeConstructOp->getLoc();
455 op = getOriginalDataClauseOpForAlias(var, builder, computeConstructOp,
456 dominatingDataClauses);
458 if (isa<acc::NoCreateOp>(op))
459 return acc::NoCreateOp::create(builder, loc, var,
464 if (isa<acc::DevicePtrOp>(op))
465 return acc::DevicePtrOp::create(builder, loc, var,
472 return acc::PresentOp::create(builder, loc, var,
476 }
else if (isScalar) {
477 if (enableImplicitReductionCopy &&
479 computeConstructOp->getRegion(0))) {
481 acc::CopyinOp::create(builder, loc, var,
484 copyinOp.setDataClause(acc::DataClause::acc_reduction);
485 return copyinOp.getOperation();
487 if constexpr (std::is_same_v<OpT, acc::KernelsOp> ||
488 std::is_same_v<OpT, acc::KernelEnvironmentOp>) {
496 acc::CopyinOp::create(builder, loc, var,
499 copyinOp.setDataClause(acc::DataClause::acc_copy);
500 return copyinOp.getOperation();
503 return acc::FirstprivateOp::create(builder, loc, var,
507 }
else if (isAnyAggregate) {
511 if (defaultClause.has_value() &&
512 defaultClause.value() == acc::ClauseDefaultValue::Present) {
513 newDataOp = acc::PresentOp::create(builder, loc, var,
520 acc::CopyinOp::create(builder, loc, var,
523 copyinOp.setDataClause(acc::DataClause::acc_copy);
524 newDataOp = copyinOp.getOperation();
533 LLVM_DEBUG(llvm::dbgs()
534 <<
"Unhandled case for implicit data mapping " << var <<
"\n");
549static void legalizeValuesInRegion(
Region &accRegion,
552 for (
Value dataClause :
553 llvm::concat<Value>(newDataClauseOperands, newPrivateOperands)) {
560template <
typename OpT>
561static void addNewPrivateOperands(OpT &accOp,
563 if (privateOperands.empty())
566 for (
auto priv : privateOperands) {
567 if (isa<acc::PrivateOp>(priv.getDefiningOp())) {
568 accOp.getPrivateOperandsMutable().append(priv);
569 }
else if (isa<acc::FirstprivateOp>(priv.getDefiningOp())) {
570 accOp.getFirstprivateOperandsMutable().append(priv);
572 llvm_unreachable(
"unhandled reduction operand");
579 for (
auto *user : res.getUsers())
580 if (isa<ACC_DATA_EXIT_OPS>(user))
598 Value lastDataClause =
nullptr;
599 for (
auto dataEntry : llvm::reverse(sortedDataClauseOperands)) {
600 if (llvm::find(newDataClauseOperands, dataEntry) ==
601 newDataClauseOperands.end()) {
604 lastDataClause = dataEntry;
608 if (
auto *dataExitOp = findDataExitOp(lastDataClause.
getDefiningOp()))
610 Operation *dataEntryOp = dataEntry.getDefiningOp();
611 if (isa<acc::CopyinOp>(dataEntryOp)) {
612 auto copyoutOp = acc::CopyoutOp::create(
616 copyoutOp.setDataClause(acc::DataClause::acc_copy);
617 }
else if (isa<acc::PresentOp, acc::NoCreateOp>(dataEntryOp)) {
618 auto deleteOp = acc::DeleteOp::create(
619 builder, dataEntryOp->
getLoc(), dataEntry,
623 }
else if (isa<acc::DevicePtrOp>(dataEntryOp)) {
626 llvm_unreachable(
"unhandled data exit");
628 lastDataClause = dataEntry;
639 baseRefs.push_back(val);
644 if (val != baseRefs.front())
645 baseRefs.insert(baseRefs.begin(), val);
649 if (
auto viewLikeOp = val.
getDefiningOp<ViewLikeOpInterface>()) {
650 val = viewLikeOp.getViewSource();
651 baseRefs.insert(baseRefs.begin(), val);
665 std::find_if(sortedDataClauseOperands.begin(),
666 sortedDataClauseOperands.end(), [&](
Value dataClauseVal) {
669 auto var = acc::getVar(dataClauseVal.getDefiningOp());
670 auto baseRefs = getBaseRefsChain(var);
676 return std::find(baseRefs.begin(), baseRefs.end(),
677 acc::getVar(newClause)) != baseRefs.end();
680 if (insertPos != sortedDataClauseOperands.end()) {
681 newClause->
moveBefore(insertPos->getDefiningOp());
682 sortedDataClauseOperands.insert(insertPos,
acc::getAccVar(newClause));
688template <
typename OpT>
689void ACCImplicitData::generateImplicitDataOps(
690 ModuleOp &module, OpT computeConstructOp,
691 std::optional<acc::ClauseDefaultValue> &defaultClause,
695 if (defaultClause.has_value() &&
696 defaultClause.value() == acc::ClauseDefaultValue::None)
698 assert(!defaultClause.has_value() ||
699 defaultClause.value() == acc::ClauseDefaultValue::Present);
702 Region &accRegion = computeConstructOp->getRegion(0);
707 auto isCandidate{[&](
Value val) ->
bool {
708 return isCandidateForImplicitData(val, accRegion, accSupport);
711 llvm::to_vector(llvm::make_filter_range(liveInValues, isCandidate)));
712 if (candidateVars.empty())
719 if (!candidateVars.empty()) {
720 LLVM_DEBUG(llvm::dbgs() <<
"== Generating clauses for ==\n"
721 << computeConstructOp <<
"\n");
723 auto &domInfo = this->getAnalysis<DominanceInfo>();
724 auto &postDomInfo = this->getAnalysis<PostDominanceInfo>();
725 auto dominatingDataClauses =
727 for (
auto var : candidateVars) {
728 auto newDataClauseOp = generateDataClauseOpForCandidate(
729 var, module, builder, computeConstructOp, dominatingDataClauses,
731 fillInBoundsForUnknownDimensions(newDataClauseOp, builder);
732 LLVM_DEBUG(llvm::dbgs() <<
"Generated data clause for " << var <<
":\n"
733 <<
"\t" << *newDataClauseOp <<
"\n");
734 if (isa_and_nonnull<acc::PrivateOp, acc::FirstprivateOp, acc::ReductionOp>(
737 }
else if (isa_and_nonnull<ACC_DATA_CLAUSE_OPS>(newDataClauseOp)) {
745 legalizeValuesInRegion(accRegion, newPrivateOperands, newDataClauseOperands);
749 if constexpr (!std::is_same_v<OpT, acc::KernelsOp> &&
750 !std::is_same_v<OpT, acc::KernelEnvironmentOp>)
751 generateRecipes(module, builder, computeConstructOp, newPrivateOperands);
755 computeConstructOp.getDataClauseOperands());
756 for (
auto newClause : newDataClauseOperands)
757 insertInSortedOrder(sortedDataClauseOperands, newClause.getDefiningOp());
760 generateDataExitOperations(builder, computeConstructOp, newDataClauseOperands,
761 sortedDataClauseOperands);
763 if constexpr (!std::is_same_v<OpT, acc::KernelsOp> &&
764 !std::is_same_v<OpT, acc::KernelEnvironmentOp>)
765 addNewPrivateOperands(computeConstructOp, newPrivateOperands);
766 computeConstructOp.getDataClauseOperandsMutable().assign(
767 sortedDataClauseOperands);
770void ACCImplicitData::runOnOperation() {
771 ModuleOp module = this->getOperation();
775 module.walk([&](Operation *op) {
776 if (isa<ACC_COMPUTE_CONSTRUCT_OPS, acc::KernelEnvironmentOp>(op)) {
777 assert(op->getNumRegions() == 1 && "must have 1 region");
783 generateImplicitDataOps(module, op, defaultClause, accSupport);
#define ACC_COMPUTE_CONSTRUCT_OPS
#define ACC_DATA_ENTRY_OPS
#define ACC_DATA_EXIT_OPS
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
InFlightDiagnostic emitNYI(Location loc, const Twine &message)
Report a case that is not yet supported by the implementation.
bool isValidValueUse(Value v, Region ®ion)
Check if a value use is legal in an OpenACC region.
std::string getVariableName(Value v)
Get the variable name for a given value.
std::string getRecipeName(RecipeKind kind, Type type, Value var)
Get the recipe name for a given type and value.
static constexpr StringLiteral getFromDefaultClauseAttrName()
mlir::Value getAccVar(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation.
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
std::optional< mlir::acc::DataClause > getDataClause(mlir::Operation *accDataEntryOp)
Used to obtain the dataClause from a data entry operation.
bool isPointerLikeType(mlir::Type type)
Used to check whether the provided type implements the PointerLikeType interface.
mlir::SmallVector< mlir::Value > getBounds(mlir::Operation *accDataClauseOp)
Used to obtain bounds from an acc data clause operation.
std::optional< ClauseDefaultValue > getDefaultAttr(mlir::Operation *op)
Looks for an OpenACC default attribute on the current operation op or in a parent operation which enc...
bool isOnlyUsedByReductionClauses(mlir::Value val, mlir::Region ®ion)
Returns true if this value is only used by acc.reduction operations in the region.
std::optional< llvm::StringRef > getVarName(mlir::Operation *accOp)
Used to obtain the name from an acc operation.
llvm::SmallVector< mlir::Value > getDominatingDataClauses(mlir::Operation *computeConstructOp, mlir::DominanceInfo &domInfo, mlir::PostDominanceInfo &postDomInfo)
Collects all data clauses that dominate the compute construct.
bool isMappableType(mlir::Type type)
Used to check whether the provided type implements the MappableType interface.
mlir::Value getBaseEntity(mlir::Value val)
bool isOnlyUsedByPrivateClauses(mlir::Value val, mlir::Region ®ion)
Returns true if this value is only used by acc.private operations in the region.
Include the generated interface declarations.
void replaceAllUsesInRegionWith(Value orig, Value replacement, Region ®ion)
Replace all uses of orig within the given region with replacement.
llvm::SetVector< T, Vector, Set, N > SetVector
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
void getUsedValuesDefinedAbove(Region ®ion, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
llvm::TypeSwitch< T, ResultT > TypeSwitch