MLIR 22.0.0git
ACCImplicitRoutine.cpp
Go to the documentation of this file.
1//===- ACCImplicitRoutine.cpp - OpenACC Implicit Routine Transform -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass implements the implicit rules described in OpenACC specification
10// for `Routine Directive` (OpenACC 3.4 spec, section 2.15.1).
11//
12// "If no explicit routine directive applies to a procedure whose definition
13// appears in the program unit being compiled, then the implementation applies
14// an implicit routine directive to that procedure if any of the following
15// conditions holds:
16// - The procedure is called or its address is accessed in a compute region."
17//
18// The specification further states:
19// "When the implementation applies an implicit routine directive to a
20// procedure, it must recursively apply implicit routine directives to other
21// procedures for which the above rules specify relevant dependencies. Such
22// dependencies can form a cycle, so the implementation must take care to avoid
23// infinite recursion."
24//
25// This pass implements these requirements by:
26// 1. Walking through all OpenACC compute constructs and functions already
27// marked with `acc routine` in the module and identifying function calls
28// within these regions.
29// 2. Creating implicit `acc.routine` operations for functions that don't
30// already have routine declarations.
31// 3. Recursively walking through all existing `acc routine` and creating
32// implicit routine operations for function calls within these routines,
33// while avoiding infinite recursion through proper tracking.
34//
35// Requirements:
36// -------------
37// To use this pass in a pipeline, the following requirements must be met:
38//
39// 1. Operation Interface Implementation: Operations that define functions
40// or call functions should implement `mlir::FunctionOpInterface` and
41// `mlir::CallOpInterface` respectively.
42//
43// 2. Analysis Registration (Optional): If custom behavior is needed for
44// determining if a symbol use is valid within GPU regions, the dialect
45// should pre-register the `acc::OpenACCSupport` analysis.
46//===----------------------------------------------------------------------===//
47
49
52#include "mlir/IR/Builders.h"
54#include "mlir/IR/BuiltinOps.h"
55#include "mlir/IR/Operation.h"
56#include "mlir/IR/Value.h"
59#include <queue>
60
61#define DEBUG_TYPE "acc-implicit-routine"
62
63namespace mlir {
64namespace acc {
65#define GEN_PASS_DEF_ACCIMPLICITROUTINE
66#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
67} // namespace acc
68} // namespace mlir
69
70namespace {
71
72using namespace mlir;
73
74class ACCImplicitRoutine
75 : public acc::impl::ACCImplicitRoutineBase<ACCImplicitRoutine> {
76private:
77 unsigned routineCounter = 0;
78 static constexpr llvm::StringRef accRoutinePrefix = "acc_routine_";
79
80 // Count existing routine operations and update counter
81 void initRoutineCounter(ModuleOp module) {
82 module.walk([&](acc::RoutineOp routineOp) { routineCounter++; });
83 }
84
85 // Check if routine has a default bind clause or a device-type specific bind
86 // clause. Returns true if `acc routine` has a default bind clause or
87 // a device-type specific bind clause.
88 bool isACCRoutineBindDefaultOrDeviceType(acc::RoutineOp op,
89 acc::DeviceType deviceType) {
90 // Fast check to avoid device-type specific lookups.
91 if (!op.getBindIdName() && !op.getBindStrName())
92 return false;
93 return op.getBindNameValue().has_value() ||
94 op.getBindNameValue(deviceType).has_value();
95 }
96
97 // Generate a unique name for the routine and create the routine operation
98 acc::RoutineOp createRoutineOp(OpBuilder &builder, Location loc,
99 FunctionOpInterface &callee) {
100 std::string routineName =
101 (accRoutinePrefix + std::to_string(routineCounter++)).str();
102 auto routineOp = acc::RoutineOp::create(
103 builder, loc,
104 /* sym_name=*/builder.getStringAttr(routineName),
105 /* func_name=*/
106 mlir::SymbolRefAttr::get(builder.getContext(),
107 builder.getStringAttr(callee.getName())),
108 /* bindIdName=*/nullptr,
109 /* bindStrName=*/nullptr,
110 /* bindIdNameDeviceType=*/nullptr,
111 /* bindStrNameDeviceType=*/nullptr,
112 /* worker=*/nullptr,
113 /* vector=*/nullptr,
114 /* seq=*/nullptr,
115 /* nohost=*/nullptr,
116 /* implicit=*/builder.getUnitAttr(),
117 /* gang=*/nullptr,
118 /* gangDim=*/nullptr,
119 /* gangDimDeviceType=*/nullptr);
120
121 // Assert that the callee does not already have routine info attribute
122 assert(!callee->hasAttr(acc::getRoutineInfoAttrName()) &&
123 "function is already associated with a routine");
124
125 callee->setAttr(
127 mlir::acc::RoutineInfoAttr::get(
128 builder.getContext(),
129 {mlir::SymbolRefAttr::get(builder.getContext(),
130 builder.getStringAttr(routineName))}));
131 return routineOp;
133
134 // Used to walk through a compute region looking for function calls.
135 void
136 implicitRoutineForCallsInComputeRegions(Operation *op, SymbolTable &symTab,
138 acc::OpenACCSupport &accSupport) {
139 op->walk([&](CallOpInterface callOp) {
140 if (!callOp.getCallableForCallee())
141 return;
142
143 auto calleeSymbolRef =
144 dyn_cast<SymbolRefAttr>(callOp.getCallableForCallee());
145 // When call is done through ssa value, the callee is not a symbol.
146 // Skip it because we don't know the call target.
147 if (!calleeSymbolRef)
148 return;
150 auto callee = symTab.lookup<FunctionOpInterface>(
151 calleeSymbolRef.getLeafReference().str());
152 // If the callee does not exist or is already a valid symbol for GPU
153 // regions, skip it
154
155 assert(callee && "callee function must be found in symbol table");
156 if (accSupport.isValidSymbolUse(callOp.getOperation(), calleeSymbolRef))
157 return;
158 builder.setInsertionPoint(callee);
159 createRoutineOp(builder, callee.getLoc(), callee);
160 });
161 }
162
163 // Recursively handle calls within a routine operation
164 void implicitRoutineForCallsInRoutine(acc::RoutineOp routineOp,
165 mlir::OpBuilder &builder,
166 acc::OpenACCSupport &accSupport,
167 acc::DeviceType targetDeviceType) {
168 // When bind clause is used, it means that the target is different than the
169 // function to which the `acc routine` is used with. Skip this case to
170 // avoid implicitly recursively marking calls that would not end up on
171 // device.
172 if (isACCRoutineBindDefaultOrDeviceType(routineOp, targetDeviceType))
173 return;
174
175 SymbolTable symTab(routineOp->getParentOfType<ModuleOp>());
176 std::queue<acc::RoutineOp> routineQueue;
177 routineQueue.push(routineOp);
178 while (!routineQueue.empty()) {
179 auto currentRoutine = routineQueue.front();
180 routineQueue.pop();
181 auto func = symTab.lookup<FunctionOpInterface>(
182 currentRoutine.getFuncName().getLeafReference());
183 func.walk([&](CallOpInterface callOp) {
184 if (!callOp.getCallableForCallee())
185 return;
186
187 auto calleeSymbolRef =
188 dyn_cast<SymbolRefAttr>(callOp.getCallableForCallee());
189 // When call is done through ssa value, the callee is not a symbol.
190 // Skip it because we don't know the call target.
191 if (!calleeSymbolRef)
192 return;
193
194 auto callee = symTab.lookup<FunctionOpInterface>(
195 calleeSymbolRef.getLeafReference().str());
196 // If the callee does not exist or is already a valid symbol for GPU
197 // regions, skip it
198 assert(callee && "callee function must be found in symbol table");
199 if (accSupport.isValidSymbolUse(callOp.getOperation(), calleeSymbolRef))
200 return;
201 builder.setInsertionPoint(callee);
202 auto newRoutineOp = createRoutineOp(builder, callee.getLoc(), callee);
203 routineQueue.push(newRoutineOp);
204 });
205 }
206 }
207
208public:
209 using ACCImplicitRoutineBase<ACCImplicitRoutine>::ACCImplicitRoutineBase;
210
211 void runOnOperation() override {
212 auto module = getOperation();
213 mlir::OpBuilder builder(module.getContext());
214 SymbolTable symTab(module);
215 initRoutineCounter(module);
216
217 acc::OpenACCSupport &accSupport = getAnalysis<acc::OpenACCSupport>();
218
219 // Handle compute regions
220 module.walk([&](Operation *op) {
221 if (isa<ACC_COMPUTE_CONSTRUCT_OPS>(op))
222 implicitRoutineForCallsInComputeRegions(op, symTab, builder,
223 accSupport);
224 });
225
226 // Use the device type option from the pass options.
227 acc::DeviceType targetDeviceType = deviceType;
228
229 // Handle existing routines
230 module.walk([&](acc::RoutineOp routineOp) {
231 implicitRoutineForCallsInRoutine(routineOp, builder, accSupport,
232 targetDeviceType);
233 });
234 }
235};
236
237} // namespace
UnitAttr getUnitAttr()
Definition Builders.cpp:98
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:262
MLIRContext * getContext() const
Definition Builders.h:56
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:207
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:797
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition SymbolTable.h:24
Operation * lookup(StringRef name) const
Look up a symbol with the specified name, returning null if no such name exists.
bool isValidSymbolUse(Operation *user, SymbolRefAttr symbol, Operation **definingOpPtr=nullptr)
Check if a symbol use is valid for use in an OpenACC region.
static constexpr StringLiteral getRoutineInfoAttrName()
Definition OpenACC.h:176
Include the generated interface declarations.