40#include "llvm/Support/Debug.h"
44#define GEN_PASS_DEF_ACCBINDROUTINE
45#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
49#define DEBUG_TYPE "acc-bind-routine"
56static RoutineOp getFirstAccRoutineOp(FunctionOpInterface funcOp,
59 auto attr = funcOp->getAttrOfType<SpecializedRoutineAttr>(
61 return symTab.
lookup<RoutineOp>(attr.getRoutine().getLeafReference());
65 assert(routineInfo &&
"expected acc.routine_info for acc routine function");
66 auto accRoutines = routineInfo.getAccRoutines();
67 assert(!accRoutines.empty() &&
"expected at least one acc routine");
68 return symTab.
lookup<RoutineOp>(accRoutines[0].getLeafReference());
71static bool isACCRoutineBindDefaultOrDeviceType(RoutineOp op,
72 DeviceType deviceType) {
73 if (!op.getBindIdName() && !op.getBindStrName())
75 return op.getBindNameValue().has_value() ||
76 op.getBindNameValue(deviceType).has_value();
79class ACCBindRoutine :
public acc::impl::ACCBindRoutineBase<ACCBindRoutine> {
81 using acc::impl::ACCBindRoutineBase<ACCBindRoutine>::ACCBindRoutineBase;
83 void runOnOperation()
override {
84 func::FuncOp
func = getOperation();
85 ModuleOp module =
func->getParentOfType<ModuleOp>();
91 getCachedParentAnalysis<OpenACCSupport>(
func->getParentOp());
93 cachedAnalysis ? cachedAnalysis->get() : getAnalysis<OpenACCSupport>();
97 func.walk([&](acc::OffloadRegionOpInterface offload) {
98 Region ®ion = offload.getOffloadRegion();
99 region.
walk([&](CallOpInterface callOp) {
100 if (!callOp.getCallableForCallee())
102 SymbolRefAttr calleeSymbolRef =
103 dyn_cast<SymbolRefAttr>(callOp.getCallableForCallee());
104 if (!calleeSymbolRef)
107 FunctionOpInterface callee = symTab.
lookup<FunctionOpInterface>(
108 calleeSymbolRef.getLeafReference());
115 if (
auto routineInfo = callee->getAttrOfType<RoutineInfoAttr>(
117 if (routineInfo.getAccRoutines().size() > 1) {
119 "multiple `acc routine`s");
125 RoutineOp routine = getFirstAccRoutineOp(callee, symTab);
126 if (!isACCRoutineBindDefaultOrDeviceType(routine, this->deviceType))
129 auto bindNameOpt = routine.getBindNameValue(this->deviceType);
131 bindNameOpt = routine.getBindNameValue();
135 SymbolRefAttr calleeRef;
136 if (
auto *symRef = std::get_if<SymbolRefAttr>(&*bindNameOpt)) {
141 std::get<StringAttr>(*bindNameOpt).getValue());
143 callOp.setCalleeFromCallable(calleeRef);
static FlatSymbolRefAttr get(StringAttr value)
Construct a symbol reference for the given value name.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RetT walk(FnT &&callback)
Walk all nested operations, blocks or regions (including this region), depending on the type of callb...
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Operation * lookup(StringRef name) const
Look up a symbol with the specified name, returning null if no such name exists.
InFlightDiagnostic emitNYI(Location loc, const Twine &message)
Report a case that is not yet supported by the implementation.
bool isAccRoutine(mlir::Operation *op)
Used to check whether the current operation is marked with acc routine.
static constexpr StringLiteral getSpecializedRoutineAttrName()
bool isSpecializedAccRoutine(mlir::Operation *op)
Used to check whether this is a specialized accelerator version of acc routine function.
static constexpr StringLiteral getRoutineInfoAttrName()
Include the generated interface declarations.