16#include "llvm/Support/ErrorHandling.h"
20#define GEN_PASS_DEF_LEGALIZEDATAVALUESINREGION
21#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
32 if (isa<ACC_COMPUTE_CONSTRUCT_OPS>(parent)) {
43 for (
auto operand : operands) {
48 values.push_back({var, accVar});
50 values.push_back({accVar, var});
58 for (
auto &use : llvm::make_early_inc_range(orig.
getUses())) {
59 if (outerRegion.
isAncestor(use.getOwner()->getParentRegion())) {
60 if constexpr (std::is_same_v<Op, acc::DataOp> ||
61 std::is_same_v<Op, acc::DeclareOp>) {
64 if (insideAccComputeRegion(use.getOwner())) {
75static void replaceAllUsesInUnstructuredComputeRegionWith(
80 if constexpr (std::is_same_v<Op, acc::DeclareEnterOp>) {
82 for (
auto *user : op.getToken().getUsers()) {
83 if (
auto declareExit = dyn_cast<acc::DeclareExitOp>(user))
84 exitOps.push_back(declareExit);
90 for (
auto p : values) {
91 Value hostVal = std::get<0>(p);
92 Value deviceVal = std::get<1>(p);
93 for (
auto &use : llvm::make_early_inc_range(hostVal.
getUses())) {
102 bool hasPostDominatingExit =
false;
103 for (
auto *exit : exitOps) {
105 hasPostDominatingExit =
true;
110 if (!hasPostDominatingExit)
113 if (insideAccComputeRegion(owner))
119template <
typename Op>
121collectAndReplaceInRegion(
Op &op,
bool hostToDevice,
126 if constexpr (std::is_same_v<Op, acc::LoopOp>) {
127 collectVars(op.getReductionOperands(), values, hostToDevice);
128 collectVars(op.getPrivateOperands(), values, hostToDevice);
130 collectVars(op.getDataClauseOperands(), values, hostToDevice);
131 if constexpr (!std::is_same_v<Op, acc::KernelsOp> &&
132 !std::is_same_v<Op, acc::DataOp> &&
133 !std::is_same_v<Op, acc::DeclareOp> &&
134 !std::is_same_v<Op, acc::HostDataOp> &&
135 !std::is_same_v<Op, acc::DeclareEnterOp>) {
136 collectVars(op.getReductionOperands(), values, hostToDevice);
137 collectVars(op.getPrivateOperands(), values, hostToDevice);
138 collectVars(op.getFirstprivateOperands(), values, hostToDevice);
142 if constexpr (std::is_same_v<Op, acc::DeclareEnterOp>) {
143 assert(domInfo && postDomInfo &&
144 "Dominance info required for DeclareEnterOp");
145 replaceAllUsesInUnstructuredComputeRegionWith<Op>(op, values, *domInfo,
148 for (
auto p : values) {
149 replaceAllUsesInAccComputeRegionsWith<Op>(std::get<0>(p), std::get<1>(p),
155class LegalizeDataValuesInRegion
157 LegalizeDataValuesInRegion> {
159 using LegalizeDataValuesInRegionBase<
160 LegalizeDataValuesInRegion>::LegalizeDataValuesInRegionBase;
162 void runOnOperation()
override {
163 func::FuncOp funcOp = getOperation();
164 bool replaceHostVsDevice = this->hostToDevice.getValue();
167 DominanceInfo domInfo;
168 PostDominanceInfo postDomInfo;
169 bool computedDomInfo =
false;
171 funcOp.walk([&](Operation *op) {
172 if (!isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(*op) &&
173 !(isa<ACC_DATA_CONSTRUCT_STRUCTURED_OPS>(*op) &&
174 applyToAccDataConstruct) &&
175 !isa<acc::DeclareEnterOp>(*op))
178 if (
auto parallelOp = dyn_cast<acc::ParallelOp>(*op)) {
179 collectAndReplaceInRegion(parallelOp, replaceHostVsDevice);
180 }
else if (
auto serialOp = dyn_cast<acc::SerialOp>(*op)) {
181 collectAndReplaceInRegion(serialOp, replaceHostVsDevice);
182 }
else if (
auto kernelsOp = dyn_cast<acc::KernelsOp>(*op)) {
183 collectAndReplaceInRegion(kernelsOp, replaceHostVsDevice);
184 }
else if (
auto loopOp = dyn_cast<acc::LoopOp>(*op)) {
185 collectAndReplaceInRegion(loopOp, replaceHostVsDevice);
186 }
else if (
auto dataOp = dyn_cast<acc::DataOp>(*op)) {
187 collectAndReplaceInRegion(dataOp, replaceHostVsDevice);
188 }
else if (
auto declareOp = dyn_cast<acc::DeclareOp>(*op)) {
189 collectAndReplaceInRegion(declareOp, replaceHostVsDevice);
190 }
else if (
auto hostDataOp = dyn_cast<acc::HostDataOp>(*op)) {
191 collectAndReplaceInRegion(hostDataOp, replaceHostVsDevice);
192 }
else if (
auto declareEnterOp = dyn_cast<acc::DeclareEnterOp>(*op)) {
193 if (!computedDomInfo) {
194 domInfo = DominanceInfo(funcOp);
195 postDomInfo = PostDominanceInfo(funcOp);
196 computedDomInfo =
true;
198 collectAndReplaceInRegion(declareEnterOp, replaceHostVsDevice, &domInfo,
201 llvm_unreachable(
"unsupported acc region op");
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
A class for computing basic dominance information.
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
This provides public APIs that all operations should have.
Operation * getOperation()
Inherit getOperation from OpState.
Operation is the basic unit of execution within MLIR.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
A class for computing basic postdominance information.
bool postDominates(Operation *a, Operation *b) const
Return true if operation A postdominates operation B.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
mlir::Value getAccVar(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation.
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
Include the generated interface declarations.