MLIR  20.0.0git
LegalizeDataValues.cpp
Go to the documentation of this file.
1 //===- LegalizeDataValues.cpp - -------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
13 #include "mlir/Pass/Pass.h"
15 #include "llvm/Support/ErrorHandling.h"
16 
17 namespace mlir {
18 namespace acc {
19 #define GEN_PASS_DEF_LEGALIZEDATAVALUESINREGION
20 #include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
21 } // namespace acc
22 } // namespace mlir
23 
24 using namespace mlir;
25 
26 namespace {
27 
28 static bool insideAccComputeRegion(mlir::Operation *op) {
29  mlir::Operation *parent{op->getParentOp()};
30  while (parent) {
31  if (isa<ACC_COMPUTE_CONSTRUCT_OPS>(parent)) {
32  return true;
33  }
34  parent = parent->getParentOp();
35  }
36  return false;
37 }
38 
39 static void collectPtrs(mlir::ValueRange operands,
40  llvm::SmallVector<std::pair<Value, Value>> &values,
41  bool hostToDevice) {
42  for (auto operand : operands) {
43  Value varPtr = acc::getVarPtr(operand.getDefiningOp());
44  Value accPtr = acc::getAccPtr(operand.getDefiningOp());
45  if (varPtr && accPtr) {
46  if (hostToDevice)
47  values.push_back({varPtr, accPtr});
48  else
49  values.push_back({accPtr, varPtr});
50  }
51  }
52 }
53 
54 template <typename Op>
55 static void replaceAllUsesInAccComputeRegionsWith(Value orig, Value replacement,
56  Region &outerRegion) {
57  for (auto &use : llvm::make_early_inc_range(orig.getUses())) {
58  if (outerRegion.isAncestor(use.getOwner()->getParentRegion())) {
59  if constexpr (std::is_same_v<Op, acc::DataOp> ||
60  std::is_same_v<Op, acc::DeclareOp>) {
61  // For data construct regions, only replace uses in contained compute
62  // regions.
63  if (insideAccComputeRegion(use.getOwner())) {
64  use.set(replacement);
65  }
66  } else {
67  use.set(replacement);
68  }
69  }
70  }
71 }
72 
73 template <typename Op>
74 static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
76 
77  if constexpr (std::is_same_v<Op, acc::LoopOp>) {
78  collectPtrs(op.getReductionOperands(), values, hostToDevice);
79  collectPtrs(op.getPrivateOperands(), values, hostToDevice);
80  } else {
81  collectPtrs(op.getDataClauseOperands(), values, hostToDevice);
82  if constexpr (!std::is_same_v<Op, acc::KernelsOp> &&
83  !std::is_same_v<Op, acc::DataOp> &&
84  !std::is_same_v<Op, acc::DeclareOp>) {
85  collectPtrs(op.getReductionOperands(), values, hostToDevice);
86  collectPtrs(op.getPrivateOperands(), values, hostToDevice);
87  collectPtrs(op.getFirstprivateOperands(), values, hostToDevice);
88  }
89  }
90 
91  for (auto p : values)
92  replaceAllUsesInAccComputeRegionsWith<Op>(std::get<0>(p), std::get<1>(p),
93  op.getRegion());
94 }
95 
96 class LegalizeDataValuesInRegion
97  : public acc::impl::LegalizeDataValuesInRegionBase<
98  LegalizeDataValuesInRegion> {
99 public:
100  using LegalizeDataValuesInRegionBase<
101  LegalizeDataValuesInRegion>::LegalizeDataValuesInRegionBase;
102 
103  void runOnOperation() override {
104  func::FuncOp funcOp = getOperation();
105  bool replaceHostVsDevice = this->hostToDevice.getValue();
106 
107  funcOp.walk([&](Operation *op) {
108  if (!isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(*op) &&
109  !(isa<ACC_DATA_CONSTRUCT_STRUCTURED_OPS>(*op) &&
110  applyToAccDataConstruct))
111  return;
112 
113  if (auto parallelOp = dyn_cast<acc::ParallelOp>(*op)) {
114  collectAndReplaceInRegion(parallelOp, replaceHostVsDevice);
115  } else if (auto serialOp = dyn_cast<acc::SerialOp>(*op)) {
116  collectAndReplaceInRegion(serialOp, replaceHostVsDevice);
117  } else if (auto kernelsOp = dyn_cast<acc::KernelsOp>(*op)) {
118  collectAndReplaceInRegion(kernelsOp, replaceHostVsDevice);
119  } else if (auto loopOp = dyn_cast<acc::LoopOp>(*op)) {
120  collectAndReplaceInRegion(loopOp, replaceHostVsDevice);
121  } else if (auto dataOp = dyn_cast<acc::DataOp>(*op)) {
122  collectAndReplaceInRegion(dataOp, replaceHostVsDevice);
123  } else if (auto declareOp = dyn_cast<acc::DeclareOp>(*op)) {
124  collectAndReplaceInRegion(declareOp, replaceHostVsDevice);
125  } else {
126  llvm_unreachable("unsupported acc region op");
127  }
128  });
129  }
130 };
131 
132 } // end anonymous namespace
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
Definition: Region.h:222
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:212
mlir::Value getVarPtr(mlir::Operation *accDataClauseOp)
Used to obtain the varPtr from a data clause operation.
Definition: OpenACC.cpp:2876
mlir::Value getAccPtr(mlir::Operation *accDataClauseOp)
Used to obtain the accPtr from a data clause operation.
Definition: OpenACC.cpp:2886
Include the generated interface declarations.