MLIR 23.0.0git
ACCRoutineLowering.cpp
Go to the documentation of this file.
1//===- ACCRoutineLowering.cpp - Wrap ACC routines in compute_region -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass handles `acc routine` directive by creating specialized
10// functions with appropriate parallelism information that can be used for
11// eventual creation of device function.
12//
13// Overview:
14// ---------
15// For each acc.routine that is not bound by name, the pass creates a new
16// function (the "device" copy) whose body is a single acc.compute_region
17// containing a clone of the original (host) function body. Parallelism is
18// expressed by one acc.par_width derived from the routine's clauses (seq,
19// vector, worker, gang). The device copy created is simply a staging
20// place for eventual move to device module level function.
21//
22// Transformations:
23// ----------------
24// 1. Device function: Same signature as the host; attributes copied except
25// acc.routine_info. The acc.specialized_routine attribute is set with the
26// routine symbol, par level, and original function name.
27//
28// 2. Body: One acc.par_width, one acc.compute_region that clones the host
29// body. Multi-block host bodies are wrapped in scf.execute_region inside
30// the compute_region.
31//
32// 3. Finalization: acc.routine's func_name is updated to the device function.
33// For nohost routines, all uses of the host symbol are replaced with the
34// device symbol and the host function is erased.
35//
36//===----------------------------------------------------------------------===//
37
39
45#include "mlir/IR/IRMapping.h"
47#include "mlir/IR/SymbolTable.h"
48#include "mlir/IR/Value.h"
49
50namespace mlir {
51namespace acc {
52#define GEN_PASS_DEF_ACCROUTINELOWERING
53#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
54} // namespace acc
55} // namespace mlir
56
57#define DEBUG_TYPE "acc-routine-lowering"
58
59using namespace mlir;
60using namespace mlir::acc;
61
62namespace {
63
64/// Compute the ParLevel from an acc.routine op for specialization.
65static ParLevel computeParLevel(RoutineOp routineOp, DeviceType deviceType) {
66 auto gangDim = routineOp.getGangDimValue(deviceType);
67 if (!gangDim)
68 gangDim = routineOp.getGangDimValue();
69 if (gangDim) {
70 switch (*gangDim) {
71 case 1:
72 return ParLevel::gang_dim1;
73 case 2:
74 return ParLevel::gang_dim2;
75 case 3:
76 return ParLevel::gang_dim3;
77 default:
78 break;
79 }
80 }
81 if (routineOp.hasGang(deviceType) || routineOp.hasGang())
82 return ParLevel::gang_dim1;
83 if (routineOp.hasWorker(deviceType) || routineOp.hasWorker())
84 return ParLevel::worker;
85 if (routineOp.hasVector(deviceType) || routineOp.hasVector())
86 return ParLevel::vector;
87 return ParLevel::seq;
88}
89
90/// Collect return operands from the function (first block with func.return).
91static void getReturnValues(func::FuncOp func, SmallVectorImpl<Value> &result) {
92 result.clear();
93 for (Block &block : func.getBody().getBlocks()) {
94 if (auto returnOp = dyn_cast<func::ReturnOp>(block.getTerminator())) {
95 result.assign(returnOp.operand_begin(), returnOp.operand_end());
96 break;
97 }
98 }
99}
100
101/// Create the device function with the same signature as the host, set
102/// specialized_routine, and add a single block with the same block arguments.
103static func::FuncOp createFunctionForDeviceStaging(func::FuncOp hostFunc,
104 RoutineOp routineOp,
105 ParLevel parLevel,
106 MLIRContext *ctx,
107 IRRewriter &rewriter) {
108 Location loc = hostFunc.getLoc();
109 FunctionType funcType = hostFunc.getFunctionType();
110 func::FuncOp deviceFunc =
111 func::FuncOp::create(rewriter, loc, hostFunc.getName(), funcType);
112 deviceFunc->setAttrs(hostFunc->getAttrs());
113 deviceFunc->removeAttr(getRoutineInfoAttrName());
114 deviceFunc->setAttr(getSpecializedRoutineAttrName(),
115 SpecializedRoutineAttr::get(
116 ctx, SymbolRefAttr::get(ctx, routineOp.getSymName()),
117 ParLevelAttr::get(ctx, parLevel),
118 StringAttr::get(ctx, hostFunc.getName())));
119
120 Block *sourceBlock = &hostFunc.getBody().front();
121 Block *newBlock = rewriter.createBlock(&deviceFunc.getRegion());
122 for (BlockArgument arg : sourceBlock->getArguments())
123 newBlock->addArgument(arg.getType(), hostFunc.getLoc());
124
125 return deviceFunc;
126}
127
128/// Fill the device function body: one acc.par_width, one acc.compute_region
129/// (cloning the host body with inputArgsToMap), then func.return.
130static LogicalResult
131buildRoutineBody(func::FuncOp deviceFunc, func::FuncOp hostFunc,
132 ArrayRef<Value> funcReturnVals, ParLevel parLevel,
133 DefaultACCToGPUMappingPolicy &policy, IRRewriter &rewriter) {
134 Block *newBlock = &deviceFunc.getBody().front();
135 Block *sourceBlock = &hostFunc.getBody().front();
136 Location loc = hostFunc.getLoc();
137 MLIRContext *ctx = rewriter.getContext();
138
139 rewriter.setInsertionPointToStart(newBlock);
140 GPUParallelDimAttr parDim = policy.map(ctx, parLevel);
141 Value parWidthVal = ParWidthOp::create(rewriter, loc, Value(), parDim);
142 SmallVector<Value, 4> inputArgs(newBlock->getArguments().begin(),
143 newBlock->getArguments().end());
144
145 // Normally the region passed to buildComputeRegion is something in the
146 // current function. Here we pass the body of the original (host) function as
147 // an optimization to avoid cloning twice (once for a staged device copy and
148 // again when creating the compute region). Since we clone only once, we must
149 // also provide the original function's arguments so the mapping is correct
150 // when cloning the body.
151 ValueRange sourceArgsToMap = sourceBlock->getArguments();
152
153 IRMapping mapping;
154 rewriter.setInsertionPointAfter(parWidthVal.getDefiningOp());
155 ComputeRegionOp computeRegion = buildComputeRegion(
156 loc, {parWidthVal}, inputArgs, RoutineOp::getOperationName(),
157 hostFunc.getBody(), rewriter, mapping,
158 /*output=*/funcReturnVals, /*kernelFuncName=*/{},
159 /*kernelModuleName=*/{}, /*stream=*/{}, sourceArgsToMap);
160 if (!computeRegion)
161 return failure();
162
163 rewriter.setInsertionPointAfter(computeRegion);
164 if (funcReturnVals.empty())
165 func::ReturnOp::create(rewriter, loc);
166 else
167 func::ReturnOp::create(rewriter, loc, computeRegion.getResults());
168
169 return success();
170}
171
172/// Update acc.routine refs and optionally erase host for nohost routines.
173static LogicalResult finalizeRoutines(
174 SmallVectorImpl<std::tuple<func::FuncOp, func::FuncOp, RoutineOp>>
175 &accRoutineInfo,
176 ModuleOp mod, MLIRContext *ctx) {
177 for (auto &[hostFunc, deviceFunc, routineOp] : accRoutineInfo) {
178 routineOp.setFuncNameAttr(SymbolRefAttr::get(ctx, deviceFunc.getName()));
179 routineOp->moveBefore(deviceFunc);
180
181 if (routineOp.getNohost()) {
183 StringAttr::get(ctx, hostFunc.getName()),
184 StringAttr::get(ctx, deviceFunc.getName()), mod))) {
185 routineOp.emitError("cannot replace symbol uses for acc routine");
186 return failure();
187 }
188 hostFunc->erase();
189 }
190 }
191 return success();
192}
193
194class ACCRoutineLowering
195 : public acc::impl::ACCRoutineLoweringBase<ACCRoutineLowering> {
196public:
197 using ACCRoutineLoweringBase::ACCRoutineLoweringBase;
198
199 void runOnOperation() override {
200 ModuleOp mod = getOperation();
201 if (mod.getOps<RoutineOp>().empty()) {
202 LLVM_DEBUG(llvm::dbgs()
203 << "Skipping ACCRoutineLowering - no acc.routine ops\n");
204 return;
205 }
206
207 SymbolTable symTab(mod);
208 MLIRContext *ctx = mod.getContext();
209 IRRewriter rewriter(ctx);
211
212 // Tuple: host function, device function, routine operation
214 accRoutineInfo;
215
216 for (RoutineOp routineOp : mod.getOps<RoutineOp>()) {
217 if (routineOp.getBindNameValue() ||
218 routineOp.getBindNameValue(deviceType))
219 continue;
220
221 func::FuncOp hostFunc = symTab.lookup<func::FuncOp>(
222 routineOp.getFuncName().getLeafReference());
223 if (!hostFunc) {
224 routineOp.emitError("acc routine function not found in symbol table");
225 return signalPassFailure();
226 }
227 if (hostFunc.isExternal())
228 continue;
229
230 SmallVector<Value, 4> funcReturnVals;
231 getReturnValues(hostFunc, funcReturnVals);
232
233 OpBuilder::InsertionGuard guard(rewriter);
234 ParLevel parLevel = computeParLevel(routineOp, deviceType);
235 func::FuncOp deviceFunc = createFunctionForDeviceStaging(
236 hostFunc, routineOp, parLevel, ctx, rewriter);
237 if (failed(buildRoutineBody(deviceFunc, hostFunc, funcReturnVals,
238 parLevel, policy, rewriter)))
239 return signalPassFailure();
240
241 accRoutineInfo.push_back({hostFunc, deviceFunc, routineOp});
242 symTab.insert(deviceFunc);
243 }
244
245 if (failed(finalizeRoutines(accRoutineInfo, mod, ctx)))
246 return signalPassFailure();
247 }
248};
249
250} // namespace
return success()
This class represents an argument of a Block.
Definition Value.h:306
Block represents an ordered list of Operations.
Definition Block.h:33
Operation & front()
Definition Block.h:163
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Block.cpp:158
BlockArgListType getArguments()
Definition Block.h:97
MLIRContext * getContext() const
Definition Builders.h:56
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:434
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:414
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition SymbolTable.h:24
static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol, StringAttr newSymbol, Operation *from)
Attempt to replace all uses of the given symbol 'oldSymbol' with the provided symbol 'newSymbol' that...
Operation * lookup(StringRef name) const
Look up a symbol with the specified name, returning null if no such name exists.
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Default policy that provides the standard GPU mapping: gang(dim:1) -> BlockX (gridDim....
mlir::acc::GPUParallelDimAttr map(MLIRContext *ctx, ParLevel level) const override
Map an OpenACC parallelism level to target dimension.
static constexpr StringLiteral getSpecializedRoutineAttrName()
Definition OpenACC.h:185
ComputeRegionOp buildComputeRegion(Location loc, ValueRange launchArgs, ValueRange inputArgs, llvm::StringRef origin, Region &regionToClone, RewriterBase &rewriter, IRMapping &mapping, ValueRange output={}, FlatSymbolRefAttr kernelFuncName={}, FlatSymbolRefAttr kernelModuleName={}, Value stream={}, ValueRange inputArgsToMap={})
Build an acc.compute_region operation by cloning a source region.
static constexpr StringLiteral getRoutineInfoAttrName()
Definition OpenACC.h:181
Include the generated interface declarations.