187#include "llvm/ADT/SmallVector.h"
188#include "llvm/ADT/TypeSwitch.h"
192#define GEN_PASS_DEF_ACCIMPLICITDECLARE
193#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
197#define DEBUG_TYPE "acc-implicit-declare"
203using GlobalOpSetT = llvm::SmallSetVector<Operation *, 16>;
208static bool isGlobalUseCandidateForHoisting(
Operation *globalOp,
210 SymbolRefAttr symbol,
217 bool isConstant =
false;
218 bool isFunction =
false;
220 if (
auto globalVarOp = dyn_cast<acc::GlobalVariableOpInterface>(globalOp))
221 isConstant = globalVarOp.isConstant();
223 if (isa<FunctionOpInterface>(globalOp))
230 return !isConstant && !isFunction;
234bool isValidForAccDeclare(
Operation *globalOp) {
236 return !isa<FunctionOpInterface>(globalOp);
243template <
typename RecipeOpT>
244static bool hasRelevantRecipeUse(RecipeOpT &recipeOp, ModuleOp &mod) {
245 std::optional<SymbolTable::UseRange> symbolUses = recipeOp.getSymbolUses(mod);
248 if (!symbolUses.has_value() || symbolUses->empty())
252 auto begin = symbolUses->begin();
253 auto end = symbolUses->end();
254 if (begin != end && std::next(begin) != end)
259 return use.
getUser() != recipeOp.getOperation();
265template <
typename AccConstructT>
266static void hoistNonConstantDirectUses(AccConstructT accOp,
268 accOp.
walk([&](acc::AddressOfGlobalOpInterface addrOfOp) {
269 SymbolRefAttr symRef = addrOfOp.getSymbol();
273 if (isGlobalUseCandidateForHoisting(globalOp, addrOfOp, symRef,
275 auto computeRegionParent =
276 addrOfOp->getParentOfType<acc::ComputeRegionOp>();
277 addrOfOp->moveBefore(accOp);
278 if (computeRegionParent)
279 for (
Value v : addrOfOp->getResults())
280 computeRegionParent.wireHoistedValueThroughIns(v);
282 llvm::dbgs() <<
"Hoisted:\n\t" << addrOfOp <<
"\n\tfrom:\n\t";
283 accOp->print(llvm::dbgs(),
285 llvm::dbgs() <<
"\n");
292static void collectGlobalsFromDeviceRegion(
Region ®ion,
293 GlobalOpSetT &globals,
298 auto addrOfOp = dyn_cast<acc::AddressOfGlobalOpInterface>(op);
300 SymbolRefAttr symRef = addrOfOp.getSymbol();
308 if (isCandidate && globalOp && isValidForAccDeclare(globalOp))
309 globals.insert(globalOp);
310 }
else if (
auto indirectAccessOp =
311 dyn_cast<acc::IndirectGlobalAccessOpInterface>(op)) {
314 indirectAccessOp.getReferencedSymbols(symbols, &symTab);
315 for (SymbolRefAttr symRef : symbols)
317 if (isValidForAccDeclare(globalOp))
318 globals.insert(globalOp);
325 acc::DataClause clause) {
327 acc::DeclareAttr::get(context,
328 acc::DataClauseAttr::get(context, clause)));
333class ACCImplicitDeclare
334 :
public acc::impl::ACCImplicitDeclareBase<ACCImplicitDeclare> {
336 using ACCImplicitDeclareBase<ACCImplicitDeclare>::ACCImplicitDeclareBase;
338 void runOnOperation()
override {
339 ModuleOp mod = getOperation();
351 hoistNonConstantDirectUses(accOp, accSupport);
359 GlobalOpSetT globalsToAccDeclare;
364 collectGlobalsFromDeviceRegion(
365 accOp.getRegion(), globalsToAccDeclare, accSupport, symTab);
367 .Case([&](FunctionOpInterface
func) {
371 collectGlobalsFromDeviceRegion(
func.getFunctionBody(),
372 globalsToAccDeclare, accSupport,
375 .Case([&](acc::GlobalVariableOpInterface globalVarOp) {
377 if (
Region *initRegion = globalVarOp.getInitRegion())
378 collectGlobalsFromDeviceRegion(*initRegion, globalsToAccDeclare,
381 .Case([&](acc::PrivateRecipeOp privateRecipe) {
382 if (hasRelevantRecipeUse(privateRecipe, mod)) {
383 collectGlobalsFromDeviceRegion(privateRecipe.getInitRegion(),
384 globalsToAccDeclare, accSupport,
386 collectGlobalsFromDeviceRegion(privateRecipe.getDestroyRegion(),
387 globalsToAccDeclare, accSupport,
391 .Case([&](acc::FirstprivateRecipeOp firstprivateRecipe) {
392 if (hasRelevantRecipeUse(firstprivateRecipe, mod)) {
393 collectGlobalsFromDeviceRegion(firstprivateRecipe.getInitRegion(),
394 globalsToAccDeclare, accSupport,
396 collectGlobalsFromDeviceRegion(
397 firstprivateRecipe.getDestroyRegion(), globalsToAccDeclare,
399 collectGlobalsFromDeviceRegion(firstprivateRecipe.getCopyRegion(),
400 globalsToAccDeclare, accSupport,
404 .Case([&](acc::ReductionRecipeOp reductionRecipe) {
405 if (hasRelevantRecipeUse(reductionRecipe, mod)) {
406 collectGlobalsFromDeviceRegion(reductionRecipe.getInitRegion(),
407 globalsToAccDeclare, accSupport,
409 collectGlobalsFromDeviceRegion(
410 reductionRecipe.getCombinerRegion(), globalsToAccDeclare,
418 for (
Operation *globalOp : globalsToAccDeclare) {
420 llvm::dbgs() <<
"Global is being `acc declare copyin`d: ";
421 globalOp->
print(llvm::dbgs(),
423 llvm::dbgs() <<
"\n");
426 addDeclareAttr(context, globalOp, acc::DataClause::acc_copyin);
MLIRContext is the top-level object for a collection of MLIR operations.
Set of flags used to control the behavior of the various IR print methods (e.g.
OpPrintingFlags & skipRegions(bool skip=true)
Skip printing regions.
Operation is the basic unit of execution within MLIR.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
void print(raw_ostream &os, const OpPrintingFlags &flags={})
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RetT walk(FnT &&callback)
Walk all nested operations, blocks or regions (including this region), depending on the type of callb...
This class represents a specific symbol use.
Operation * getUser() const
Return the operation user of this symbol reference.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Operation * lookup(StringRef name) const
Look up a symbol with the specified name, returning null if no such name exists.
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool isValidSymbolUse(Operation *user, SymbolRefAttr symbol, Operation **definingOpPtr=nullptr)
Check if a symbol use is valid for use in an OpenACC region.
#define ACC_COMPUTE_CONSTRUCT_OPS
bool isAccRoutine(mlir::Operation *op)
Used to check whether the current operation is marked with acc routine.
bool isSpecializedAccRoutine(mlir::Operation *op)
Used to check whether this is a specialized accelerator version of acc routine function.
static constexpr StringLiteral getDeclareAttrName()
Used to obtain the attribute name for declare.
Include the generated interface declarations.
llvm::TypeSwitch< T, ResultT > TypeSwitch