201#include "llvm/ADT/SmallVectorExtras.h"
215#include "llvm/ADT/STLExtras.h"
216#include "llvm/ADT/SmallVector.h"
217#include "llvm/ADT/TypeSwitch.h"
218#include "llvm/Support/ErrorHandling.h"
219#include <type_traits>
223#define GEN_PASS_DEF_ACCIMPLICITDATA
224#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
228#define DEBUG_TYPE "acc-implicit-data"
234class ACCImplicitData :
public acc::impl::ACCImplicitDataBase<ACCImplicitData> {
236 using acc::impl::ACCImplicitDataBase<ACCImplicitData>::ACCImplicitDataBase;
238 void runOnOperation()
override;
243 template <
typename OpT>
244 Operation *getOriginalDataClauseOpForAlias(
250 template <
typename OpT>
251 Operation *generateDataClauseOpForCandidate(
252 Value var, ModuleOp &module,
OpBuilder &builder, OpT computeConstructOp,
254 const std::optional<acc::ClauseDefaultValue> &defaultClause);
257 template <
typename OpT>
259 generateImplicitDataOps(ModuleOp &module, OpT computeConstructOp,
260 std::optional<acc::ClauseDefaultValue> &defaultClause,
264 acc::PrivateRecipeOp generatePrivateRecipe(ModuleOp &module,
Value var,
269 acc::FirstprivateRecipeOp
270 generateFirstprivateRecipe(ModuleOp &module,
Value var,
Location loc,
275 void generateRecipes(ModuleOp &module,
OpBuilder &builder,
282static bool isCandidateForImplicitData(
Value val,
Region &accRegion,
291 if (isa_and_nonnull<ACC_DATA_ENTRY_OPS>(val.
getDefiningOp()))
305template <
typename OpT>
306Operation *ACCImplicitData::getOriginalDataClauseOpForAlias(
309 auto &aliasAnalysis = this->getAnalysis<AliasAnalysis>();
310 for (
auto dataClause : dominatingDataClauses) {
311 if (
auto *dataClauseOp = dataClause.getDefiningOp()) {
313 if (isa<acc::CopyinOp, acc::CreateOp, acc::PresentOp, acc::NoCreateOp,
314 acc::DevicePtrOp>(dataClauseOp))
315 if (aliasAnalysis.alias(
acc::getVar(dataClauseOp), var).isMust()) {
316 LLVM_DEBUG(llvm::dbgs()
317 <<
"Using existing data clause:\n\t" << *dataClauseOp
318 <<
"\n\tas reference when processing var:\n\t" << var
328static void fillInBoundsForUnknownDimensions(
Operation *dataClauseOp,
339 if (
auto mappableTy = dyn_cast<acc::MappableType>(type)) {
340 if (mappableTy.hasUnknownDimensions()) {
343 if (std::is_same_v<
decltype(dataClauseOp), acc::DevicePtrOp>)
347 auto bounds = mappableTy.generateAccBounds(var, builder);
349 dataClauseOp.getBoundsMutable().assign(bounds);
356ACCImplicitData::generatePrivateRecipe(ModuleOp &module,
Value var,
360 std::string recipeName =
361 accSupport.
getRecipeName(acc::RecipeKind::private_recipe, type, var);
364 auto existingRecipe =
module.lookupSymbol<acc::PrivateRecipeOp>(recipeName);
366 return existingRecipe;
373 acc::PrivateRecipeOp::createAndPopulate(builder, loc, recipeName, var);
374 if (!recipe.has_value())
375 return accSupport.
emitNYI(loc,
"implicit private"),
nullptr;
376 return recipe.value();
379acc::FirstprivateRecipeOp
380ACCImplicitData::generateFirstprivateRecipe(ModuleOp &module,
Value var,
384 std::string recipeName =
385 accSupport.
getRecipeName(acc::RecipeKind::firstprivate_recipe, type, var);
388 auto existingRecipe =
389 module.lookupSymbol<acc::FirstprivateRecipeOp>(recipeName);
391 return existingRecipe;
397 auto recipe = acc::FirstprivateRecipeOp::createAndPopulate(builder, loc,
399 if (!recipe.has_value())
400 return accSupport.
emitNYI(loc,
"implicit firstprivate"),
nullptr;
401 return recipe.value();
404void ACCImplicitData::generateRecipes(ModuleOp &module,
OpBuilder &builder,
407 auto &accSupport = this->getAnalysis<acc::OpenACCSupport>();
408 for (
auto var : newOperands) {
411 auto recipe = generatePrivateRecipe(
414 privateOp.setRecipeAttr(
415 SymbolRefAttr::get(module->getContext(), recipe.getSymName()));
416 }
else if (
auto firstprivateOp = var.
getDefiningOp<acc::FirstprivateOp>()) {
417 auto recipe = generateFirstprivateRecipe(
420 firstprivateOp.setRecipeAttr(SymbolRefAttr::get(
421 module->getContext(), recipe.getSymName().str()));
437template <
typename OpT>
438Operation *ACCImplicitData::generateDataClauseOpForCandidate(
439 Value var, ModuleOp &module,
OpBuilder &builder, OpT computeConstructOp,
441 const std::optional<acc::ClauseDefaultValue> &defaultClause) {
442 auto &accSupport = this->getAnalysis<acc::OpenACCSupport>();
443 acc::VariableTypeCategory typeCategory =
444 acc::VariableTypeCategory::uncategorized;
445 if (
auto mappableTy = dyn_cast<acc::MappableType>(var.
getType())) {
446 typeCategory = mappableTy.getTypeCategory(var);
447 }
else if (
auto pointerLikeTy =
448 dyn_cast<acc::PointerLikeType>(var.
getType())) {
449 typeCategory = pointerLikeTy.getPointeeTypeCategory(
451 pointerLikeTy.getElementType());
455 acc::bitEnumContainsAny(typeCategory, acc::VariableTypeCategory::scalar);
456 bool isAnyAggregate = acc::bitEnumContainsAny(
457 typeCategory, acc::VariableTypeCategory::aggregate);
458 Location loc = computeConstructOp->getLoc();
462 LLVM_DEBUG(llvm::dbgs() <<
"Using deviceptr clause because variable is "
464 return acc::DevicePtrOp::create(builder, loc, var,
470 op = getOriginalDataClauseOpForAlias(var, builder, computeConstructOp,
471 dominatingDataClauses);
473 if (isa<acc::NoCreateOp>(op))
474 return acc::NoCreateOp::create(builder, loc, var,
479 if (isa<acc::DevicePtrOp>(op))
480 return acc::DevicePtrOp::create(builder, loc, var,
487 return acc::PresentOp::create(builder, loc, var,
494 if (enableImplicitReductionCopy &&
496 computeConstructOp->getRegion(0))) {
498 acc::CopyinOp::create(builder, loc, var,
501 copyinOp.setDataClause(acc::DataClause::acc_reduction);
502 return copyinOp.getOperation();
504 if constexpr (std::is_same_v<OpT, acc::KernelsOp> ||
505 std::is_same_v<OpT, acc::KernelEnvironmentOp>) {
513 acc::CopyinOp::create(builder, loc, var,
516 copyinOp.setDataClause(acc::DataClause::acc_copy);
517 return copyinOp.getOperation();
520 return acc::FirstprivateOp::create(builder, loc, var,
524 }
else if (isAnyAggregate) {
528 if (defaultClause.has_value() &&
529 defaultClause.value() == acc::ClauseDefaultValue::Present) {
530 newDataOp = acc::PresentOp::create(builder, loc, var,
537 acc::CopyinOp::create(builder, loc, var,
540 copyinOp.setDataClause(acc::DataClause::acc_copy);
541 newDataOp = copyinOp.getOperation();
550 LLVM_DEBUG(llvm::dbgs()
551 <<
"Unhandled case for implicit data mapping " << var <<
"\n");
566static void legalizeValuesInRegion(
Region &accRegion,
569 for (
Value dataClause :
570 llvm::concat<Value>(newDataClauseOperands, newPrivateOperands)) {
577template <
typename OpT>
578static void addNewPrivateOperands(OpT &accOp,
580 if (privateOperands.empty())
583 for (
auto priv : privateOperands) {
584 if (isa<acc::PrivateOp>(priv.getDefiningOp())) {
585 accOp.getPrivateOperandsMutable().append(priv);
586 }
else if (isa<acc::FirstprivateOp>(priv.getDefiningOp())) {
587 accOp.getFirstprivateOperandsMutable().append(priv);
589 llvm_unreachable(
"unhandled reduction operand");
596 for (
auto *user : res.getUsers())
597 if (isa<ACC_DATA_EXIT_OPS>(user))
615 Value lastDataClause =
nullptr;
616 for (
auto dataEntry : llvm::reverse(sortedDataClauseOperands)) {
617 if (llvm::find(newDataClauseOperands, dataEntry) ==
618 newDataClauseOperands.end()) {
621 lastDataClause = dataEntry;
625 if (
auto *dataExitOp = findDataExitOp(lastDataClause.
getDefiningOp()))
627 Operation *dataEntryOp = dataEntry.getDefiningOp();
628 if (isa<acc::CopyinOp>(dataEntryOp)) {
629 auto copyoutOp = acc::CopyoutOp::create(
633 copyoutOp.setDataClause(acc::DataClause::acc_copy);
634 }
else if (isa<acc::PresentOp, acc::NoCreateOp>(dataEntryOp)) {
635 auto deleteOp = acc::DeleteOp::create(
636 builder, dataEntryOp->
getLoc(), dataEntry,
640 }
else if (isa<acc::DevicePtrOp>(dataEntryOp)) {
643 llvm_unreachable(
"unhandled data exit");
645 lastDataClause = dataEntry;
656 baseRefs.push_back(val);
661 if (val != baseRefs.front())
662 baseRefs.insert(baseRefs.begin(), val);
666 if (
auto viewLikeOp = val.
getDefiningOp<ViewLikeOpInterface>()) {
667 val = viewLikeOp.getViewSource();
668 baseRefs.insert(baseRefs.begin(), val);
682 std::find_if(sortedDataClauseOperands.begin(),
683 sortedDataClauseOperands.end(), [&](
Value dataClauseVal) {
686 auto var = acc::getVar(dataClauseVal.getDefiningOp());
687 auto baseRefs = getBaseRefsChain(var);
693 return std::find(baseRefs.begin(), baseRefs.end(),
694 acc::getVar(newClause)) != baseRefs.end();
697 if (insertPos != sortedDataClauseOperands.end()) {
698 newClause->
moveBefore(insertPos->getDefiningOp());
699 sortedDataClauseOperands.insert(insertPos,
acc::getAccVar(newClause));
705template <
typename OpT>
706void ACCImplicitData::generateImplicitDataOps(
707 ModuleOp &module, OpT computeConstructOp,
708 std::optional<acc::ClauseDefaultValue> &defaultClause,
712 if (defaultClause.has_value() &&
713 defaultClause.value() == acc::ClauseDefaultValue::None)
715 assert(!defaultClause.has_value() ||
716 defaultClause.value() == acc::ClauseDefaultValue::Present);
719 Region &accRegion = computeConstructOp->getRegion(0);
724 auto isCandidate{[&](
Value val) ->
bool {
725 return isCandidateForImplicitData(val, accRegion, accSupport);
727 auto candidateVars(llvm::filter_to_vector(liveInValues, isCandidate));
728 if (candidateVars.empty())
735 if (!candidateVars.empty()) {
736 LLVM_DEBUG(llvm::dbgs() <<
"== Generating clauses for ==\n"
737 << computeConstructOp <<
"\n");
739 auto &domInfo = this->getAnalysis<DominanceInfo>();
740 auto &postDomInfo = this->getAnalysis<PostDominanceInfo>();
741 auto dominatingDataClauses =
743 for (
auto var : candidateVars) {
744 auto newDataClauseOp = generateDataClauseOpForCandidate(
745 var, module, builder, computeConstructOp, dominatingDataClauses,
747 fillInBoundsForUnknownDimensions(newDataClauseOp, builder);
748 LLVM_DEBUG(llvm::dbgs() <<
"Generated data clause for " << var <<
":\n"
749 <<
"\t" << *newDataClauseOp <<
"\n");
750 if (isa_and_nonnull<acc::PrivateOp, acc::FirstprivateOp, acc::ReductionOp>(
753 }
else if (isa_and_nonnull<ACC_DATA_CLAUSE_OPS>(newDataClauseOp)) {
761 legalizeValuesInRegion(accRegion, newPrivateOperands, newDataClauseOperands);
765 if constexpr (!std::is_same_v<OpT, acc::KernelsOp> &&
766 !std::is_same_v<OpT, acc::KernelEnvironmentOp>)
767 generateRecipes(module, builder, computeConstructOp, newPrivateOperands);
771 computeConstructOp.getDataClauseOperands());
772 for (
auto newClause : newDataClauseOperands)
773 insertInSortedOrder(sortedDataClauseOperands, newClause.getDefiningOp());
776 generateDataExitOperations(builder, computeConstructOp, newDataClauseOperands,
777 sortedDataClauseOperands);
779 if constexpr (!std::is_same_v<OpT, acc::KernelsOp> &&
780 !std::is_same_v<OpT, acc::KernelEnvironmentOp>)
781 addNewPrivateOperands(computeConstructOp, newPrivateOperands);
782 computeConstructOp.getDataClauseOperandsMutable().assign(
783 sortedDataClauseOperands);
786void ACCImplicitData::runOnOperation() {
787 ModuleOp module = this->getOperation();
791 module.walk([&](Operation *op) {
792 if (isa<ACC_COMPUTE_CONSTRUCT_OPS, acc::KernelEnvironmentOp>(op)) {
793 assert(op->getNumRegions() == 1 && "must have 1 region");
799 generateImplicitDataOps(module, op, defaultClause, accSupport);
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
InFlightDiagnostic emitNYI(Location loc, const Twine &message)
Report a case that is not yet supported by the implementation.
bool isValidValueUse(Value v, Region ®ion)
Check if a value use is legal in an OpenACC region.
std::string getVariableName(Value v)
Get the variable name for a given value.
std::string getRecipeName(RecipeKind kind, Type type, Value var)
Get the recipe name for a given type and value.
#define ACC_COMPUTE_CONSTRUCT_OPS
#define ACC_DATA_ENTRY_OPS
#define ACC_DATA_EXIT_OPS
static constexpr StringLiteral getFromDefaultClauseAttrName()
mlir::Value getAccVar(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation.
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
std::optional< mlir::acc::DataClause > getDataClause(mlir::Operation *accDataEntryOp)
Used to obtain the dataClause from a data entry operation.
bool isPointerLikeType(mlir::Type type)
Used to check whether the provided type implements the PointerLikeType interface.
mlir::SmallVector< mlir::Value > getBounds(mlir::Operation *accDataClauseOp)
Used to obtain bounds from an acc data clause operation.
std::optional< ClauseDefaultValue > getDefaultAttr(mlir::Operation *op)
Looks for an OpenACC default attribute on the current operation op or in a parent operation which enc...
bool isOnlyUsedByReductionClauses(mlir::Value val, mlir::Region ®ion)
Returns true if this value is only used by acc.reduction operations in the region.
std::optional< llvm::StringRef > getVarName(mlir::Operation *accOp)
Used to obtain the name from an acc operation.
llvm::SmallVector< mlir::Value > getDominatingDataClauses(mlir::Operation *computeConstructOp, mlir::DominanceInfo &domInfo, mlir::PostDominanceInfo &postDomInfo)
Collects all data clauses that dominate the compute construct.
bool isMappableType(mlir::Type type)
Used to check whether the provided type implements the MappableType interface.
bool isDeviceValue(mlir::Value val)
Check if a value represents device data.
mlir::Value getBaseEntity(mlir::Value val)
Include the generated interface declarations.
void replaceAllUsesInRegionWith(Value orig, Value replacement, Region ®ion)
Replace all uses of orig within the given region with replacement.
llvm::SetVector< T, Vector, Set, N > SetVector
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
void getUsedValuesDefinedAbove(Region ®ion, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
llvm::TypeSwitch< T, ResultT > TypeSwitch