MLIR  19.0.0git
LegalizeData.cpp
Go to the documentation of this file.
1 //===- LegalizeData.cpp - -------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
13 #include "mlir/Pass/Pass.h"
15 
16 namespace mlir {
17 namespace acc {
18 #define GEN_PASS_DEF_LEGALIZEDATAINREGION
19 #include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
20 } // namespace acc
21 } // namespace mlir
22 
23 using namespace mlir;
24 
25 namespace {
26 
27 static void collectPtrs(mlir::ValueRange operands,
28  llvm::SmallVector<std::pair<Value, Value>> &values,
29  bool hostToDevice) {
30  for (auto operand : operands) {
31  Value varPtr = acc::getVarPtr(operand.getDefiningOp());
32  Value accPtr = acc::getAccPtr(operand.getDefiningOp());
33  if (varPtr && accPtr) {
34  if (hostToDevice)
35  values.push_back({varPtr, accPtr});
36  else
37  values.push_back({accPtr, varPtr});
38  }
39  }
40 }
41 
42 template <typename Op>
43 static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
45 
46  if constexpr (std::is_same_v<Op, acc::LoopOp>) {
47  collectPtrs(op.getReductionOperands(), values, hostToDevice);
48  collectPtrs(op.getPrivateOperands(), values, hostToDevice);
49  } else {
50  collectPtrs(op.getDataClauseOperands(), values, hostToDevice);
51  if constexpr (!std::is_same_v<Op, acc::KernelsOp>) {
52  collectPtrs(op.getReductionOperands(), values, hostToDevice);
53  collectPtrs(op.getGangPrivateOperands(), values, hostToDevice);
54  collectPtrs(op.getGangFirstPrivateOperands(), values, hostToDevice);
55  }
56  }
57 
58  for (auto p : values)
59  replaceAllUsesInRegionWith(std::get<0>(p), std::get<1>(p), op.getRegion());
60 }
61 
62 struct LegalizeDataInRegion
63  : public acc::impl::LegalizeDataInRegionBase<LegalizeDataInRegion> {
64 
65  void runOnOperation() override {
66  func::FuncOp funcOp = getOperation();
67  bool replaceHostVsDevice = this->hostToDevice.getValue();
68 
69  funcOp.walk([&](Operation *op) {
70  if (!isa<ACC_COMPUTE_CONSTRUCT_OPS>(*op) && !isa<acc::LoopOp>(*op))
71  return;
72 
73  if (auto parallelOp = dyn_cast<acc::ParallelOp>(*op)) {
74  collectAndReplaceInRegion(parallelOp, replaceHostVsDevice);
75  } else if (auto serialOp = dyn_cast<acc::SerialOp>(*op)) {
76  collectAndReplaceInRegion(serialOp, replaceHostVsDevice);
77  } else if (auto kernelsOp = dyn_cast<acc::KernelsOp>(*op)) {
78  collectAndReplaceInRegion(kernelsOp, replaceHostVsDevice);
79  } else if (auto loopOp = dyn_cast<acc::LoopOp>(*op)) {
80  collectAndReplaceInRegion(loopOp, replaceHostVsDevice);
81  }
82  });
83  }
84 };
85 
86 } // end anonymous namespace
87 
88 std::unique_ptr<OperationPass<func::FuncOp>>
90  return std::make_unique<LegalizeDataInRegion>();
91 }
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
mlir::Value getVarPtr(mlir::Operation *accDataClauseOp)
Used to obtain the varPtr from a data clause operation.
Definition: OpenACC.cpp:2841
std::unique_ptr< OperationPass< func::FuncOp > > createLegalizeDataInRegion()
Create a pass to replace ssa values in region with device/host values.
mlir::Value getAccPtr(mlir::Operation *accDataClauseOp)
Used to obtain the accPtr from a data clause operation.
Definition: OpenACC.cpp:2851
Include the generated interface declarations.
void replaceAllUsesInRegionWith(Value orig, Value replacement, Region &region)
Replace all uses of orig within the given region with replacement.
Definition: RegionUtils.cpp:28