61#define DEBUG_TYPE "acc-implicit-routine"
65#define GEN_PASS_DEF_ACCIMPLICITROUTINE
66#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
74class ACCImplicitRoutine
77 unsigned routineCounter = 0;
78 static constexpr llvm::StringRef accRoutinePrefix =
"acc_routine_";
81 void initRoutineCounter(ModuleOp module) {
82 module.walk([&](acc::RoutineOp routineOp) { routineCounter++; });
88 bool isACCRoutineBindDefaultOrDeviceType(acc::RoutineOp op,
89 acc::DeviceType deviceType) {
91 if (!op.getBindIdName() && !op.getBindStrName())
93 return op.getBindNameValue().has_value() ||
94 op.getBindNameValue(deviceType).has_value();
99 FunctionOpInterface &callee) {
100 std::string routineName =
101 (accRoutinePrefix + std::to_string(routineCounter++)).str();
102 auto routineOp = acc::RoutineOp::create(
106 mlir::SymbolRefAttr::get(builder.
getContext(),
123 "function is already associated with a routine");
127 mlir::acc::RoutineInfoAttr::get(
129 {mlir::SymbolRefAttr::get(builder.getContext(),
130 builder.getStringAttr(routineName))}));
139 op->
walk([&](CallOpInterface callOp) {
140 if (!callOp.getCallableForCallee())
143 auto calleeSymbolRef =
144 dyn_cast<SymbolRefAttr>(callOp.getCallableForCallee());
147 if (!calleeSymbolRef)
150 auto callee = symTab.
lookup<FunctionOpInterface>(
151 calleeSymbolRef.getLeafReference().str());
155 assert(callee &&
"callee function must be found in symbol table");
159 createRoutineOp(builder, callee.getLoc(), callee);
164 void implicitRoutineForCallsInRoutine(acc::RoutineOp routineOp,
167 acc::DeviceType targetDeviceType) {
172 if (isACCRoutineBindDefaultOrDeviceType(routineOp, targetDeviceType))
175 SymbolTable symTab(routineOp->getParentOfType<ModuleOp>());
176 std::queue<acc::RoutineOp> routineQueue;
177 routineQueue.push(routineOp);
178 while (!routineQueue.empty()) {
179 auto currentRoutine = routineQueue.front();
181 auto func = symTab.
lookup<FunctionOpInterface>(
182 currentRoutine.getFuncName().getLeafReference());
183 func.walk([&](CallOpInterface callOp) {
184 if (!callOp.getCallableForCallee())
187 auto calleeSymbolRef =
188 dyn_cast<SymbolRefAttr>(callOp.getCallableForCallee());
191 if (!calleeSymbolRef)
194 auto callee = symTab.
lookup<FunctionOpInterface>(
195 calleeSymbolRef.getLeafReference().str());
198 assert(callee &&
"callee function must be found in symbol table");
202 auto newRoutineOp = createRoutineOp(builder, callee.getLoc(), callee);
203 routineQueue.push(newRoutineOp);
209 using ACCImplicitRoutineBase<ACCImplicitRoutine>::ACCImplicitRoutineBase;
212 auto module = getOperation();
215 initRoutineCounter(module);
220 module.walk([&](Operation *op) {
221 if (isa<ACC_COMPUTE_CONSTRUCT_OPS>(op))
222 implicitRoutineForCallsInComputeRegions(op, symTab, builder,
230 module.walk([&](acc::RoutineOp routineOp) {
231 implicitRoutineForCallsInRoutine(routineOp, builder, accSupport,
StringAttr getStringAttr(const Twine &bytes)
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
AnalysisT & getAnalysis()
Operation is the basic unit of execution within MLIR.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
virtual void runOnOperation()=0
The polymorphic API that runs the pass over the currently held operation.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Operation * lookup(StringRef name) const
Look up a symbol with the specified name, returning null if no such name exists.
bool isValidSymbolUse(Operation *user, SymbolRefAttr symbol, Operation **definingOpPtr=nullptr)
Check if a symbol use is valid for use in an OpenACC region.
::mlir::Pass::Option< mlir::acc::DeviceType > deviceType
static constexpr StringLiteral getRoutineInfoAttrName()
Include the generated interface declarations.