61#define DEBUG_TYPE "acc-implicit-routine"
65#define GEN_PASS_DEF_ACCIMPLICITROUTINE
66#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
74class ACCImplicitRoutine
77 unsigned routineCounter = 0;
78 static constexpr llvm::StringRef accRoutinePrefix =
"acc_routine_";
81 void initRoutineCounter(ModuleOp module) {
82 module.walk([&](acc::RoutineOp routineOp) { routineCounter++; });
88 bool isACCRoutineBindDefaultOrDeviceType(acc::RoutineOp op,
89 acc::DeviceType deviceType) {
91 if (!op.getBindIdName() && !op.getBindStrName())
93 return op.getBindNameValue().has_value() ||
94 op.getBindNameValue(deviceType).has_value();
99 FunctionOpInterface &callee) {
100 std::string routineName =
101 (accRoutinePrefix + std::to_string(routineCounter++)).str();
102 auto routineOp = acc::RoutineOp::create(
106 mlir::SymbolRefAttr::get(builder.
getContext(),
123 "function is already associated with a routine");
127 mlir::acc::RoutineInfoAttr::get(
129 {mlir::SymbolRefAttr::get(builder.getContext(),
130 builder.getStringAttr(routineName))}));
139 op->
walk([&](CallOpInterface callOp) {
140 if (!callOp.getCallableForCallee())
143 auto calleeSymbolRef =
144 dyn_cast<SymbolRefAttr>(callOp.getCallableForCallee());
147 if (!calleeSymbolRef)
150 auto callee = symTab.
lookup<FunctionOpInterface>(
151 calleeSymbolRef.getLeafReference().str());
155 assert(callee &&
"callee function must be found in symbol table");
159 createRoutineOp(builder, callee.getLoc(), callee);
164 void implicitRoutineForCallsInRoutine(acc::RoutineOp routineOp,
167 acc::DeviceType targetDeviceType) {
172 if (isACCRoutineBindDefaultOrDeviceType(routineOp, targetDeviceType))
175 SymbolTable symTab(routineOp->getParentOfType<ModuleOp>());
176 std::queue<acc::RoutineOp> routineQueue;
177 routineQueue.push(routineOp);
178 while (!routineQueue.empty()) {
179 auto currentRoutine = routineQueue.front();
182 currentRoutine.getFuncName().getLeafReference());
183 func.walk([&](CallOpInterface callOp) {
184 if (!callOp.getCallableForCallee())
187 auto calleeSymbolRef =
188 dyn_cast<SymbolRefAttr>(callOp.getCallableForCallee());
191 if (!calleeSymbolRef)
194 auto callee = symTab.
lookup<FunctionOpInterface>(
195 calleeSymbolRef.getLeafReference().str());
198 assert(callee &&
"callee function must be found in symbol table");
202 auto newRoutineOp = createRoutineOp(builder, callee.getLoc(), callee);
203 routineQueue.push(newRoutineOp);
209 using ACCImplicitRoutineBase<ACCImplicitRoutine>::ACCImplicitRoutineBase;
211 void runOnOperation()
override {
212 auto module = getOperation();
215 initRoutineCounter(module);
220 module.walk([&](Operation *op) {
221 if (isa<ACC_COMPUTE_CONSTRUCT_OPS>(op))
222 implicitRoutineForCallsInComputeRegions(op, symTab, builder,
227 acc::DeviceType targetDeviceType = deviceType;
230 module.walk([&](acc::RoutineOp routineOp) {
231 implicitRoutineForCallsInRoutine(routineOp, builder, accSupport,
StringAttr getStringAttr(const Twine &bytes)
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation is the basic unit of execution within MLIR.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Operation * lookup(StringRef name) const
Look up a symbol with the specified name, returning null if no such name exists.
bool isValidSymbolUse(Operation *user, SymbolRefAttr symbol, Operation **definingOpPtr=nullptr)
Check if a symbol use is valid for use in an OpenACC region.
static constexpr StringLiteral getRoutineInfoAttrName()
Include the generated interface declarations.