15 #include "llvm/Support/ErrorHandling.h"
19 #define GEN_PASS_DEF_LEGALIZEDATAVALUESINREGION
20 #include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
31 if (isa<ACC_COMPUTE_CONSTRUCT_OPS>(parent)) {
42 for (
auto operand : operands) {
47 values.push_back({var, accVar});
49 values.push_back({accVar, var});
54 template <
typename Op>
55 static void replaceAllUsesInAccComputeRegionsWith(
Value orig,
Value replacement,
57 for (
auto &use : llvm::make_early_inc_range(orig.
getUses())) {
58 if (outerRegion.
isAncestor(use.getOwner()->getParentRegion())) {
59 if constexpr (std::is_same_v<Op, acc::DataOp> ||
60 std::is_same_v<Op, acc::DeclareOp>) {
63 if (insideAccComputeRegion(use.getOwner())) {
73 template <
typename Op>
74 static void collectAndReplaceInRegion(
Op &op,
bool hostToDevice) {
77 if constexpr (std::is_same_v<Op, acc::LoopOp>) {
78 collectVars(op.getReductionOperands(), values, hostToDevice);
79 collectVars(op.getPrivateOperands(), values, hostToDevice);
81 collectVars(op.getDataClauseOperands(), values, hostToDevice);
82 if constexpr (!std::is_same_v<Op, acc::KernelsOp> &&
83 !std::is_same_v<Op, acc::DataOp> &&
84 !std::is_same_v<Op, acc::DeclareOp> &&
85 !std::is_same_v<Op, acc::HostDataOp>) {
86 collectVars(op.getReductionOperands(), values, hostToDevice);
87 collectVars(op.getPrivateOperands(), values, hostToDevice);
88 collectVars(op.getFirstprivateOperands(), values, hostToDevice);
93 replaceAllUsesInAccComputeRegionsWith<Op>(std::get<0>(p), std::get<1>(p),
97 class LegalizeDataValuesInRegion
98 :
public acc::impl::LegalizeDataValuesInRegionBase<
99 LegalizeDataValuesInRegion> {
101 using LegalizeDataValuesInRegionBase<
102 LegalizeDataValuesInRegion>::LegalizeDataValuesInRegionBase;
104 void runOnOperation()
override {
105 func::FuncOp funcOp = getOperation();
106 bool replaceHostVsDevice = this->hostToDevice.getValue();
109 if (!isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(*op) &&
110 !(isa<ACC_DATA_CONSTRUCT_STRUCTURED_OPS>(*op) &&
111 applyToAccDataConstruct))
114 if (
auto parallelOp = dyn_cast<acc::ParallelOp>(*op)) {
115 collectAndReplaceInRegion(parallelOp, replaceHostVsDevice);
116 }
else if (
auto serialOp = dyn_cast<acc::SerialOp>(*op)) {
117 collectAndReplaceInRegion(serialOp, replaceHostVsDevice);
118 }
else if (
auto kernelsOp = dyn_cast<acc::KernelsOp>(*op)) {
119 collectAndReplaceInRegion(kernelsOp, replaceHostVsDevice);
120 }
else if (
auto loopOp = dyn_cast<acc::LoopOp>(*op)) {
121 collectAndReplaceInRegion(loopOp, replaceHostVsDevice);
122 }
else if (
auto dataOp = dyn_cast<acc::DataOp>(*op)) {
123 collectAndReplaceInRegion(dataOp, replaceHostVsDevice);
124 }
else if (
auto declareOp = dyn_cast<acc::DeclareOp>(*op)) {
125 collectAndReplaceInRegion(declareOp, replaceHostVsDevice);
126 }
else if (
auto hostDataOp = dyn_cast<acc::HostDataOp>(*op)) {
127 collectAndReplaceInRegion(hostDataOp, replaceHostVsDevice);
129 llvm_unreachable(
"unsupported acc region op");
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
mlir::Value getAccVar(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation.
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
Include the generated interface declarations.