52#define GEN_PASS_DEF_ACCROUTINELOWERING
53#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
57#define DEBUG_TYPE "acc-routine-lowering"
65static ParLevel computeParLevel(RoutineOp routineOp, DeviceType deviceType) {
66 auto gangDim = routineOp.getGangDimValue(deviceType);
68 gangDim = routineOp.getGangDimValue();
72 return ParLevel::gang_dim1;
74 return ParLevel::gang_dim2;
76 return ParLevel::gang_dim3;
81 if (routineOp.hasGang(deviceType) || routineOp.hasGang())
82 return ParLevel::gang_dim1;
83 if (routineOp.hasWorker(deviceType) || routineOp.hasWorker())
84 return ParLevel::worker;
85 if (routineOp.hasVector(deviceType) || routineOp.hasVector())
86 return ParLevel::vector;
93 for (
Block &block :
func.getBody().getBlocks()) {
94 if (
auto returnOp = dyn_cast<func::ReturnOp>(block.getTerminator())) {
95 result.assign(returnOp.operand_begin(), returnOp.operand_end());
103static func::FuncOp createFunctionForDeviceStaging(func::FuncOp hostFunc,
109 FunctionType funcType = hostFunc.getFunctionType();
110 func::FuncOp deviceFunc =
111 func::FuncOp::create(rewriter, loc, hostFunc.getName(), funcType);
112 deviceFunc->setAttrs(hostFunc->getAttrs());
115 SpecializedRoutineAttr::get(
116 ctx, SymbolRefAttr::get(ctx, routineOp.getSymName()),
117 ParLevelAttr::get(ctx, parLevel),
118 StringAttr::get(ctx, hostFunc.getName())));
120 Block *sourceBlock = &hostFunc.getBody().
front();
123 newBlock->
addArgument(arg.getType(), hostFunc.getLoc());
131buildRoutineBody(func::FuncOp deviceFunc, func::FuncOp hostFunc,
134 Block *newBlock = &deviceFunc.getBody().
front();
135 Block *sourceBlock = &hostFunc.getBody().
front();
140 GPUParallelDimAttr parDim = policy.
map(ctx, parLevel);
141 Value parWidthVal = ParWidthOp::create(rewriter, loc,
Value(), parDim);
156 loc, {parWidthVal}, inputArgs, RoutineOp::getOperationName(),
157 hostFunc.getBody(), rewriter, mapping,
159 {}, {}, sourceArgsToMap);
164 if (funcReturnVals.empty())
165 func::ReturnOp::create(rewriter, loc);
167 func::ReturnOp::create(rewriter, loc, computeRegion.getResults());
173static LogicalResult finalizeRoutines(
177 for (
auto &[hostFunc, deviceFunc, routineOp] : accRoutineInfo) {
178 routineOp.setFuncNameAttr(SymbolRefAttr::get(ctx, deviceFunc.getName()));
179 routineOp->moveBefore(deviceFunc);
181 if (routineOp.getNohost()) {
183 StringAttr::get(ctx, hostFunc.getName()),
184 StringAttr::get(ctx, deviceFunc.getName()), mod))) {
185 routineOp.emitError(
"cannot replace symbol uses for acc routine");
194class ACCRoutineLowering
195 :
public acc::impl::ACCRoutineLoweringBase<ACCRoutineLowering> {
197 using ACCRoutineLoweringBase::ACCRoutineLoweringBase;
199 void runOnOperation()
override {
200 ModuleOp mod = getOperation();
201 if (mod.getOps<RoutineOp>().empty()) {
202 LLVM_DEBUG(llvm::dbgs()
203 <<
"Skipping ACCRoutineLowering - no acc.routine ops\n");
216 for (RoutineOp routineOp : mod.getOps<RoutineOp>()) {
217 if (routineOp.getBindNameValue() ||
218 routineOp.getBindNameValue(deviceType))
221 func::FuncOp hostFunc = symTab.
lookup<func::FuncOp>(
222 routineOp.getFuncName().getLeafReference());
224 routineOp.emitError(
"acc routine function not found in symbol table");
225 return signalPassFailure();
227 if (hostFunc.isExternal())
231 getReturnValues(hostFunc, funcReturnVals);
234 ParLevel parLevel = computeParLevel(routineOp, deviceType);
235 func::FuncOp deviceFunc = createFunctionForDeviceStaging(
236 hostFunc, routineOp, parLevel, ctx, rewriter);
237 if (failed(buildRoutineBody(deviceFunc, hostFunc, funcReturnVals,
238 parLevel, policy, rewriter)))
239 return signalPassFailure();
241 accRoutineInfo.push_back({hostFunc, deviceFunc, routineOp});
242 symTab.
insert(deviceFunc);
245 if (failed(finalizeRoutines(accRoutineInfo, mod, ctx)))
246 return signalPassFailure();
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
BlockArgListType getArguments()
MLIRContext * getContext() const
This is a utility class for mapping one set of IR entities to another.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol, StringAttr newSymbol, Operation *from)
Attempt to replace all uses of the given symbol 'oldSymbol' with the provided symbol 'newSymbol' that...
Operation * lookup(StringRef name) const
Look up a symbol with the specified name, returning null if no such name exists.
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Default policy that provides the standard GPU mapping: gang(dim:1) -> BlockX (gridDim....
mlir::acc::GPUParallelDimAttr map(MLIRContext *ctx, ParLevel level) const override
Map an OpenACC parallelism level to target dimension.
static constexpr StringLiteral getSpecializedRoutineAttrName()
ComputeRegionOp buildComputeRegion(Location loc, ValueRange launchArgs, ValueRange inputArgs, llvm::StringRef origin, Region ®ionToClone, RewriterBase &rewriter, IRMapping &mapping, ValueRange output={}, FlatSymbolRefAttr kernelFuncName={}, FlatSymbolRefAttr kernelModuleName={}, Value stream={}, ValueRange inputArgsToMap={})
Build an acc.compute_region operation by cloning a source region.
static constexpr StringLiteral getRoutineInfoAttrName()
Include the generated interface declarations.