MLIR  21.0.0git
LegalizeDataValues.cpp
Go to the documentation of this file.
1 //===- LegalizeDataValues.cpp - -------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
13 #include "mlir/IR/Dominance.h"
14 #include "mlir/Pass/Pass.h"
16 #include "llvm/Support/ErrorHandling.h"
17 
18 namespace mlir {
19 namespace acc {
20 #define GEN_PASS_DEF_LEGALIZEDATAVALUESINREGION
21 #include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
22 } // namespace acc
23 } // namespace mlir
24 
25 using namespace mlir;
26 
27 namespace {
28 
29 static bool insideAccComputeRegion(mlir::Operation *op) {
30  mlir::Operation *parent{op->getParentOp()};
31  while (parent) {
32  if (isa<ACC_COMPUTE_CONSTRUCT_OPS>(parent)) {
33  return true;
34  }
35  parent = parent->getParentOp();
36  }
37  return false;
38 }
39 
40 static void collectVars(mlir::ValueRange operands,
41  llvm::SmallVector<std::pair<Value, Value>> &values,
42  bool hostToDevice) {
43  for (auto operand : operands) {
44  Value var = acc::getVar(operand.getDefiningOp());
45  Value accVar = acc::getAccVar(operand.getDefiningOp());
46  if (var && accVar) {
47  if (hostToDevice)
48  values.push_back({var, accVar});
49  else
50  values.push_back({accVar, var});
51  }
52  }
53 }
54 
55 template <typename Op>
56 static void replaceAllUsesInAccComputeRegionsWith(Value orig, Value replacement,
57  Region &outerRegion) {
58  for (auto &use : llvm::make_early_inc_range(orig.getUses())) {
59  if (outerRegion.isAncestor(use.getOwner()->getParentRegion())) {
60  if constexpr (std::is_same_v<Op, acc::DataOp> ||
61  std::is_same_v<Op, acc::DeclareOp>) {
62  // For data construct regions, only replace uses in contained compute
63  // regions.
64  if (insideAccComputeRegion(use.getOwner())) {
65  use.set(replacement);
66  }
67  } else {
68  use.set(replacement);
69  }
70  }
71  }
72 }
73 
74 template <typename Op>
75 static void replaceAllUsesInUnstructuredComputeRegionWith(
76  Op &op, llvm::SmallVector<std::pair<Value, Value>> &values,
77  DominanceInfo &domInfo, PostDominanceInfo &postDomInfo) {
78 
80  if constexpr (std::is_same_v<Op, acc::DeclareEnterOp>) {
81  // For declare enter/exit pairs, collect all exit ops
82  for (auto *user : op.getToken().getUsers()) {
83  if (auto declareExit = dyn_cast<acc::DeclareExitOp>(user))
84  exitOps.push_back(declareExit);
85  }
86  if (exitOps.empty())
87  return;
88  }
89 
90  for (auto p : values) {
91  Value hostVal = std::get<0>(p);
92  Value deviceVal = std::get<1>(p);
93  for (auto &use : llvm::make_early_inc_range(hostVal.getUses())) {
94  Operation *owner = use.getOwner();
95 
96  // Check It's the case that the acc entry operation dominates the use.
97  if (!domInfo.dominates(op.getOperation(), owner))
98  continue;
99 
100  // Check It's the case that at least one of the acc exit operations
101  // post-dominates the use
102  bool hasPostDominatingExit = false;
103  for (auto *exit : exitOps) {
104  if (postDomInfo.postDominates(exit, owner)) {
105  hasPostDominatingExit = true;
106  break;
107  }
108  }
109 
110  if (!hasPostDominatingExit)
111  continue;
112 
113  if (insideAccComputeRegion(owner))
114  use.set(deviceVal);
115  }
116  }
117 }
118 
119 template <typename Op>
120 static void
121 collectAndReplaceInRegion(Op &op, bool hostToDevice,
122  DominanceInfo *domInfo = nullptr,
123  PostDominanceInfo *postDomInfo = nullptr) {
125 
126  if constexpr (std::is_same_v<Op, acc::LoopOp>) {
127  collectVars(op.getReductionOperands(), values, hostToDevice);
128  collectVars(op.getPrivateOperands(), values, hostToDevice);
129  } else {
130  collectVars(op.getDataClauseOperands(), values, hostToDevice);
131  if constexpr (!std::is_same_v<Op, acc::KernelsOp> &&
132  !std::is_same_v<Op, acc::DataOp> &&
133  !std::is_same_v<Op, acc::DeclareOp> &&
134  !std::is_same_v<Op, acc::HostDataOp> &&
135  !std::is_same_v<Op, acc::DeclareEnterOp>) {
136  collectVars(op.getReductionOperands(), values, hostToDevice);
137  collectVars(op.getPrivateOperands(), values, hostToDevice);
138  collectVars(op.getFirstprivateOperands(), values, hostToDevice);
139  }
140  }
141 
142  if constexpr (std::is_same_v<Op, acc::DeclareEnterOp>) {
143  assert(domInfo && postDomInfo &&
144  "Dominance info required for DeclareEnterOp");
145  replaceAllUsesInUnstructuredComputeRegionWith<Op>(op, values, *domInfo,
146  *postDomInfo);
147  } else {
148  for (auto p : values) {
149  replaceAllUsesInAccComputeRegionsWith<Op>(std::get<0>(p), std::get<1>(p),
150  op.getRegion());
151  }
152  }
153 }
154 
155 class LegalizeDataValuesInRegion
156  : public acc::impl::LegalizeDataValuesInRegionBase<
157  LegalizeDataValuesInRegion> {
158 public:
159  using LegalizeDataValuesInRegionBase<
160  LegalizeDataValuesInRegion>::LegalizeDataValuesInRegionBase;
161 
162  void runOnOperation() override {
163  func::FuncOp funcOp = getOperation();
164  bool replaceHostVsDevice = this->hostToDevice.getValue();
165 
166  // Initialize dominance info
167  DominanceInfo domInfo;
168  PostDominanceInfo postDomInfo;
169  bool computedDomInfo = false;
170 
171  funcOp.walk([&](Operation *op) {
172  if (!isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(*op) &&
173  !(isa<ACC_DATA_CONSTRUCT_STRUCTURED_OPS>(*op) &&
174  applyToAccDataConstruct) &&
175  !isa<acc::DeclareEnterOp>(*op))
176  return;
177 
178  if (auto parallelOp = dyn_cast<acc::ParallelOp>(*op)) {
179  collectAndReplaceInRegion(parallelOp, replaceHostVsDevice);
180  } else if (auto serialOp = dyn_cast<acc::SerialOp>(*op)) {
181  collectAndReplaceInRegion(serialOp, replaceHostVsDevice);
182  } else if (auto kernelsOp = dyn_cast<acc::KernelsOp>(*op)) {
183  collectAndReplaceInRegion(kernelsOp, replaceHostVsDevice);
184  } else if (auto loopOp = dyn_cast<acc::LoopOp>(*op)) {
185  collectAndReplaceInRegion(loopOp, replaceHostVsDevice);
186  } else if (auto dataOp = dyn_cast<acc::DataOp>(*op)) {
187  collectAndReplaceInRegion(dataOp, replaceHostVsDevice);
188  } else if (auto declareOp = dyn_cast<acc::DeclareOp>(*op)) {
189  collectAndReplaceInRegion(declareOp, replaceHostVsDevice);
190  } else if (auto hostDataOp = dyn_cast<acc::HostDataOp>(*op)) {
191  collectAndReplaceInRegion(hostDataOp, replaceHostVsDevice);
192  } else if (auto declareEnterOp = dyn_cast<acc::DeclareEnterOp>(*op)) {
193  if (!computedDomInfo) {
194  domInfo = DominanceInfo(funcOp);
195  postDomInfo = PostDominanceInfo(funcOp);
196  computedDomInfo = true;
197  }
198  collectAndReplaceInRegion(declareEnterOp, replaceHostVsDevice, &domInfo,
199  &postDomInfo);
200  } else {
201  llvm_unreachable("unsupported acc region op");
202  }
203  });
204  }
205 };
206 
207 } // end anonymous namespace
A class for computing basic dominance information.
Definition: Dominance.h:140
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
Definition: Dominance.h:158
This provides public APIs that all operations should have.
Operation * getOperation()
Inherit getOperation from OpState.
Definition: OpDefinition.h:111
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
A class for computing basic postdominance information.
Definition: Dominance.h:204
bool postDominates(Operation *a, Operation *b) const
Return true if operation A postdominates operation B.
Definition: Dominance.h:213
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
Definition: Region.h:222
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:188
mlir::Value getAccVar(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation.
Definition: OpenACC.cpp:3975
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
Definition: OpenACC.cpp:3944
Include the generated interface declarations.