MLIR 23.0.0git
MarkDeclareTarget.cpp
Go to the documentation of this file.
1//===- MarkDeclareTarget.cpp ----------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Mark functions called from explicit target code as implicitly declare target.
10//
11//===----------------------------------------------------------------------===//
12
15#include "mlir/IR/Operation.h"
16#include "mlir/IR/SymbolTable.h"
17#include "mlir/Pass/Pass.h"
18#include "mlir/Support/LLVM.h"
19#include "llvm/ADT/SmallPtrSet.h"
20#include "llvm/ADT/TypeSwitch.h"
21
22namespace mlir {
23namespace omp {
24
25#define GEN_PASS_DEF_MARKDECLARETARGETPASS
26#include "mlir/Dialect/OpenMP/Transforms/Passes.h.inc"
27
28} // namespace omp
29} // namespace mlir
30
31using namespace mlir;
32namespace {
33
34class MarkDeclareTargetPass
35 : public omp::impl::MarkDeclareTargetPassBase<MarkDeclareTargetPass> {
36
37 struct ParentInfo {
38 omp::DeclareTargetDeviceType devTy;
39 omp::DeclareTargetCaptureClause capClause;
40 bool automap;
41 };
42
43 void processSymbolRef(SymbolRefAttr symRef, ParentInfo parentInfo,
44 llvm::SmallPtrSet<Operation *, 16> visited) {
45 if (auto currFOp = getOperation().lookupSymbol<func::FuncOp>(symRef)) {
46 auto current =
47 llvm::dyn_cast<omp::DeclareTargetInterface>(currFOp.getOperation());
48
49 if (current.isDeclareTarget()) {
50 auto currentDt = current.getDeclareTargetDeviceType();
51
52 // Found the same function twice, with different device_types,
53 // mark as Any as it belongs to both
54 if (currentDt != parentInfo.devTy &&
55 currentDt != omp::DeclareTargetDeviceType::any) {
56 current.setDeclareTarget(omp::DeclareTargetDeviceType::any,
57 current.getDeclareTargetCaptureClause(),
58 current.getDeclareTargetAutomap());
59 }
60 } else {
61 current.setDeclareTarget(parentInfo.devTy, parentInfo.capClause,
62 parentInfo.automap);
63 }
64
65 markNestedFuncs(parentInfo, currFOp, visited);
66 }
67 }
68
69 void processReductionRefs(std::optional<mlir::ArrayAttr> symRefs,
70 ParentInfo parentInfo,
71 llvm::SmallPtrSet<Operation *, 16> visited) {
72 if (!symRefs)
73 return;
74
75 for (auto symRef : symRefs->getAsRange<mlir::SymbolRefAttr>()) {
76 if (auto declareReductionOp =
77 getOperation().lookupSymbol<omp::DeclareReductionOp>(symRef)) {
78 markNestedFuncs(parentInfo, declareReductionOp, visited);
79 }
80 }
81 }
82
83 void processReductionClauses(Operation *op, ParentInfo parentInfo,
84 llvm::SmallPtrSet<Operation *, 16> visited) {
85 llvm::TypeSwitch<Operation &>(*op)
86 .Case([&](omp::LoopOp op) {
87 processReductionRefs(op.getReductionSyms(), parentInfo, visited);
88 })
89 .Case([&](omp::ParallelOp op) {
90 processReductionRefs(op.getReductionSyms(), parentInfo, visited);
91 })
92 .Case([&](omp::SectionsOp op) {
93 processReductionRefs(op.getReductionSyms(), parentInfo, visited);
94 })
95 .Case([&](omp::SimdOp op) {
96 processReductionRefs(op.getReductionSyms(), parentInfo, visited);
97 })
98 .Case([&](omp::TargetOp op) {
99 processReductionRefs(op.getInReductionSyms(), parentInfo, visited);
100 })
101 .Case([&](omp::TaskgroupOp op) {
102 processReductionRefs(op.getTaskReductionSyms(), parentInfo, visited);
103 })
104 .Case([&](omp::TaskloopOp op) {
105 processReductionRefs(op.getReductionSyms(), parentInfo, visited);
106 processReductionRefs(op.getInReductionSyms(), parentInfo, visited);
107 })
108 .Case([&](omp::TaskOp op) {
109 processReductionRefs(op.getInReductionSyms(), parentInfo, visited);
110 })
111 .Case([&](omp::TeamsOp op) {
112 processReductionRefs(op.getReductionSyms(), parentInfo, visited);
113 })
114 .Case([&](omp::WsloopOp op) {
115 processReductionRefs(op.getReductionSyms(), parentInfo, visited);
116 })
117 .Default([](Operation &) {});
118 }
119
120 void markNestedFuncs(ParentInfo parentInfo, Operation *currOp,
121 llvm::SmallPtrSet<Operation *, 16> visited) {
122 if (visited.contains(currOp))
123 return;
124 visited.insert(currOp);
125
126 currOp->walk([&, this](Operation *op) {
127 if (auto callOp = llvm::dyn_cast<CallOpInterface>(op)) {
128 if (auto symRef = llvm::dyn_cast_if_present<mlir::SymbolRefAttr>(
129 callOp.getCallableForCallee())) {
130 processSymbolRef(symRef, parentInfo, visited);
131 }
132 }
133 processReductionClauses(op, parentInfo, visited);
134 });
135 }
136
137 // This pass executes on mlir::ModuleOp's marking functions contained within
138 // as implicitly declare target if they are called from within an explicitly
139 // marked declare target function or a target region (TargetOp)
140 void runOnOperation() override {
141 for (auto functionOp : getOperation().getOps<func::FuncOp>()) {
142 auto declareTargetOp = llvm::dyn_cast<omp::DeclareTargetInterface>(
143 functionOp.getOperation());
144 if (declareTargetOp.isDeclareTarget()) {
145 llvm::SmallPtrSet<Operation *, 16> visited;
146 ParentInfo parentInfo{declareTargetOp.getDeclareTargetDeviceType(),
147 declareTargetOp.getDeclareTargetCaptureClause(),
148 declareTargetOp.getDeclareTargetAutomap()};
149 markNestedFuncs(parentInfo, functionOp, visited);
150 }
151 }
152
153 // TODO: Extend to work with reverse-offloading, this shouldn't
154 // require too much effort, just need to check the device clause
155 // when it's lowering has been implemented and change the
156 // DeclareTargetDeviceType argument from nohost to host depending on
157 // the contents of the device clause
158 getOperation()->walk([&](omp::TargetOp tarOp) {
159 llvm::SmallPtrSet<Operation *, 16> visited;
160 ParentInfo parentInfo = {
161 /*devTy=*/omp::DeclareTargetDeviceType::nohost,
162 /*capClause=*/omp::DeclareTargetCaptureClause::to,
163 /*automap=*/false,
164 };
165 markNestedFuncs(parentInfo, tarOp, visited);
166 });
167 }
168};
169} // namespace
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:826
Include the generated interface declarations.