19#include "llvm/ADT/SmallPtrSet.h"
20#include "llvm/ADT/TypeSwitch.h"
25#define GEN_PASS_DEF_MARKDECLARETARGETPASS
26#include "mlir/Dialect/OpenMP/Transforms/Passes.h.inc"
34class MarkDeclareTargetPass
35 :
public omp::impl::MarkDeclareTargetPassBase<MarkDeclareTargetPass> {
38 omp::DeclareTargetDeviceType devTy;
39 omp::DeclareTargetCaptureClause capClause;
43 void processSymbolRef(SymbolRefAttr symRef, ParentInfo parentInfo,
44 llvm::SmallPtrSet<Operation *, 16> visited) {
45 if (
auto currFOp = getOperation().lookupSymbol<func::FuncOp>(symRef)) {
47 llvm::dyn_cast<omp::DeclareTargetInterface>(currFOp.getOperation());
49 if (current.isDeclareTarget()) {
50 auto currentDt = current.getDeclareTargetDeviceType();
54 if (currentDt != parentInfo.devTy &&
55 currentDt != omp::DeclareTargetDeviceType::any) {
56 current.setDeclareTarget(omp::DeclareTargetDeviceType::any,
57 current.getDeclareTargetCaptureClause(),
58 current.getDeclareTargetAutomap());
61 current.setDeclareTarget(parentInfo.devTy, parentInfo.capClause,
65 markNestedFuncs(parentInfo, currFOp, visited);
69 void processReductionRefs(std::optional<mlir::ArrayAttr> symRefs,
70 ParentInfo parentInfo,
71 llvm::SmallPtrSet<Operation *, 16> visited) {
75 for (
auto symRef : symRefs->getAsRange<mlir::SymbolRefAttr>()) {
76 if (
auto declareReductionOp =
77 getOperation().lookupSymbol<omp::DeclareReductionOp>(symRef)) {
78 markNestedFuncs(parentInfo, declareReductionOp, visited);
83 void processReductionClauses(Operation *op, ParentInfo parentInfo,
84 llvm::SmallPtrSet<Operation *, 16> visited) {
85 llvm::TypeSwitch<Operation &>(*op)
86 .Case([&](omp::LoopOp op) {
87 processReductionRefs(op.getReductionSyms(), parentInfo, visited);
89 .Case([&](omp::ParallelOp op) {
90 processReductionRefs(op.getReductionSyms(), parentInfo, visited);
92 .Case([&](omp::SectionsOp op) {
93 processReductionRefs(op.getReductionSyms(), parentInfo, visited);
95 .Case([&](omp::SimdOp op) {
96 processReductionRefs(op.getReductionSyms(), parentInfo, visited);
98 .Case([&](omp::TargetOp op) {
99 processReductionRefs(op.getInReductionSyms(), parentInfo, visited);
101 .Case([&](omp::TaskgroupOp op) {
102 processReductionRefs(op.getTaskReductionSyms(), parentInfo, visited);
104 .Case([&](omp::TaskloopOp op) {
105 processReductionRefs(op.getReductionSyms(), parentInfo, visited);
106 processReductionRefs(op.getInReductionSyms(), parentInfo, visited);
108 .Case([&](omp::TaskOp op) {
109 processReductionRefs(op.getInReductionSyms(), parentInfo, visited);
111 .Case([&](omp::TeamsOp op) {
112 processReductionRefs(op.getReductionSyms(), parentInfo, visited);
114 .Case([&](omp::WsloopOp op) {
115 processReductionRefs(op.getReductionSyms(), parentInfo, visited);
117 .Default([](Operation &) {});
120 void markNestedFuncs(ParentInfo parentInfo, Operation *currOp,
121 llvm::SmallPtrSet<Operation *, 16> visited) {
122 if (visited.contains(currOp))
124 visited.insert(currOp);
126 currOp->
walk([&,
this](Operation *op) {
127 if (
auto callOp = llvm::dyn_cast<CallOpInterface>(op)) {
128 if (
auto symRef = llvm::dyn_cast_if_present<mlir::SymbolRefAttr>(
129 callOp.getCallableForCallee())) {
130 processSymbolRef(symRef, parentInfo, visited);
133 processReductionClauses(op, parentInfo, visited);
140 void runOnOperation()
override {
141 for (
auto functionOp : getOperation().getOps<func::FuncOp>()) {
142 auto declareTargetOp = llvm::dyn_cast<omp::DeclareTargetInterface>(
143 functionOp.getOperation());
144 if (declareTargetOp.isDeclareTarget()) {
145 llvm::SmallPtrSet<Operation *, 16> visited;
146 ParentInfo parentInfo{declareTargetOp.getDeclareTargetDeviceType(),
147 declareTargetOp.getDeclareTargetCaptureClause(),
148 declareTargetOp.getDeclareTargetAutomap()};
149 markNestedFuncs(parentInfo, functionOp, visited);
158 getOperation()->walk([&](omp::TargetOp tarOp) {
159 llvm::SmallPtrSet<Operation *, 16> visited;
160 ParentInfo parentInfo = {
161 omp::DeclareTargetDeviceType::nohost,
162 omp::DeclareTargetCaptureClause::to,
165 markNestedFuncs(parentInfo, tarOp, visited);
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Include the generated interface declarations.