15 #include "llvm/Support/ErrorHandling.h"
19 #define GEN_PASS_DEF_LEGALIZEDATAVALUESINREGION
20 #include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
31 if (isa<ACC_COMPUTE_CONSTRUCT_OPS>(parent)) {
42 for (
auto operand : operands) {
45 if (varPtr && accPtr) {
47 values.push_back({varPtr, accPtr});
49 values.push_back({accPtr, varPtr});
54 template <
typename Op>
55 static void replaceAllUsesInAccComputeRegionsWith(
Value orig,
Value replacement,
57 for (
auto &use : llvm::make_early_inc_range(orig.
getUses())) {
58 if (outerRegion.
isAncestor(use.getOwner()->getParentRegion())) {
59 if constexpr (std::is_same_v<Op, acc::DataOp> ||
60 std::is_same_v<Op, acc::DeclareOp>) {
63 if (insideAccComputeRegion(use.getOwner())) {
73 template <
typename Op>
74 static void collectAndReplaceInRegion(
Op &op,
bool hostToDevice) {
77 if constexpr (std::is_same_v<Op, acc::LoopOp>) {
78 collectPtrs(op.getReductionOperands(), values, hostToDevice);
79 collectPtrs(op.getPrivateOperands(), values, hostToDevice);
81 collectPtrs(op.getDataClauseOperands(), values, hostToDevice);
82 if constexpr (!std::is_same_v<Op, acc::KernelsOp> &&
83 !std::is_same_v<Op, acc::DataOp> &&
84 !std::is_same_v<Op, acc::DeclareOp>) {
85 collectPtrs(op.getReductionOperands(), values, hostToDevice);
86 collectPtrs(op.getPrivateOperands(), values, hostToDevice);
87 collectPtrs(op.getFirstprivateOperands(), values, hostToDevice);
92 replaceAllUsesInAccComputeRegionsWith<Op>(std::get<0>(p), std::get<1>(p),
96 class LegalizeDataValuesInRegion
97 :
public acc::impl::LegalizeDataValuesInRegionBase<
98 LegalizeDataValuesInRegion> {
100 using LegalizeDataValuesInRegionBase<
101 LegalizeDataValuesInRegion>::LegalizeDataValuesInRegionBase;
103 void runOnOperation()
override {
104 func::FuncOp funcOp = getOperation();
105 bool replaceHostVsDevice = this->hostToDevice.getValue();
108 if (!isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(*op) &&
109 !(isa<ACC_DATA_CONSTRUCT_STRUCTURED_OPS>(*op) &&
110 applyToAccDataConstruct))
113 if (
auto parallelOp = dyn_cast<acc::ParallelOp>(*op)) {
114 collectAndReplaceInRegion(parallelOp, replaceHostVsDevice);
115 }
else if (
auto serialOp = dyn_cast<acc::SerialOp>(*op)) {
116 collectAndReplaceInRegion(serialOp, replaceHostVsDevice);
117 }
else if (
auto kernelsOp = dyn_cast<acc::KernelsOp>(*op)) {
118 collectAndReplaceInRegion(kernelsOp, replaceHostVsDevice);
119 }
else if (
auto loopOp = dyn_cast<acc::LoopOp>(*op)) {
120 collectAndReplaceInRegion(loopOp, replaceHostVsDevice);
121 }
else if (
auto dataOp = dyn_cast<acc::DataOp>(*op)) {
122 collectAndReplaceInRegion(dataOp, replaceHostVsDevice);
123 }
else if (
auto declareOp = dyn_cast<acc::DeclareOp>(*op)) {
124 collectAndReplaceInRegion(declareOp, replaceHostVsDevice);
126 llvm_unreachable(
"unsupported acc region op");
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
mlir::Value getVarPtr(mlir::Operation *accDataClauseOp)
Used to obtain the varPtr from a data clause operation.
mlir::Value getAccPtr(mlir::Operation *accDataClauseOp)
Used to obtain the accPtr from a data clause operation.
Include the generated interface declarations.