18 #define GEN_PASS_DEF_LEGALIZEDATAINREGION
19 #include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
30 for (
auto operand : operands) {
33 if (varPtr && accPtr) {
35 values.push_back({varPtr, accPtr});
37 values.push_back({accPtr, varPtr});
42 template <
typename Op>
43 static void collectAndReplaceInRegion(
Op &op,
bool hostToDevice) {
46 if constexpr (std::is_same_v<Op, acc::LoopOp>) {
47 collectPtrs(op.getReductionOperands(), values, hostToDevice);
48 collectPtrs(op.getPrivateOperands(), values, hostToDevice);
50 collectPtrs(op.getDataClauseOperands(), values, hostToDevice);
51 if constexpr (!std::is_same_v<Op, acc::KernelsOp>) {
52 collectPtrs(op.getReductionOperands(), values, hostToDevice);
53 collectPtrs(op.getGangPrivateOperands(), values, hostToDevice);
54 collectPtrs(op.getGangFirstPrivateOperands(), values, hostToDevice);
62 struct LegalizeDataInRegion
63 :
public acc::impl::LegalizeDataInRegionBase<LegalizeDataInRegion> {
65 void runOnOperation()
override {
66 func::FuncOp funcOp = getOperation();
67 bool replaceHostVsDevice = this->hostToDevice.getValue();
70 if (!isa<ACC_COMPUTE_CONSTRUCT_OPS>(*op) && !isa<acc::LoopOp>(*op))
73 if (
auto parallelOp = dyn_cast<acc::ParallelOp>(*op)) {
74 collectAndReplaceInRegion(parallelOp, replaceHostVsDevice);
75 }
else if (
auto serialOp = dyn_cast<acc::SerialOp>(*op)) {
76 collectAndReplaceInRegion(serialOp, replaceHostVsDevice);
77 }
else if (
auto kernelsOp = dyn_cast<acc::KernelsOp>(*op)) {
78 collectAndReplaceInRegion(kernelsOp, replaceHostVsDevice);
79 }
else if (
auto loopOp = dyn_cast<acc::LoopOp>(*op)) {
80 collectAndReplaceInRegion(loopOp, replaceHostVsDevice);
88 std::unique_ptr<OperationPass<func::FuncOp>>
90 return std::make_unique<LegalizeDataInRegion>();
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
mlir::Value getVarPtr(mlir::Operation *accDataClauseOp)
Used to obtain the varPtr from a data clause operation.
std::unique_ptr< OperationPass< func::FuncOp > > createLegalizeDataInRegion()
Create a pass to replace ssa values in region with device/host values.
mlir::Value getAccPtr(mlir::Operation *accDataClauseOp)
Used to obtain the accPtr from a data clause operation.
Include the generated interface declarations.
void replaceAllUsesInRegionWith(Value orig, Value replacement, Region ®ion)
Replace all uses of orig within the given region with replacement.