MLIR 23.0.0git
ACCComputeLowering.cpp
Go to the documentation of this file.
1//===- ACCComputeLowering.cpp - Lower ACC compute to compute_region -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass decomposes OpenACC compute constructs into a representation that
10// separates the data environment from the compute portion and prepares for
11// parallelism assignment and privatization at the appropriate level.
12//
13// Overview:
14// ---------
15// Each compute construct (`acc.parallel`, `acc.serial`, `acc.kernels`) is
16// lowered to (1) `acc.kernel_environment`, which captures the data environment
17// and (2) `acc.compute_region`, which holds the compute body. Inside the
18// compute region, acc.loop is converted to SCF loops (`scf.parallel` or
19// `scf.for`) with any predetermined parallelism expressed as `par_dims`. This
20// decomposition allows later phases to assign parallelism and handle
21// privatization at the right granularity.
22//
23// Transformations:
24// ----------------
25// 1. Compute constructs: acc.parallel, acc.serial, and acc.kernels are
26// replaced by acc.kernel_environment containing a single acc.compute_region.
27// Launch arguments (num_gangs, num_workers, vector_length) become
28// acc.par_width ops and are passed as compute_region launch operands.
29//
30// 2. acc.loop: Converted according to context and attributes:
31// - Unstructured: body wrapped in scf.execute_region.
32// - Sequential (serial region or seq clause): scf.parallel with
33// par_dims = sequential.
34// - Auto (in parallel/kernels): scf.for with collapse when
35// multi-dimensional.
36// - Orphan (not inside a compute construct): scf.for, no collapse.
37// - Independent (in parallel/kernels): scf.parallel with par_dims from
38// gang/worker/vector mapping (e.g. block_x).
39//
40//===----------------------------------------------------------------------===//
41
43
52#include "mlir/IR/IRMapping.h"
56
57namespace mlir {
58namespace acc {
59#define GEN_PASS_DEF_ACCCOMPUTELOWERING
60#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
61} // namespace acc
62} // namespace mlir
63
64#define DEBUG_TYPE "acc-compute-lowering"
65
66using namespace mlir;
67using namespace mlir::acc;
68
69namespace {
70
71//===----------------------------------------------------------------------===//
72// Helper functions
73//===----------------------------------------------------------------------===//
74
75/// Strip index_cast operations from a value before checking for a constant.
76static Value stripIndexCasts(Value val) {
77 while (auto castOp = val.getDefiningOp<arith::IndexCastOp>())
78 val = castOp.getIn();
79 return val;
80}
81
82/// A parallel construct is "effectively serial" when it specifies
83/// num_gangs(1), num_workers(1), and vector_length(1). This matches
84/// the semantics of acc.serial but expressed through acc.parallel.
85static bool isEffectivelySerial(ParallelOp op) {
86 auto numGangs = op.getNumGangsValues();
87 if (numGangs.size() != 1)
88 return false;
89 Value numWorkers = op.getNumWorkersValue();
90 if (!numWorkers)
91 return false;
92 Value vectorLength = op.getVectorLengthValue();
93 if (!vectorLength)
94 return false;
95 return isConstantIntValue(stripIndexCasts(numGangs.front()), 1) &&
96 isConstantIntValue(stripIndexCasts(numWorkers), 1) &&
97 isConstantIntValue(stripIndexCasts(vectorLength), 1);
98}
99
100static bool isOpInComputeRegion(Operation *op) {
101 Region *region = op->getBlock()->getParent();
102 return getEnclosingComputeOp(*region) != nullptr;
103}
104
105static bool isOpInSerialRegion(Operation *op) {
106 if (auto parallelOp = op->getParentOfType<ParallelOp>())
107 return isEffectivelySerial(parallelOp);
108 if (auto computeRegion = op->getParentOfType<ComputeRegionOp>())
109 return computeRegion.isEffectivelySerial();
110 if (op->getParentOfType<SerialOp>())
111 return true;
112 if (auto funcOp = op->getParentOfType<FunctionOpInterface>()) {
113 if (isSpecializedAccRoutine(funcOp)) {
114 auto attr = funcOp->getAttrOfType<SpecializedRoutineAttr>(
116 if (attr && attr.getLevel().getValue() == ParLevel::seq)
117 return true;
118 }
119 }
120 return false;
121}
122
123static void setParDimsAttr(Operation *op, GPUParallelDimsAttr attr) {
124 op->setAttr(GPUParallelDimsAttr::name, attr);
125}
126
127/// Insert a parallel dimension into the list, maintaining order by
128/// GPUParallelDimAttr::getOrder (descending).
129static void insertParDim(SmallVectorImpl<GPUParallelDimAttr> &parDims,
130 GPUParallelDimAttr parDim) {
131 GPUParallelDimAttr *lb = llvm::lower_bound(
132 parDims, parDim,
133 [](const GPUParallelDimAttr &a, const GPUParallelDimAttr &b) {
134 return a.getOrder() > b.getOrder();
135 });
136 if (lb == parDims.end() || *lb != parDim)
137 parDims.insert(lb, parDim);
138}
139
140/// Map loop parallelism clauses (gang/worker/vector) to GPU parallel
141/// dimensions using the given mapping policy.
143getParallelDimensions(LoopOp loopOp, const ACCToGPUMappingPolicy &policy,
144 DeviceType deviceType) {
146 auto *ctx = loopOp->getContext();
147
148 if (loopOp.hasVector(deviceType))
149 insertParDim(parDims, policy.vectorDim(ctx));
150 if (loopOp.hasWorker(deviceType))
151 insertParDim(parDims, policy.workerDim(ctx));
152 if (auto gangDimValue = loopOp.getGangValue(GangArgType::Dim, deviceType)) {
153 if (auto gangDimDefOp =
154 gangDimValue.getDefiningOp<arith::ConstantIntOp>()) {
155 auto gangLevel = getGangParLevel(gangDimDefOp.value());
156 insertParDim(parDims, policy.gangDim(ctx, gangLevel));
157 }
158 } else if (loopOp.hasGang(deviceType)) {
159 insertParDim(parDims, policy.gangDim(ctx, ParLevel::gang_dim1));
160 }
161 return parDims;
162}
163
164/// Create acc.par_width operations from gang/worker/vector values of a
165/// compute construct. Queries the device-type-specific values first, falling
166/// back to the default (DeviceType::None) values.
167template <typename ComputeConstructT>
169assignKnownLaunchArgs(ComputeConstructT computeOp, DeviceType deviceType,
170 RewriterBase &rewriter,
171 const ACCToGPUMappingPolicy &policy) {
172 SmallVector<Value> values;
173 auto *ctx = rewriter.getContext();
174 auto indexTy = rewriter.getIndexType();
175 auto loc = computeOp->getLoc();
176
177 auto numGangs = computeOp.getNumGangsValues(deviceType);
178 if (numGangs.empty())
179 numGangs = computeOp.getNumGangsValues();
180 for (auto [gangDimIdx, gangSize] : llvm::enumerate(numGangs)) {
181 auto gangLevel = getGangParLevel(gangDimIdx + 1);
182 values.push_back(
183 ParWidthOp::create(rewriter, loc,
185 rewriter, gangSize.getLoc(), indexTy, gangSize),
186 policy.gangDim(ctx, gangLevel)));
187 }
188
189 Value numWorkers = computeOp.getNumWorkersValue(deviceType);
190 if (!numWorkers)
191 numWorkers = computeOp.getNumWorkersValue();
192 if (numWorkers) {
193 values.push_back(ParWidthOp::create(
194 rewriter, loc,
195 getValueOrCreateCastToIndexLike(rewriter, numWorkers.getLoc(), indexTy,
196 numWorkers),
197 policy.workerDim(ctx)));
198 }
199
200 Value vectorLength = computeOp.getVectorLengthValue(deviceType);
201 if (!vectorLength)
202 vectorLength = computeOp.getVectorLengthValue();
203 if (vectorLength) {
204 values.push_back(ParWidthOp::create(
205 rewriter, loc,
206 getValueOrCreateCastToIndexLike(rewriter, vectorLength.getLoc(),
207 indexTy, vectorLength),
208 policy.vectorDim(ctx)));
209 }
210 return values;
211}
212
213/// SerialOp has no gang/worker/vector clauses.
214template <>
216assignKnownLaunchArgs<SerialOp>(SerialOp, DeviceType, RewriterBase &,
217 const ACCToGPUMappingPolicy &) {
218 return {};
219}
220
221//===----------------------------------------------------------------------===//
222// Loop conversion pattern
223//===----------------------------------------------------------------------===//
224
225class ACCLoopConversion : public OpRewritePattern<LoopOp> {
226public:
227 ACCLoopConversion(MLIRContext *ctx, const ACCToGPUMappingPolicy &policy,
228 DeviceType deviceType)
229 : OpRewritePattern<LoopOp>(ctx), policy(policy), deviceType(deviceType) {}
230
231 LogicalResult matchAndRewrite(LoopOp loopOp,
232 PatternRewriter &rewriter) const override {
233 if (loopOp.getUnstructured()) {
234 auto executeRegion =
236 if (!executeRegion)
237 return failure();
238 rewriter.replaceOp(loopOp, executeRegion);
239 return success();
240 }
241
242 LoopParMode parMode = loopOp.getDefaultOrDeviceTypeParallelism(deviceType);
243
244 if (parMode == LoopParMode::loop_seq || isOpInSerialRegion(loopOp)) {
245 // Although it might seem unintuitive, scf.parallel is used here because
246 // the parallelism of the loop is already predetermined (as sequential).
247 // scf.for will become a candidate for auto-parallelization analysis.
248 auto parallelOp = convertACCLoopToSCFParallel(loopOp, rewriter);
249 if (!parallelOp)
250 return failure();
251 setParDimsAttr(parallelOp,
252 GPUParallelDimsAttr::seq(loopOp->getContext()));
253 rewriter.replaceOp(loopOp, parallelOp);
254 } else if (parMode == LoopParMode::loop_auto) {
255 // All loops in serial regions should have already been handled.
256 assert(!isOpInSerialRegion(loopOp) &&
257 "Expected loop to be in non-serial region");
258 // Mark as scf.for to allow auto-parallelization analysis later.
259 auto forOp =
260 convertACCLoopToSCFFor(loopOp, rewriter, /*enableCollapse=*/true);
261 if (!forOp)
262 return failure();
263 rewriter.replaceOp(loopOp, forOp);
264 } else if (!isOpInComputeRegion(loopOp) &&
266 loopOp->getParentOfType<FunctionOpInterface>())) {
267 // This loop is an orphan `acc loop` but it is not in any sort
268 // of compute region. Thus it is just a sequential non-accelerator loop.
269 auto forOp =
270 convertACCLoopToSCFFor(loopOp, rewriter, /*enableCollapse=*/false);
271 if (!forOp)
272 return failure();
273 rewriter.replaceOp(loopOp, forOp);
274 } else {
275 assert(parMode == LoopParMode::loop_independent &&
276 "Expected loop to be independent");
277 auto parallelOp = convertACCLoopToSCFParallel(loopOp, rewriter);
278 if (!parallelOp)
279 return failure();
280
282 getParallelDimensions(loopOp, policy, deviceType);
283 if (!parDims.empty()) {
284 auto parDimsAttr =
285 GPUParallelDimsAttr::get(loopOp->getContext(), parDims);
286 setParDimsAttr(parallelOp, parDimsAttr);
287 }
288
289 rewriter.replaceOp(loopOp, parallelOp);
290 }
291 return success();
292 }
293
294private:
295 const ACCToGPUMappingPolicy &policy;
296 DeviceType deviceType;
297};
298
299//===----------------------------------------------------------------------===//
300// Compute construct conversion pattern
301//===----------------------------------------------------------------------===//
302
303template <typename ComputeConstructT>
304class ComputeOpConversion : public OpRewritePattern<ComputeConstructT> {
305public:
306 ComputeOpConversion(MLIRContext *ctx, const ACCToGPUMappingPolicy &policy,
307 DeviceType deviceType)
308 : OpRewritePattern<ComputeConstructT>(ctx), policy(policy),
309 deviceType(deviceType) {}
310
311 LogicalResult matchAndRewrite(ComputeConstructT computeOp,
312 PatternRewriter &rewriter) const override {
313 rewriter.setInsertionPoint(computeOp);
314 auto kernelEnv =
315 KernelEnvironmentOp::createAndPopulate(computeOp, rewriter);
316 auto launchArgs =
317 assignKnownLaunchArgs(computeOp, deviceType, rewriter, policy);
318 Region &region = computeOp.getRegion();
319 SetVector<Value> liveInValues;
320 getUsedValuesDefinedAbove(region, region, liveInValues);
321 IRMapping mapping;
322 auto computeRegion = buildComputeRegion(
323 computeOp->getLoc(), launchArgs, liveInValues.getArrayRef(),
324 ComputeConstructT::getOperationName(), region, rewriter, mapping);
325 if (!computeRegion) {
326 rewriter.eraseOp(kernelEnv);
327 return failure();
328 }
329 rewriter.eraseOp(computeOp);
330 return success();
331 }
332
333private:
334 const ACCToGPUMappingPolicy &policy;
335 DeviceType deviceType;
336};
337
338//===----------------------------------------------------------------------===//
339// Pass implementation
340//===----------------------------------------------------------------------===//
341
342class ACCComputeLowering
343 : public acc::impl::ACCComputeLoweringBase<ACCComputeLowering> {
344public:
345 using ACCComputeLoweringBase::ACCComputeLoweringBase;
346
347 void runOnOperation() override {
348 auto op = getOperation();
349 auto *context = op.getContext();
350
352
353 // Part 1: Convert acc.loop to scf.parallel/scf.for while the parent
354 // compute construct is still present (needed to determine conversion
355 // strategy).
356 RewritePatternSet loopPatterns(context);
357 loopPatterns.insert<ACCLoopConversion>(context, policy, deviceType);
358 if (failed(applyPatternsGreedily(op, std::move(loopPatterns))))
359 return signalPassFailure();
360
361 // Part 2: Convert acc.parallel, acc.kernels, and acc.serial to
362 // acc.kernel_environment { acc.compute_region { ... } }.
363 RewritePatternSet computePatterns(context);
364 computePatterns
365 .insert<ComputeOpConversion<ParallelOp>, ComputeOpConversion<KernelsOp>,
366 ComputeOpConversion<SerialOp>>(context, policy, deviceType);
367 if (failed(applyPatternsGreedily(op, std::move(computePatterns))))
368 return signalPassFailure();
369 }
370};
371
372} // namespace
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition Block.cpp:27
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:55
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:234
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:259
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
ParDimAttrT vectorDim(MLIRContext *ctx) const
ParDimAttrT workerDim(MLIRContext *ctx) const
ParDimAttrT gangDim(MLIRContext *ctx, ParLevel level) const
Convenience methods for specific parallelism levels.
Default policy that provides the standard GPU mapping: gang(dim:1) -> BlockX (gridDim....
Specialization of arith.constant op that returns an integer value.
Definition Arith.h:54
ParLevel getGangParLevel(int64_t gangDimValue)
Convert a gang dimension value (1, 2, or 3) to the corresponding ParLevel.
static constexpr StringLiteral getSpecializedRoutineAttrName()
Definition OpenACC.h:185
bool isSpecializedAccRoutine(mlir::Operation *op)
Used to check whether this is a specialized accelerator version of acc routine function.
Definition OpenACC.h:197
scf::ParallelOp convertACCLoopToSCFParallel(LoopOp loopOp, RewriterBase &rewriter)
Convert acc.loop to scf.parallel.
mlir::Operation * getEnclosingComputeOp(mlir::Region &region)
Used to obtain the enclosing compute construct operation that contains the provided region.
scf::ExecuteRegionOp convertUnstructuredACCLoopToSCFExecuteRegion(LoopOp loopOp, RewriterBase &rewriter)
Convert an unstructured acc.loop to scf.execute_region.
scf::ForOp convertACCLoopToSCFFor(LoopOp loopOp, RewriterBase &rewriter, bool enableCollapse)
Convert a structured acc.loop to scf.for.
ACCParMappingPolicy< mlir::acc::GPUParallelDimAttr > ACCToGPUMappingPolicy
Type alias for the GPU-specific mapping policy.
ComputeRegionOp buildComputeRegion(Location loc, ValueRange launchArgs, ValueRange inputArgs, llvm::StringRef origin, Region &regionToClone, RewriterBase &rewriter, IRMapping &mapping, ValueRange output={}, FlatSymbolRefAttr kernelFuncName={}, FlatSymbolRefAttr kernelModuleName={}, Value stream={})
Build an acc.compute_region operation by cloning a source region.
Include the generated interface declarations.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:123
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Definition Utils.cpp:120
void getUsedValuesDefinedAbove(Region &region, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...