19#include "llvm/ADT/SmallPtrSet.h"
20#include "llvm/ADT/TypeSwitch.h"
25#define GEN_PASS_DEF_MARKDECLARETARGETPASS
26#include "mlir/Dialect/OpenMP/Transforms/Passes.h.inc"
34class MarkDeclareTargetPass
35 :
public omp::impl::MarkDeclareTargetPassBase<MarkDeclareTargetPass> {
38 omp::DeclareTargetDeviceType devTy;
39 omp::DeclareTargetCaptureClause capClause;
43 void processSymbolRef(SymbolRefAttr symRef, ParentInfo parentInfo,
44 llvm::SmallPtrSet<Operation *, 16> visited) {
45 Operation *symOp = getOperation().lookupSymbol(symRef);
48 auto current = llvm::dyn_cast<omp::DeclareTargetInterface>(symOp);
52 if (current.isDeclareTarget()) {
53 auto currentDt = current.getDeclareTargetDeviceType();
57 if (currentDt != parentInfo.devTy &&
58 currentDt != omp::DeclareTargetDeviceType::any) {
59 current.setDeclareTarget(omp::DeclareTargetDeviceType::any,
60 current.getDeclareTargetCaptureClause(),
61 current.getDeclareTargetAutomap());
64 current.setDeclareTarget(parentInfo.devTy, parentInfo.capClause,
68 markNestedFuncs(parentInfo, symOp, visited);
71 void processReductionRefs(std::optional<mlir::ArrayAttr> symRefs,
72 ParentInfo parentInfo,
73 llvm::SmallPtrSet<Operation *, 16> visited) {
77 for (
auto symRef : symRefs->getAsRange<mlir::SymbolRefAttr>()) {
78 if (
auto declareReductionOp =
79 getOperation().lookupSymbol<omp::DeclareReductionOp>(symRef)) {
80 markNestedFuncs(parentInfo, declareReductionOp, visited);
85 void processReductionClauses(Operation *op, ParentInfo parentInfo,
86 llvm::SmallPtrSet<Operation *, 16> visited) {
87 llvm::TypeSwitch<Operation &>(*op)
88 .Case([&](omp::LoopOp op) {
89 processReductionRefs(op.getReductionSyms(), parentInfo, visited);
91 .Case([&](omp::ParallelOp op) {
92 processReductionRefs(op.getReductionSyms(), parentInfo, visited);
94 .Case([&](omp::SectionsOp op) {
95 processReductionRefs(op.getReductionSyms(), parentInfo, visited);
97 .Case([&](omp::SimdOp op) {
98 processReductionRefs(op.getReductionSyms(), parentInfo, visited);
100 .Case([&](omp::TargetOp op) {
101 processReductionRefs(op.getInReductionSyms(), parentInfo, visited);
103 .Case([&](omp::TaskgroupOp op) {
104 processReductionRefs(op.getTaskReductionSyms(), parentInfo, visited);
106 .Case([&](omp::TaskloopContextOp op) {
107 processReductionRefs(op.getReductionSyms(), parentInfo, visited);
108 processReductionRefs(op.getInReductionSyms(), parentInfo, visited);
110 .Case([&](omp::TaskOp op) {
111 processReductionRefs(op.getInReductionSyms(), parentInfo, visited);
113 .Case([&](omp::TeamsOp op) {
114 processReductionRefs(op.getReductionSyms(), parentInfo, visited);
116 .Case([&](omp::WsloopOp op) {
117 processReductionRefs(op.getReductionSyms(), parentInfo, visited);
119 .Default([](Operation &) {});
122 void markNestedFuncs(ParentInfo parentInfo, Operation *currOp,
123 llvm::SmallPtrSet<Operation *, 16> visited) {
124 if (visited.contains(currOp))
126 visited.insert(currOp);
128 currOp->
walk([&,
this](Operation *op) {
129 if (
auto callOp = llvm::dyn_cast<CallOpInterface>(op)) {
130 if (
auto symRef = llvm::dyn_cast_if_present<mlir::SymbolRefAttr>(
131 callOp.getCallableForCallee())) {
132 processSymbolRef(symRef, parentInfo, visited);
135 processReductionClauses(op, parentInfo, visited);
142 void runOnOperation()
override {
143 for (
auto funcOp : getOperation().getOps<FunctionOpInterface>()) {
144 auto declareTargetOp =
145 llvm::dyn_cast<omp::DeclareTargetInterface>(funcOp.getOperation());
146 if (!declareTargetOp || !declareTargetOp.isDeclareTarget())
148 llvm::SmallPtrSet<Operation *, 16> visited;
149 ParentInfo parentInfo{declareTargetOp.getDeclareTargetDeviceType(),
150 declareTargetOp.getDeclareTargetCaptureClause(),
151 declareTargetOp.getDeclareTargetAutomap()};
152 markNestedFuncs(parentInfo, funcOp, visited);
160 getOperation()->walk([&](omp::TargetOp tarOp) {
161 llvm::SmallPtrSet<Operation *, 16> visited;
162 ParentInfo parentInfo = {
163 omp::DeclareTargetDeviceType::nohost,
164 omp::DeclareTargetCaptureClause::to,
167 markNestedFuncs(parentInfo, tarOp, visited);
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Include the generated interface declarations.