187#include "llvm/ADT/SmallVector.h"
188#include "llvm/ADT/TypeSwitch.h"
192#define GEN_PASS_DEF_ACCIMPLICITDECLARE
193#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
197#define DEBUG_TYPE "acc-implicit-declare"
203using GlobalOpSetT = llvm::SmallSetVector<Operation *, 16>;
208static bool isGlobalUseCandidateForHoisting(
Operation *globalOp,
210 SymbolRefAttr symbol,
217 bool isConstant =
false;
218 bool isFunction =
false;
220 if (
auto globalVarOp = dyn_cast<acc::GlobalVariableOpInterface>(globalOp))
221 isConstant = globalVarOp.isConstant();
223 if (isa<FunctionOpInterface>(globalOp))
230 return !isConstant && !isFunction;
234bool isValidForAccDeclare(
Operation *globalOp) {
236 return !isa<FunctionOpInterface>(globalOp);
243template <
typename RecipeOpT>
244static bool hasRelevantRecipeUse(RecipeOpT &recipeOp, ModuleOp &mod) {
245 std::optional<SymbolTable::UseRange> symbolUses = recipeOp.getSymbolUses(mod);
248 if (!symbolUses.has_value() || symbolUses->empty())
252 auto begin = symbolUses->begin();
253 auto end = symbolUses->end();
254 if (begin != end && std::next(begin) != end)
259 return use.
getUser() != recipeOp.getOperation();
265template <
typename AccConstructT>
266static void hoistNonConstantDirectUses(AccConstructT accOp,
268 accOp.
walk([&](acc::AddressOfGlobalOpInterface addrOfOp) {
269 SymbolRefAttr symRef = addrOfOp.getSymbol();
273 if (isGlobalUseCandidateForHoisting(globalOp, addrOfOp, symRef,
275 addrOfOp->moveBefore(accOp);
277 llvm::dbgs() <<
"Hoisted:\n\t" << addrOfOp <<
"\n\tfrom:\n\t";
278 accOp->print(llvm::dbgs(),
280 llvm::dbgs() <<
"\n");
287static void collectGlobalsFromDeviceRegion(
Region ®ion,
288 GlobalOpSetT &globals,
293 auto addrOfOp = dyn_cast<acc::AddressOfGlobalOpInterface>(op);
295 SymbolRefAttr symRef = addrOfOp.getSymbol();
303 if (isCandidate && globalOp && isValidForAccDeclare(globalOp))
304 globals.insert(globalOp);
305 }
else if (
auto indirectAccessOp =
306 dyn_cast<acc::IndirectGlobalAccessOpInterface>(op)) {
309 indirectAccessOp.getReferencedSymbols(symbols, &symTab);
310 for (SymbolRefAttr symRef : symbols)
312 if (isValidForAccDeclare(globalOp))
313 globals.insert(globalOp);
320 acc::DataClause clause) {
322 acc::DeclareAttr::get(context,
323 acc::DataClauseAttr::get(context, clause)));
328class ACCImplicitDeclare
331 using ACCImplicitDeclareBase<ACCImplicitDeclare>::ACCImplicitDeclareBase;
333 void runOnOperation()
override {
334 ModuleOp mod = getOperation();
346 hoistNonConstantDirectUses(accOp, accSupport);
354 GlobalOpSetT globalsToAccDeclare;
359 collectGlobalsFromDeviceRegion(
360 accOp.getRegion(), globalsToAccDeclare, accSupport, symTab);
362 .Case<FunctionOpInterface>([&](
auto func) {
366 collectGlobalsFromDeviceRegion(
func.getFunctionBody(),
367 globalsToAccDeclare, accSupport,
370 .Case<acc::GlobalVariableOpInterface>([&](
auto globalVarOp) {
372 if (
Region *initRegion = globalVarOp.getInitRegion())
373 collectGlobalsFromDeviceRegion(*initRegion, globalsToAccDeclare,
376 .Case<acc::PrivateRecipeOp>([&](
auto privateRecipe) {
377 if (hasRelevantRecipeUse(privateRecipe, mod)) {
378 collectGlobalsFromDeviceRegion(privateRecipe.getInitRegion(),
379 globalsToAccDeclare, accSupport,
381 collectGlobalsFromDeviceRegion(privateRecipe.getDestroyRegion(),
382 globalsToAccDeclare, accSupport,
386 .Case<acc::FirstprivateRecipeOp>([&](
auto firstprivateRecipe) {
387 if (hasRelevantRecipeUse(firstprivateRecipe, mod)) {
388 collectGlobalsFromDeviceRegion(firstprivateRecipe.getInitRegion(),
389 globalsToAccDeclare, accSupport,
391 collectGlobalsFromDeviceRegion(
392 firstprivateRecipe.getDestroyRegion(), globalsToAccDeclare,
394 collectGlobalsFromDeviceRegion(firstprivateRecipe.getCopyRegion(),
395 globalsToAccDeclare, accSupport,
399 .Case<acc::ReductionRecipeOp>([&](
auto reductionRecipe) {
400 if (hasRelevantRecipeUse(reductionRecipe, mod)) {
401 collectGlobalsFromDeviceRegion(reductionRecipe.getInitRegion(),
402 globalsToAccDeclare, accSupport,
404 collectGlobalsFromDeviceRegion(
405 reductionRecipe.getCombinerRegion(), globalsToAccDeclare,
413 for (
Operation *globalOp : globalsToAccDeclare) {
415 llvm::dbgs() <<
"Global is being `acc declare copyin`d: ";
416 globalOp->
print(llvm::dbgs(),
418 llvm::dbgs() <<
"\n");
421 addDeclareAttr(context, globalOp, acc::DataClause::acc_copyin);
#define ACC_COMPUTE_CONSTRUCT_OPS
MLIRContext is the top-level object for a collection of MLIR operations.
Set of flags used to control the behavior of the various IR print methods (e.g.
OpPrintingFlags & skipRegions(bool skip=true)
Skip printing regions.
Operation is the basic unit of execution within MLIR.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
void print(raw_ostream &os, const OpPrintingFlags &flags={})
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RetT walk(FnT &&callback)
Walk all nested operations, blocks or regions (including this region), depending on the type of callb...
This class represents a specific symbol use.
Operation * getUser() const
Return the operation user of this symbol reference.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Operation * lookup(StringRef name) const
Look up a symbol with the specified name, returning null if no such name exists.
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
bool isValidSymbolUse(Operation *user, SymbolRefAttr symbol, Operation **definingOpPtr=nullptr)
Check if a symbol use is valid for use in an OpenACC region.
bool isAccRoutine(mlir::Operation *op)
Used to check whether the current operation is marked with acc routine.
bool isSpecializedAccRoutine(mlir::Operation *op)
Used to check whether this is a specialized accelerator version of acc routine function.
static constexpr StringLiteral getDeclareAttrName()
Used to obtain the attribute name for declare.
Include the generated interface declarations.
llvm::TypeSwitch< T, ResultT > TypeSwitch