MLIR 23.0.0git
ACCBindRoutine.cpp
Go to the documentation of this file.
1//===- ACCBindRoutine.cpp - OpenACC bind routine transform ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// The OpenACC `routine` directive may specify a `bind(name)` clause to
10// associate the routine with a different symbol for device code. This pass
11// finds calls inside offload regions that target such routines and rewrites the
12// callee to the bound symbol.
13//
14// Overview:
15// ---------
16// For each function, walk operations that implement OffloadRegionOpInterface.
17// For each call inside the offload region, if the callee is a function with
18// an acc routine that has bind(name), replace the call to use the bound
19// symbol.
20//
21// Requirements:
22// -------------
23// - OffloadRegionOpInterface: the pass walks operations implementing this
24// interface to discover offload regions (e.g. acc.compute_region) and
25// rewrites calls inside their getOffloadRegion().
26// - CallOpInterface with working setCalleeFromCallable: call operations
27// must implement CallOpInterface and setCalleeFromCallable so the pass
28// can rewrite the callee to the symbol without invalidating the call.
29//
30//===----------------------------------------------------------------------===//
31
33
38#include "mlir/IR/SymbolTable.h"
40#include "llvm/Support/Debug.h"
41
42namespace mlir {
43namespace acc {
44#define GEN_PASS_DEF_ACCBINDROUTINE
45#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
46} // namespace acc
47} // namespace mlir
48
49#define DEBUG_TYPE "acc-bind-routine"
50
51using namespace mlir;
52using namespace mlir::acc;
53
54namespace {
55
56static RoutineOp getFirstAccRoutineOp(FunctionOpInterface funcOp,
57 const SymbolTable &symTab) {
58 if (isSpecializedAccRoutine(funcOp)) {
59 auto attr = funcOp->getAttrOfType<SpecializedRoutineAttr>(
61 return symTab.lookup<RoutineOp>(attr.getRoutine().getLeafReference());
62 }
63 auto routineInfo =
64 funcOp->getAttrOfType<RoutineInfoAttr>(getRoutineInfoAttrName());
65 assert(routineInfo && "expected acc.routine_info for acc routine function");
66 auto accRoutines = routineInfo.getAccRoutines();
67 assert(!accRoutines.empty() && "expected at least one acc routine");
68 return symTab.lookup<RoutineOp>(accRoutines[0].getLeafReference());
69}
70
71static bool isACCRoutineBindDefaultOrDeviceType(RoutineOp op,
72 DeviceType deviceType) {
73 if (!op.getBindIdName() && !op.getBindStrName())
74 return false;
75 return op.getBindNameValue().has_value() ||
76 op.getBindNameValue(deviceType).has_value();
77}
78
79class ACCBindRoutine : public acc::impl::ACCBindRoutineBase<ACCBindRoutine> {
80public:
81 using acc::impl::ACCBindRoutineBase<ACCBindRoutine>::ACCBindRoutineBase;
82
83 void runOnOperation() override {
84 func::FuncOp func = getOperation();
85 ModuleOp module = func->getParentOfType<ModuleOp>();
86 if (!module)
87 return;
88
89 SymbolTable symTab(module);
90 auto cachedAnalysis =
91 getCachedParentAnalysis<OpenACCSupport>(func->getParentOp());
92 OpenACCSupport &accSupport =
93 cachedAnalysis ? cachedAnalysis->get() : getAnalysis<OpenACCSupport>();
94
95 bool failed = false;
96
97 func.walk([&](acc::OffloadRegionOpInterface offload) {
98 Region &region = offload.getOffloadRegion();
99 region.walk([&](CallOpInterface callOp) {
100 if (!callOp.getCallableForCallee())
101 return;
102 SymbolRefAttr calleeSymbolRef =
103 dyn_cast<SymbolRefAttr>(callOp.getCallableForCallee());
104 if (!calleeSymbolRef)
105 return;
106
107 FunctionOpInterface callee = symTab.lookup<FunctionOpInterface>(
108 calleeSymbolRef.getLeafReference());
109 if (!callee)
110 return;
111
112 if (!(isAccRoutine(callee) || isSpecializedAccRoutine(callee)))
113 return;
114
115 if (auto routineInfo = callee->getAttrOfType<RoutineInfoAttr>(
117 if (routineInfo.getAccRoutines().size() > 1) {
118 (void)accSupport.emitNYI(callOp.getLoc(),
119 "multiple `acc routine`s");
120 failed = true;
121 return;
122 }
123 }
124
125 RoutineOp routine = getFirstAccRoutineOp(callee, symTab);
126 if (!isACCRoutineBindDefaultOrDeviceType(routine, this->deviceType))
127 return;
128
129 auto bindNameOpt = routine.getBindNameValue(this->deviceType);
130 if (!bindNameOpt)
131 bindNameOpt = routine.getBindNameValue();
132 if (!bindNameOpt)
133 return;
134
135 SymbolRefAttr calleeRef;
136 if (auto *symRef = std::get_if<SymbolRefAttr>(&*bindNameOpt)) {
137 calleeRef = *symRef;
138 } else {
139 calleeRef = FlatSymbolRefAttr::get(
140 callOp.getContext(),
141 std::get<StringAttr>(*bindNameOpt).getValue());
142 }
143 callOp.setCalleeFromCallable(calleeRef);
144 });
145 });
146
147 if (failed)
148 signalPassFailure();
149 }
150};
151
152} // namespace
static FlatSymbolRefAttr get(StringAttr value)
Construct a symbol reference for the given value name.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
RetT walk(FnT &&callback)
Walk all nested operations, blocks or regions (including this region), depending on the type of callb...
Definition Region.h:296
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition SymbolTable.h:24
Operation * lookup(StringRef name) const
Look up a symbol with the specified name, returning null if no such name exists.
InFlightDiagnostic emitNYI(Location loc, const Twine &message)
Report a case that is not yet supported by the implementation.
bool isAccRoutine(mlir::Operation *op)
Used to check whether the current operation is marked with acc routine.
Definition OpenACC.h:192
static constexpr StringLiteral getSpecializedRoutineAttrName()
Definition OpenACC.h:186
bool isSpecializedAccRoutine(mlir::Operation *op)
Used to check whether this is a specialized accelerator version of acc routine function.
Definition OpenACC.h:198
static constexpr StringLiteral getRoutineInfoAttrName()
Definition OpenACC.h:182
Include the generated interface declarations.