16 #include "llvm/Support/ErrorHandling.h"
20 #define GEN_PASS_DEF_LEGALIZEDATAVALUESINREGION
21 #include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
32 if (isa<ACC_COMPUTE_CONSTRUCT_OPS>(parent)) {
43 for (
auto operand : operands) {
48 values.push_back({var, accVar});
50 values.push_back({accVar, var});
55 template <
typename Op>
56 static void replaceAllUsesInAccComputeRegionsWith(
Value orig,
Value replacement,
58 for (
auto &use : llvm::make_early_inc_range(orig.
getUses())) {
59 if (outerRegion.
isAncestor(use.getOwner()->getParentRegion())) {
60 if constexpr (std::is_same_v<Op, acc::DataOp> ||
61 std::is_same_v<Op, acc::DeclareOp>) {
64 if (insideAccComputeRegion(use.getOwner())) {
74 template <
typename Op>
75 static void replaceAllUsesInUnstructuredComputeRegionWith(
80 if constexpr (std::is_same_v<Op, acc::DeclareEnterOp>) {
82 for (
auto *user : op.getToken().getUsers()) {
83 if (
auto declareExit = dyn_cast<acc::DeclareExitOp>(user))
84 exitOps.push_back(declareExit);
90 for (
auto p : values) {
91 Value hostVal = std::get<0>(p);
92 Value deviceVal = std::get<1>(p);
93 for (
auto &use : llvm::make_early_inc_range(hostVal.
getUses())) {
102 bool hasPostDominatingExit =
false;
103 for (
auto *exit : exitOps) {
105 hasPostDominatingExit =
true;
110 if (!hasPostDominatingExit)
113 if (insideAccComputeRegion(owner))
119 template <
typename Op>
121 collectAndReplaceInRegion(
Op &op,
bool hostToDevice,
126 if constexpr (std::is_same_v<Op, acc::LoopOp>) {
127 collectVars(op.getReductionOperands(), values, hostToDevice);
128 collectVars(op.getPrivateOperands(), values, hostToDevice);
130 collectVars(op.getDataClauseOperands(), values, hostToDevice);
131 if constexpr (!std::is_same_v<Op, acc::KernelsOp> &&
132 !std::is_same_v<Op, acc::DataOp> &&
133 !std::is_same_v<Op, acc::DeclareOp> &&
134 !std::is_same_v<Op, acc::HostDataOp> &&
135 !std::is_same_v<Op, acc::DeclareEnterOp>) {
136 collectVars(op.getReductionOperands(), values, hostToDevice);
137 collectVars(op.getPrivateOperands(), values, hostToDevice);
138 collectVars(op.getFirstprivateOperands(), values, hostToDevice);
142 if constexpr (std::is_same_v<Op, acc::DeclareEnterOp>) {
143 assert(domInfo && postDomInfo &&
144 "Dominance info required for DeclareEnterOp");
145 replaceAllUsesInUnstructuredComputeRegionWith<Op>(op, values, *domInfo,
148 for (
auto p : values) {
149 replaceAllUsesInAccComputeRegionsWith<Op>(std::get<0>(p), std::get<1>(p),
155 class LegalizeDataValuesInRegion
156 :
public acc::impl::LegalizeDataValuesInRegionBase<
157 LegalizeDataValuesInRegion> {
159 using LegalizeDataValuesInRegionBase<
160 LegalizeDataValuesInRegion>::LegalizeDataValuesInRegionBase;
162 void runOnOperation()
override {
163 func::FuncOp funcOp = getOperation();
164 bool replaceHostVsDevice = this->hostToDevice.getValue();
169 bool computedDomInfo =
false;
172 if (!isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(*op) &&
173 !(isa<ACC_DATA_CONSTRUCT_STRUCTURED_OPS>(*op) &&
174 applyToAccDataConstruct) &&
175 !isa<acc::DeclareEnterOp>(*op))
178 if (
auto parallelOp = dyn_cast<acc::ParallelOp>(*op)) {
179 collectAndReplaceInRegion(parallelOp, replaceHostVsDevice);
180 }
else if (
auto serialOp = dyn_cast<acc::SerialOp>(*op)) {
181 collectAndReplaceInRegion(serialOp, replaceHostVsDevice);
182 }
else if (
auto kernelsOp = dyn_cast<acc::KernelsOp>(*op)) {
183 collectAndReplaceInRegion(kernelsOp, replaceHostVsDevice);
184 }
else if (
auto loopOp = dyn_cast<acc::LoopOp>(*op)) {
185 collectAndReplaceInRegion(loopOp, replaceHostVsDevice);
186 }
else if (
auto dataOp = dyn_cast<acc::DataOp>(*op)) {
187 collectAndReplaceInRegion(dataOp, replaceHostVsDevice);
188 }
else if (
auto declareOp = dyn_cast<acc::DeclareOp>(*op)) {
189 collectAndReplaceInRegion(declareOp, replaceHostVsDevice);
190 }
else if (
auto hostDataOp = dyn_cast<acc::HostDataOp>(*op)) {
191 collectAndReplaceInRegion(hostDataOp, replaceHostVsDevice);
192 }
else if (
auto declareEnterOp = dyn_cast<acc::DeclareEnterOp>(*op)) {
193 if (!computedDomInfo) {
196 computedDomInfo =
true;
198 collectAndReplaceInRegion(declareEnterOp, replaceHostVsDevice, &domInfo,
201 llvm_unreachable(
"unsupported acc region op");
A class for computing basic dominance information.
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
This provides public APIs that all operations should have.
Operation * getOperation()
Inherit getOperation from OpState.
Operation is the basic unit of execution within MLIR.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
A class for computing basic postdominance information.
bool postDominates(Operation *a, Operation *b) const
Return true if operation A postdominates operation B.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
mlir::Value getAccVar(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation.
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
Include the generated interface declarations.