16#include "llvm/Support/ErrorHandling.h"
20#define GEN_PASS_DEF_LEGALIZEDATAVALUESINREGION
21#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
36 for (
auto operand : operands) {
41 values.push_back({var, accVar});
43 values.push_back({accVar, var});
51 for (
auto &use : llvm::make_early_inc_range(orig.
getUses())) {
52 if (outerRegion.
isAncestor(use.getOwner()->getParentRegion())) {
53 if constexpr (std::is_same_v<Op, acc::DataOp> ||
54 std::is_same_v<Op, acc::DeclareOp>) {
57 if (insideAccComputeRegion(use.getOwner())) {
68static void replaceAllUsesInUnstructuredComputeRegionWith(
73 if constexpr (std::is_same_v<Op, acc::DeclareEnterOp>) {
75 for (
auto *user : op.getToken().getUsers()) {
76 if (
auto declareExit = dyn_cast<acc::DeclareExitOp>(user))
77 exitOps.push_back(declareExit);
83 for (
auto p : values) {
84 Value hostVal = std::get<0>(p);
85 Value deviceVal = std::get<1>(p);
86 for (
auto &use : llvm::make_early_inc_range(hostVal.
getUses())) {
95 bool hasPostDominatingExit =
false;
96 for (
auto *exit : exitOps) {
98 hasPostDominatingExit =
true;
103 if (!hasPostDominatingExit)
106 if (insideAccComputeRegion(owner))
112template <
typename Op>
114collectAndReplaceInRegion(
Op &op,
bool hostToDevice,
119 if constexpr (std::is_same_v<Op, acc::LoopOp>) {
120 collectVars(op.getReductionOperands(), values, hostToDevice);
121 collectVars(op.getPrivateOperands(), values, hostToDevice);
123 collectVars(op.getDataClauseOperands(), values, hostToDevice);
124 if constexpr (!std::is_same_v<Op, acc::KernelsOp> &&
125 !std::is_same_v<Op, acc::DataOp> &&
126 !std::is_same_v<Op, acc::DeclareOp> &&
127 !std::is_same_v<Op, acc::HostDataOp> &&
128 !std::is_same_v<Op, acc::DeclareEnterOp>) {
129 collectVars(op.getReductionOperands(), values, hostToDevice);
130 collectVars(op.getPrivateOperands(), values, hostToDevice);
131 collectVars(op.getFirstprivateOperands(), values, hostToDevice);
135 if constexpr (std::is_same_v<Op, acc::DeclareEnterOp>) {
136 assert(domInfo && postDomInfo &&
137 "Dominance info required for DeclareEnterOp");
138 replaceAllUsesInUnstructuredComputeRegionWith<Op>(op, values, *domInfo,
141 for (
auto p : values) {
142 replaceAllUsesInAccComputeRegionsWith<Op>(std::get<0>(p), std::get<1>(p),
148class LegalizeDataValuesInRegion
149 :
public acc::impl::LegalizeDataValuesInRegionBase<
150 LegalizeDataValuesInRegion> {
152 using LegalizeDataValuesInRegionBase<
153 LegalizeDataValuesInRegion>::LegalizeDataValuesInRegionBase;
155 void runOnOperation()
override {
156 func::FuncOp funcOp = getOperation();
157 bool replaceHostVsDevice = this->hostToDevice.getValue();
160 DominanceInfo domInfo;
161 PostDominanceInfo postDomInfo;
162 bool computedDomInfo =
false;
164 funcOp.walk([&](Operation *op) {
165 if (!isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(*op) &&
166 !(isa<ACC_DATA_CONSTRUCT_STRUCTURED_OPS>(*op) &&
167 applyToAccDataConstruct) &&
168 !isa<acc::DeclareEnterOp>(*op))
171 if (
auto parallelOp = dyn_cast<acc::ParallelOp>(*op)) {
172 collectAndReplaceInRegion(parallelOp, replaceHostVsDevice);
173 }
else if (
auto serialOp = dyn_cast<acc::SerialOp>(*op)) {
174 collectAndReplaceInRegion(serialOp, replaceHostVsDevice);
175 }
else if (
auto kernelsOp = dyn_cast<acc::KernelsOp>(*op)) {
176 collectAndReplaceInRegion(kernelsOp, replaceHostVsDevice);
177 }
else if (
auto loopOp = dyn_cast<acc::LoopOp>(*op)) {
178 collectAndReplaceInRegion(loopOp, replaceHostVsDevice);
179 }
else if (
auto dataOp = dyn_cast<acc::DataOp>(*op)) {
180 collectAndReplaceInRegion(dataOp, replaceHostVsDevice);
181 }
else if (
auto declareOp = dyn_cast<acc::DeclareOp>(*op)) {
182 collectAndReplaceInRegion(declareOp, replaceHostVsDevice);
183 }
else if (
auto hostDataOp = dyn_cast<acc::HostDataOp>(*op)) {
184 collectAndReplaceInRegion(hostDataOp, replaceHostVsDevice);
185 }
else if (
auto declareEnterOp = dyn_cast<acc::DeclareEnterOp>(*op)) {
186 if (!computedDomInfo) {
187 domInfo = DominanceInfo(funcOp);
188 postDomInfo = PostDominanceInfo(funcOp);
189 computedDomInfo =
true;
191 collectAndReplaceInRegion(declareEnterOp, replaceHostVsDevice, &domInfo,
194 llvm_unreachable(
"unsupported acc region op");
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
A class for computing basic dominance information.
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
This provides public APIs that all operations should have.
Operation * getOperation()
Inherit getOperation from OpState.
Operation is the basic unit of execution within MLIR.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
A class for computing basic postdominance information.
bool postDominates(Operation *a, Operation *b) const
Return true if operation A postdominates operation B.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
#define ACC_COMPUTE_CONSTRUCT_OPS
mlir::Value getAccVar(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation.
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
Include the generated interface declarations.