MLIR 23.0.0git
LegalizeDataValues.cpp
Go to the documentation of this file.
1//===- LegalizeDataValues.cpp - -------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
13#include "mlir/IR/Dominance.h"
14#include "mlir/Pass/Pass.h"
16#include "llvm/Support/ErrorHandling.h"
17
18namespace mlir {
19namespace acc {
20#define GEN_PASS_DEF_LEGALIZEDATAVALUESINREGION
21#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
22} // namespace acc
23} // namespace mlir
24
25using namespace mlir;
26
27namespace {
28
29static bool insideAccComputeRegion(mlir::Operation *op) {
31}
32
33static void collectVars(mlir::ValueRange operands,
34 llvm::SmallVector<std::pair<Value, Value>> &values,
35 bool hostToDevice) {
36 for (auto operand : operands) {
37 Value var = acc::getVar(operand.getDefiningOp());
38 Value accVar = acc::getAccVar(operand.getDefiningOp());
39 if (var && accVar) {
40 if (hostToDevice)
41 values.push_back({var, accVar});
42 else
43 values.push_back({accVar, var});
44 }
45 }
46}
47
48template <typename Op>
49static void replaceAllUsesInAccComputeRegionsWith(Value orig, Value replacement,
50 Region &outerRegion) {
51 for (auto &use : llvm::make_early_inc_range(orig.getUses())) {
52 if (outerRegion.isAncestor(use.getOwner()->getParentRegion())) {
53 if constexpr (std::is_same_v<Op, acc::DataOp> ||
54 std::is_same_v<Op, acc::DeclareOp>) {
55 // For data construct regions, only replace uses in contained compute
56 // regions.
57 if (insideAccComputeRegion(use.getOwner())) {
58 use.set(replacement);
59 }
60 } else {
61 use.set(replacement);
62 }
63 }
64 }
65}
66
67template <typename Op>
68static void replaceAllUsesInUnstructuredComputeRegionWith(
69 Op &op, llvm::SmallVector<std::pair<Value, Value>> &values,
70 DominanceInfo &domInfo, PostDominanceInfo &postDomInfo) {
71
73 if constexpr (std::is_same_v<Op, acc::DeclareEnterOp>) {
74 // For declare enter/exit pairs, collect all exit ops
75 for (auto *user : op.getToken().getUsers()) {
76 if (auto declareExit = dyn_cast<acc::DeclareExitOp>(user))
77 exitOps.push_back(declareExit);
78 }
79 if (exitOps.empty())
80 return;
81 }
82
83 for (auto p : values) {
84 Value hostVal = std::get<0>(p);
85 Value deviceVal = std::get<1>(p);
86 for (auto &use : llvm::make_early_inc_range(hostVal.getUses())) {
87 Operation *owner = use.getOwner();
88
89 // Check It's the case that the acc entry operation dominates the use.
90 if (!domInfo.dominates(op.getOperation(), owner))
91 continue;
92
93 // Check It's the case that at least one of the acc exit operations
94 // post-dominates the use
95 bool hasPostDominatingExit = false;
96 for (auto *exit : exitOps) {
97 if (postDomInfo.postDominates(exit, owner)) {
98 hasPostDominatingExit = true;
99 break;
100 }
101 }
102
103 if (!hasPostDominatingExit)
104 continue;
105
106 if (insideAccComputeRegion(owner))
107 use.set(deviceVal);
108 }
109 }
110}
111
112template <typename Op>
113static void
114collectAndReplaceInRegion(Op &op, bool hostToDevice,
115 DominanceInfo *domInfo = nullptr,
116 PostDominanceInfo *postDomInfo = nullptr) {
118
119 if constexpr (std::is_same_v<Op, acc::LoopOp>) {
120 collectVars(op.getReductionOperands(), values, hostToDevice);
121 collectVars(op.getPrivateOperands(), values, hostToDevice);
122 } else {
123 collectVars(op.getDataClauseOperands(), values, hostToDevice);
124 if constexpr (!std::is_same_v<Op, acc::KernelsOp> &&
125 !std::is_same_v<Op, acc::DataOp> &&
126 !std::is_same_v<Op, acc::DeclareOp> &&
127 !std::is_same_v<Op, acc::HostDataOp> &&
128 !std::is_same_v<Op, acc::DeclareEnterOp>) {
129 collectVars(op.getReductionOperands(), values, hostToDevice);
130 collectVars(op.getPrivateOperands(), values, hostToDevice);
131 collectVars(op.getFirstprivateOperands(), values, hostToDevice);
132 }
133 }
134
135 if constexpr (std::is_same_v<Op, acc::DeclareEnterOp>) {
136 assert(domInfo && postDomInfo &&
137 "Dominance info required for DeclareEnterOp");
138 replaceAllUsesInUnstructuredComputeRegionWith<Op>(op, values, *domInfo,
139 *postDomInfo);
140 } else {
141 for (auto p : values) {
142 replaceAllUsesInAccComputeRegionsWith<Op>(std::get<0>(p), std::get<1>(p),
143 op.getRegion());
144 }
145 }
146}
147
148class LegalizeDataValuesInRegion
149 : public acc::impl::LegalizeDataValuesInRegionBase<
150 LegalizeDataValuesInRegion> {
151public:
152 using LegalizeDataValuesInRegionBase<
153 LegalizeDataValuesInRegion>::LegalizeDataValuesInRegionBase;
154
155 void runOnOperation() override {
156 func::FuncOp funcOp = getOperation();
157 bool replaceHostVsDevice = this->hostToDevice.getValue();
158
159 // Initialize dominance info
160 DominanceInfo domInfo;
161 PostDominanceInfo postDomInfo;
162 bool computedDomInfo = false;
163
164 funcOp.walk([&](Operation *op) {
165 if (!isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(*op) &&
166 !(isa<ACC_DATA_CONSTRUCT_STRUCTURED_OPS>(*op) &&
167 applyToAccDataConstruct) &&
168 !isa<acc::DeclareEnterOp>(*op))
169 return;
170
171 if (auto parallelOp = dyn_cast<acc::ParallelOp>(*op)) {
172 collectAndReplaceInRegion(parallelOp, replaceHostVsDevice);
173 } else if (auto serialOp = dyn_cast<acc::SerialOp>(*op)) {
174 collectAndReplaceInRegion(serialOp, replaceHostVsDevice);
175 } else if (auto kernelsOp = dyn_cast<acc::KernelsOp>(*op)) {
176 collectAndReplaceInRegion(kernelsOp, replaceHostVsDevice);
177 } else if (auto loopOp = dyn_cast<acc::LoopOp>(*op)) {
178 collectAndReplaceInRegion(loopOp, replaceHostVsDevice);
179 } else if (auto dataOp = dyn_cast<acc::DataOp>(*op)) {
180 collectAndReplaceInRegion(dataOp, replaceHostVsDevice);
181 } else if (auto declareOp = dyn_cast<acc::DeclareOp>(*op)) {
182 collectAndReplaceInRegion(declareOp, replaceHostVsDevice);
183 } else if (auto hostDataOp = dyn_cast<acc::HostDataOp>(*op)) {
184 collectAndReplaceInRegion(hostDataOp, replaceHostVsDevice);
185 } else if (auto declareEnterOp = dyn_cast<acc::DeclareEnterOp>(*op)) {
186 if (!computedDomInfo) {
187 domInfo = DominanceInfo(funcOp);
188 postDomInfo = PostDominanceInfo(funcOp);
189 computedDomInfo = true;
190 }
191 collectAndReplaceInRegion(declareEnterOp, replaceHostVsDevice, &domInfo,
192 &postDomInfo);
193 } else {
194 llvm_unreachable("unsupported acc region op");
195 }
196 });
197 }
198};
199
200} // end anonymous namespace
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
A class for computing basic dominance information.
Definition Dominance.h:140
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
Definition Dominance.h:158
This provides public APIs that all operations should have.
Operation * getOperation()
Inherit getOperation from OpState.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:256
A class for computing basic postdominance information.
Definition Dominance.h:204
bool postDominates(Operation *a, Operation *b) const
Return true if operation A postdominates operation B.
Definition Dominance.h:213
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
Definition Region.h:233
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition Value.h:188
#define ACC_COMPUTE_CONSTRUCT_OPS
Definition OpenACC.h:60
mlir::Value getAccVar(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation.
Definition OpenACC.cpp:5098
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
Definition OpenACC.cpp:5067
Include the generated interface declarations.