MLIR 23.0.0git
MarkDeclareTarget.cpp
Go to the documentation of this file.
1//===- MarkDeclareTarget.cpp ----------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Mark functions called from explicit target code as implicitly declare target.
10//
11//===----------------------------------------------------------------------===//
12
14#include "mlir/IR/Operation.h"
15#include "mlir/IR/SymbolTable.h"
17#include "mlir/Pass/Pass.h"
18#include "mlir/Support/LLVM.h"
19#include "llvm/ADT/SmallPtrSet.h"
20#include "llvm/ADT/TypeSwitch.h"
21
22namespace mlir {
23namespace omp {
24
25#define GEN_PASS_DEF_MARKDECLARETARGETPASS
26#include "mlir/Dialect/OpenMP/Transforms/Passes.h.inc"
27
28} // namespace omp
29} // namespace mlir
30
31using namespace mlir;
32namespace {
33
34class MarkDeclareTargetPass
35 : public omp::impl::MarkDeclareTargetPassBase<MarkDeclareTargetPass> {
36
37 struct ParentInfo {
38 omp::DeclareTargetDeviceType devTy;
39 omp::DeclareTargetCaptureClause capClause;
40 bool automap;
41 };
42
43 void processSymbolRef(SymbolRefAttr symRef, ParentInfo parentInfo,
44 llvm::SmallPtrSet<Operation *, 16> visited) {
45 Operation *symOp = getOperation().lookupSymbol(symRef);
46 if (!symOp)
47 return;
48 auto current = llvm::dyn_cast<omp::DeclareTargetInterface>(symOp);
49 if (!current)
50 return;
51
52 if (current.isDeclareTarget()) {
53 auto currentDt = current.getDeclareTargetDeviceType();
54
55 // Found the same function twice, with different device_types,
56 // mark as Any as it belongs to both
57 if (currentDt != parentInfo.devTy &&
58 currentDt != omp::DeclareTargetDeviceType::any) {
59 current.setDeclareTarget(omp::DeclareTargetDeviceType::any,
60 current.getDeclareTargetCaptureClause(),
61 current.getDeclareTargetAutomap());
62 }
63 } else {
64 current.setDeclareTarget(parentInfo.devTy, parentInfo.capClause,
65 parentInfo.automap);
66 }
67
68 markNestedFuncs(parentInfo, symOp, visited);
69 }
70
71 void processReductionRefs(std::optional<mlir::ArrayAttr> symRefs,
72 ParentInfo parentInfo,
73 llvm::SmallPtrSet<Operation *, 16> visited) {
74 if (!symRefs)
75 return;
76
77 for (auto symRef : symRefs->getAsRange<mlir::SymbolRefAttr>()) {
78 if (auto declareReductionOp =
79 getOperation().lookupSymbol<omp::DeclareReductionOp>(symRef)) {
80 markNestedFuncs(parentInfo, declareReductionOp, visited);
81 }
82 }
83 }
84
85 void processReductionClauses(Operation *op, ParentInfo parentInfo,
86 llvm::SmallPtrSet<Operation *, 16> visited) {
87 llvm::TypeSwitch<Operation &>(*op)
88 .Case([&](omp::LoopOp op) {
89 processReductionRefs(op.getReductionSyms(), parentInfo, visited);
90 })
91 .Case([&](omp::ParallelOp op) {
92 processReductionRefs(op.getReductionSyms(), parentInfo, visited);
93 })
94 .Case([&](omp::SectionsOp op) {
95 processReductionRefs(op.getReductionSyms(), parentInfo, visited);
96 })
97 .Case([&](omp::SimdOp op) {
98 processReductionRefs(op.getReductionSyms(), parentInfo, visited);
99 })
100 .Case([&](omp::TargetOp op) {
101 processReductionRefs(op.getInReductionSyms(), parentInfo, visited);
102 })
103 .Case([&](omp::TaskgroupOp op) {
104 processReductionRefs(op.getTaskReductionSyms(), parentInfo, visited);
105 })
106 .Case([&](omp::TaskloopContextOp op) {
107 processReductionRefs(op.getReductionSyms(), parentInfo, visited);
108 processReductionRefs(op.getInReductionSyms(), parentInfo, visited);
109 })
110 .Case([&](omp::TaskOp op) {
111 processReductionRefs(op.getInReductionSyms(), parentInfo, visited);
112 })
113 .Case([&](omp::TeamsOp op) {
114 processReductionRefs(op.getReductionSyms(), parentInfo, visited);
115 })
116 .Case([&](omp::WsloopOp op) {
117 processReductionRefs(op.getReductionSyms(), parentInfo, visited);
118 })
119 .Default([](Operation &) {});
120 }
121
122 void markNestedFuncs(ParentInfo parentInfo, Operation *currOp,
123 llvm::SmallPtrSet<Operation *, 16> visited) {
124 if (visited.contains(currOp))
125 return;
126 visited.insert(currOp);
127
128 currOp->walk([&, this](Operation *op) {
129 if (auto callOp = llvm::dyn_cast<CallOpInterface>(op)) {
130 if (auto symRef = llvm::dyn_cast_if_present<mlir::SymbolRefAttr>(
131 callOp.getCallableForCallee())) {
132 processSymbolRef(symRef, parentInfo, visited);
133 }
134 }
135 processReductionClauses(op, parentInfo, visited);
136 });
137 }
138
139 // This pass executes on mlir::ModuleOp's marking functions contained within
140 // as implicitly declare target if they are called from within an explicitly
141 // marked declare target function or a target region (TargetOp)
142 void runOnOperation() override {
143 for (auto funcOp : getOperation().getOps<FunctionOpInterface>()) {
144 auto declareTargetOp =
145 llvm::dyn_cast<omp::DeclareTargetInterface>(funcOp.getOperation());
146 if (!declareTargetOp || !declareTargetOp.isDeclareTarget())
147 continue;
148 llvm::SmallPtrSet<Operation *, 16> visited;
149 ParentInfo parentInfo{declareTargetOp.getDeclareTargetDeviceType(),
150 declareTargetOp.getDeclareTargetCaptureClause(),
151 declareTargetOp.getDeclareTargetAutomap()};
152 markNestedFuncs(parentInfo, funcOp, visited);
153 }
154
155 // TODO: Extend to work with reverse-offloading, this shouldn't
156 // require too much effort, just need to check the device clause
157 // when it's lowering has been implemented and change the
158 // DeclareTargetDeviceType argument from nohost to host depending on
159 // the contents of the device clause
160 getOperation()->walk([&](omp::TargetOp tarOp) {
161 llvm::SmallPtrSet<Operation *, 16> visited;
162 ParentInfo parentInfo = {
163 /*devTy=*/omp::DeclareTargetDeviceType::nohost,
164 /*capClause=*/omp::DeclareTargetCaptureClause::to,
165 /*automap=*/false,
166 };
167 markNestedFuncs(parentInfo, tarOp, visited);
168 });
169 }
170};
171} // namespace
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:823
Include the generated interface declarations.