MLIR 23.0.0git
ACCComputeLowering.cpp
Go to the documentation of this file.
1//===- ACCComputeLowering.cpp - Lower ACC compute to compute_region -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass decomposes OpenACC compute constructs into a representation that
10// separates the data environment from the compute portion and prepares for
11// parallelism assignment and privatization at the appropriate level.
12//
13// Overview:
14// ---------
15// Each compute construct (`acc.parallel`, `acc.serial`, `acc.kernels`) is
16// lowered to (1) `acc.kernel_environment`, which captures the data environment
17// and (2) `acc.compute_region`, which holds the compute body. Inside the
18// compute region, acc.loop is converted to SCF loops (`scf.parallel` or
19// `scf.for`) with any predetermined parallelism expressed as `par_dims`. This
20// decomposition allows later phases to assign parallelism and handle
21// privatization at the right granularity.
22//
23// Transformations:
24// ----------------
25// 1. Compute constructs: acc.parallel, acc.serial, and acc.kernels are
26// replaced by acc.kernel_environment containing a single acc.compute_region.
27// For acc.parallel / acc.kernels, launch arguments (num_gangs, num_workers,
28// vector_length) become acc.par_width ops (each result is `index`) and are
29// passed as compute_region launch operands. Compute regions with
30// num_gangs(1), num_workers(1), and vector_length(1) and acc serial use a
31// single sequential acc.par_width launch operand.
32//
33// 2. acc.loop: Converted according to context and attributes:
34// - Unstructured: body wrapped in scf.execute_region.
35// - Sequential (serial region, seq clause, or compute region with
36// num_gangs(1), num_workers(1), and vector_length(1)):
37// scf.parallel with par_dims = sequential.
38// - Auto (in parallel/kernels): scf.for with collapse when
39// multi-dimensional.
40// - Orphan (not inside a compute construct): scf.for, no collapse.
41// - Independent (in parallel/kernels): scf.parallel with par_dims from
42// gang/worker/vector mapping (e.g. block_x).
43//
44//===----------------------------------------------------------------------===//
45
47
56#include "mlir/IR/IRMapping.h"
57#include "mlir/IR/Matchers.h"
61#include "llvm/ADT/STLExtras.h"
62
63namespace mlir {
64namespace acc {
65#define GEN_PASS_DEF_ACCCOMPUTELOWERING
66#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
67} // namespace acc
68} // namespace mlir
69
70#define DEBUG_TYPE "acc-compute-lowering"
71
72using namespace mlir;
73using namespace mlir::acc;
74
75namespace {
76
77//===----------------------------------------------------------------------===//
78// Helper functions
79//===----------------------------------------------------------------------===//
80
81/// Strip index_cast operations from a value before checking for a constant.
82static Value stripIndexCasts(Value val) {
83 while (auto castOp = val.getDefiningOp<arith::IndexCastOp>())
84 val = castOp.getIn();
85 return val;
86}
87
88template <typename ComputeOpT>
89static bool isGangWorkerVectorAllOne(ComputeOpT op) {
90 auto numGangs = op.getNumGangsValues();
91 if (numGangs.empty())
92 return false;
93 for (Value gangSize : numGangs) {
94 if (!isConstantIntValue(stripIndexCasts(gangSize), 1))
95 return false;
96 }
97 Value numWorkers = op.getNumWorkersValue();
98 if (!numWorkers)
99 return false;
100 Value vectorLength = op.getVectorLengthValue();
101 if (!vectorLength)
102 return false;
103 return isConstantIntValue(stripIndexCasts(numWorkers), 1) &&
104 isConstantIntValue(stripIndexCasts(vectorLength), 1);
105}
106
107/// A compute construct is "effectively serial" when it specifies
108/// num_gangs(1), num_workers(1), and vector_length(1). This is because
109/// these are the only parallelism dimensions expressible from OpenACC spec
110/// point-of-view and is consistent with how `serial` semantics are defined.
111template <typename ComputeOpT>
112static bool isEffectivelySerial(ComputeOpT op) {
113 return isGangWorkerVectorAllOne(op);
114}
115
116static bool isOpInComputeRegion(Operation *op) {
117 Region *region = op->getBlock()->getParent();
118 return getEnclosingComputeOp(*region) != nullptr;
119}
120
121static bool isOpInSerialRegion(Operation *op) {
122 if (auto parallelOp = op->getParentOfType<ParallelOp>())
123 return isEffectivelySerial(parallelOp);
124 if (auto kernelsOp = op->getParentOfType<KernelsOp>())
125 return isEffectivelySerial(kernelsOp);
126 if (op->getParentOfType<SerialOp>())
127 return true;
128 if (auto computeRegion = op->getParentOfType<ComputeRegionOp>())
129 return computeRegion.isEffectivelySerial();
130 if (auto funcOp = op->getParentOfType<FunctionOpInterface>()) {
131 if (isSpecializedAccRoutine(funcOp)) {
132 auto attr = funcOp->getAttrOfType<SpecializedRoutineAttr>(
134 if (attr && attr.getLevel().getValue() == ParLevel::seq)
135 return true;
136 }
137 }
138 return false;
139}
140
141static void setParDimsAttr(Operation *op, GPUParallelDimsAttr attr) {
142 op->setAttr(GPUParallelDimsAttr::name, attr);
143}
144
145/// Clone defining ops of constant live-in values into `region`, rewrite uses
146/// inside the region to the clones, and remove those values from
147/// `liveInValues` so they are not threaded through `acc.compute_region` ins.
148static void materializeConstantLiveInsIntoRegion(Region &region,
149 SetVector<Value> &liveInValues,
150 RewriterBase &rewriter) {
151 SmallVector<Value> constantLiveIns;
152 for (Value v : liveInValues) {
153 Operation *defOp = v.getDefiningOp();
154 if (defOp && matchPattern(defOp, m_Constant())) {
155 // As per the definition of ConstantLike trait, constants must have a
156 // single result.
157 assert(defOp->getNumResults() == 1 &&
158 "constants must have a single result");
159 constantLiveIns.push_back(v);
160 }
161 }
162 if (constantLiveIns.empty())
163 return;
164
165 OpBuilder::InsertionGuard guard(rewriter);
166 rewriter.setInsertionPointToStart(&region.front());
167
168 for (Value v : constantLiveIns) {
169 Value newV = rewriter.clone(*v.getDefiningOp())->getResult(0);
170 replaceAllUsesInRegionWith(v, newV, region);
171 liveInValues.remove(v);
172 }
173}
174
175/// Insert a parallel dimension into the list, maintaining order by
176/// GPUParallelDimAttr::getOrder (descending).
177static void insertParDim(SmallVectorImpl<GPUParallelDimAttr> &parDims,
178 GPUParallelDimAttr parDim) {
179 GPUParallelDimAttr *lb = llvm::lower_bound(
180 parDims, parDim,
181 [](const GPUParallelDimAttr &a, const GPUParallelDimAttr &b) {
182 return a.getOrder() > b.getOrder();
183 });
184 if (lb == parDims.end() || *lb != parDim)
185 parDims.insert(lb, parDim);
186}
187
188/// Return the device type from which gang/worker/vector clauses should be read.
189/// If the requested device type has any such clauses, use that exclusively;
190/// otherwise fall back to the default (DeviceType::None).
191static DeviceType getGangWorkerVectorDeviceType(LoopOp loopOp,
192 DeviceType deviceType) {
193 if (deviceType != DeviceType::None &&
194 loopOp.hasAnyGangWorkerVector(deviceType))
195 return deviceType;
196 return DeviceType::None;
197}
198
199template <typename ComputeConstructT>
200static DeviceType getParDimsDeviceType(ComputeConstructT computeOp,
201 DeviceType deviceType) {
202 if (deviceType != DeviceType::None &&
203 computeOp.hasAnyGangWorkerVector(deviceType))
204 return deviceType;
205 return DeviceType::None;
206}
207
208/// Map loop parallelism clauses (gang/worker/vector) to GPU parallel
209/// dimensions using the given mapping policy.
211getParallelDimensions(LoopOp loopOp, const ACCToGPUMappingPolicy &policy,
212 DeviceType deviceType) {
213 deviceType = getGangWorkerVectorDeviceType(loopOp, deviceType);
215 auto *ctx = loopOp->getContext();
216
217 if (loopOp.hasVector(deviceType))
218 insertParDim(parDims, policy.vectorDim(ctx));
219 if (loopOp.hasWorker(deviceType))
220 insertParDim(parDims, policy.workerDim(ctx));
221 if (auto gangDimValue = loopOp.getGangValue(GangArgType::Dim, deviceType)) {
222 if (auto gangDimDefOp =
223 gangDimValue.getDefiningOp<arith::ConstantIntOp>()) {
224 auto gangLevel = getGangParLevel(gangDimDefOp.value());
225 insertParDim(parDims, policy.gangDim(ctx, gangLevel));
226 }
227 } else if (loopOp.hasGang(deviceType)) {
228 insertParDim(parDims, policy.gangDim(ctx, ParLevel::gang_dim1));
229 }
230 return parDims;
231}
232
233/// Build `acc.compute_region` launch operands: one sequential `acc.par_width`
234/// for `acc.serial`, for `acc.parallel` / `acc.kernels` when every num_gangs
235/// operand and num_workers / vector_length are the constant 1, and otherwise
236/// `acc.par_width` from gang/worker/vector (device-type operands first, then
237/// default DeviceType::None).
238template <typename ComputeConstructT>
240assignKnownLaunchArgs(ComputeConstructT computeOp, DeviceType deviceType,
241 RewriterBase &rewriter,
242 const ACCToGPUMappingPolicy &policy) {
243 auto *ctx = rewriter.getContext();
244 auto loc = computeOp->getLoc();
245
246 if constexpr (std::is_same_v<ComputeConstructT, SerialOp>) {
247 return {ParWidthOp::create(rewriter, loc, Value(), policy.seqDim(ctx))};
248 } else if constexpr (llvm::is_one_of<ComputeConstructT, ParallelOp,
249 KernelsOp>::value) {
250 if (isEffectivelySerial(computeOp))
251 return {ParWidthOp::create(rewriter, loc, Value(), policy.seqDim(ctx))};
252
253 deviceType = getParDimsDeviceType(computeOp, deviceType);
254
255 SmallVector<Value> values;
256 auto indexTy = rewriter.getIndexType();
257
258 auto numGangs = computeOp.getNumGangsValues(deviceType);
259 for (auto [gangDimIdx, gangSize] : llvm::enumerate(numGangs)) {
260 auto gangLevel = getGangParLevel(gangDimIdx + 1);
261 values.push_back(ParWidthOp::create(
262 rewriter, loc,
263 getValueOrCreateCastToIndexLike(rewriter, gangSize.getLoc(), indexTy,
264 gangSize),
265 policy.gangDim(ctx, gangLevel)));
266 }
267
268 Value numWorkers = computeOp.getNumWorkersValue(deviceType);
269 if (numWorkers) {
270 values.push_back(ParWidthOp::create(
271 rewriter, loc,
272 getValueOrCreateCastToIndexLike(rewriter, numWorkers.getLoc(),
273 indexTy, numWorkers),
274 policy.workerDim(ctx)));
275 }
276
277 Value vectorLength = computeOp.getVectorLengthValue(deviceType);
278 if (vectorLength) {
279 values.push_back(ParWidthOp::create(
280 rewriter, loc,
281 getValueOrCreateCastToIndexLike(rewriter, vectorLength.getLoc(),
282 indexTy, vectorLength),
283 policy.vectorDim(ctx)));
284 }
285 return values;
286 } else {
287 llvm_unreachable("assignKnownLaunchArgs: expected parallel, kernels, or "
288 "serial");
289 }
290}
291
292//===----------------------------------------------------------------------===//
293// Loop conversion pattern
294//===----------------------------------------------------------------------===//
295
296class ACCLoopConversion : public OpRewritePattern<LoopOp> {
297public:
298 ACCLoopConversion(MLIRContext *ctx, const ACCToGPUMappingPolicy &policy,
299 DeviceType deviceType)
300 : OpRewritePattern<LoopOp>(ctx), policy(policy), deviceType(deviceType) {}
301
302 LogicalResult matchAndRewrite(LoopOp loopOp,
303 PatternRewriter &rewriter) const override {
304 if (loopOp.getUnstructured()) {
305 auto executeRegion =
307 if (!executeRegion)
308 return failure();
309 rewriter.replaceOp(loopOp, executeRegion);
310 return success();
311 }
312
313 LoopParMode parMode = loopOp.getDefaultOrDeviceTypeParallelism(deviceType);
314
315 if (parMode == LoopParMode::loop_seq || isOpInSerialRegion(loopOp)) {
316 // Although it might seem unintuitive, scf.parallel is used here because
317 // the parallelism of the loop is already predetermined (as sequential).
318 // scf.for will become a candidate for auto-parallelization analysis.
319 auto parallelOp = convertACCLoopToSCFParallel(loopOp, rewriter);
320 if (!parallelOp)
321 return failure();
322 setParDimsAttr(parallelOp,
323 GPUParallelDimsAttr::seq(loopOp->getContext()));
324 rewriter.replaceOp(loopOp, parallelOp);
325 } else if (parMode == LoopParMode::loop_auto) {
326 // All loops in serial regions should have already been handled.
327 assert(!isOpInSerialRegion(loopOp) &&
328 "Expected loop to be in non-serial region");
329 // Mark as scf.for to allow auto-parallelization analysis later.
330 auto forOp =
331 convertACCLoopToSCFFor(loopOp, rewriter, /*enableCollapse=*/true);
332 if (!forOp)
333 return failure();
335 getParallelDimensions(loopOp, policy, deviceType);
336 if (!parDims.empty()) {
337 auto parDimsAttr =
338 GPUParallelDimsAttr::get(loopOp->getContext(), parDims);
339 setParDimsAttr(forOp, parDimsAttr);
340 }
341 rewriter.replaceOp(loopOp, forOp);
342 } else if (!isOpInComputeRegion(loopOp) &&
344 loopOp->getParentOfType<FunctionOpInterface>())) {
345 // This loop is an orphan `acc loop` but it is not in any sort
346 // of compute region. Thus it is just a sequential non-accelerator loop.
347 auto forOp =
348 convertACCLoopToSCFFor(loopOp, rewriter, /*enableCollapse=*/false);
349 if (!forOp)
350 return failure();
351 rewriter.replaceOp(loopOp, forOp);
352 } else {
353 assert(parMode == LoopParMode::loop_independent &&
354 "Expected loop to be independent");
355 auto parallelOp = convertACCLoopToSCFParallel(loopOp, rewriter);
356 if (!parallelOp)
357 return failure();
358
360 getParallelDimensions(loopOp, policy, deviceType);
361 if (!parDims.empty()) {
362 auto parDimsAttr =
363 GPUParallelDimsAttr::get(loopOp->getContext(), parDims);
364 setParDimsAttr(parallelOp, parDimsAttr);
365 }
366
367 rewriter.replaceOp(loopOp, parallelOp);
368 }
369 return success();
370 }
371
372private:
373 const ACCToGPUMappingPolicy &policy;
374 DeviceType deviceType;
375};
376
377//===----------------------------------------------------------------------===//
378// Compute construct conversion pattern
379//===----------------------------------------------------------------------===//
380
381template <typename ComputeConstructT>
382class ComputeOpConversion : public OpRewritePattern<ComputeConstructT> {
383public:
384 ComputeOpConversion(MLIRContext *ctx, const ACCToGPUMappingPolicy &policy,
385 DeviceType deviceType)
386 : OpRewritePattern<ComputeConstructT>(ctx), policy(policy),
387 deviceType(deviceType) {}
388
389 LogicalResult matchAndRewrite(ComputeConstructT computeOp,
390 PatternRewriter &rewriter) const override {
391 rewriter.setInsertionPoint(computeOp);
392 auto kernelEnv =
393 KernelEnvironmentOp::createAndPopulate(computeOp, deviceType, rewriter);
394 auto launchArgs =
395 assignKnownLaunchArgs(computeOp, deviceType, rewriter, policy);
396 Region &region = computeOp.getRegion();
397 SetVector<Value> liveInValues;
398 getUsedValuesDefinedAbove(region, region, liveInValues);
399 materializeConstantLiveInsIntoRegion(region, liveInValues, rewriter);
400 IRMapping mapping;
401 auto computeRegion = buildComputeRegion(
402 computeOp->getLoc(), launchArgs, liveInValues.getArrayRef(),
403 ComputeConstructT::getOperationName(), region, rewriter, mapping);
404 if (!computeRegion) {
405 rewriter.eraseOp(kernelEnv);
406 return failure();
407 }
408 rewriter.eraseOp(computeOp);
409 return success();
410 }
411
412private:
413 const ACCToGPUMappingPolicy &policy;
414 DeviceType deviceType;
415};
416
417//===----------------------------------------------------------------------===//
418// Pass implementation
419//===----------------------------------------------------------------------===//
420
421class ACCComputeLowering
422 : public acc::impl::ACCComputeLoweringBase<ACCComputeLowering> {
423public:
424 using ACCComputeLoweringBase::ACCComputeLoweringBase;
425
426 void runOnOperation() override {
427 auto op = getOperation();
428 auto *context = op.getContext();
429
431
432 // Part 1: Convert acc.loop to scf.parallel/scf.for while the parent
433 // compute construct is still present (needed to determine conversion
434 // strategy).
435 RewritePatternSet loopPatterns(context);
436 loopPatterns.insert<ACCLoopConversion>(context, policy, deviceType);
437 if (failed(applyPatternsGreedily(op, std::move(loopPatterns))))
438 return signalPassFailure();
439
440 // Part 2: Convert acc.parallel, acc.kernels, and acc.serial to
441 // acc.kernel_environment { acc.compute_region { ... } }.
442 RewritePatternSet computePatterns(context);
443 computePatterns
444 .insert<ComputeOpConversion<ParallelOp>, ComputeOpConversion<KernelsOp>,
445 ComputeOpConversion<SerialOp>>(context, policy, deviceType);
446 if (failed(applyPatternsGreedily(op, std::move(computePatterns))))
447 return signalPassFailure();
448 }
449};
450
451} // namespace
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition Block.cpp:27
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:55
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:567
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:230
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:432
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:255
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:429
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
ParDimAttrT seqDim(MLIRContext *ctx) const
ParDimAttrT vectorDim(MLIRContext *ctx) const
ParDimAttrT workerDim(MLIRContext *ctx) const
ParDimAttrT gangDim(MLIRContext *ctx, ParLevel level) const
Convenience methods for specific parallelism levels.
Default policy that provides the standard GPU mapping: gang(dim:1) -> BlockX (gridDim....
Specialization of arith.constant op that returns an integer value.
Definition Arith.h:55
ParLevel getGangParLevel(int64_t gangDimValue)
Convert a gang dimension value (1, 2, or 3) to the corresponding ParLevel.
static constexpr StringLiteral getSpecializedRoutineAttrName()
Definition OpenACC.h:189
ComputeRegionOp buildComputeRegion(Location loc, ValueRange launchArgs, ValueRange inputArgs, llvm::StringRef origin, Region &regionToClone, RewriterBase &rewriter, IRMapping &mapping, ValueRange output={}, FlatSymbolRefAttr kernelFuncName={}, FlatSymbolRefAttr kernelModuleName={}, Value stream={}, ValueRange inputArgsToMap={})
Build an acc.compute_region operation by cloning a source region.
bool isSpecializedAccRoutine(mlir::Operation *op)
Used to check whether this is a specialized accelerator version of acc routine function.
Definition OpenACC.h:201
scf::ParallelOp convertACCLoopToSCFParallel(LoopOp loopOp, RewriterBase &rewriter)
Convert acc.loop to scf.parallel.
mlir::Operation * getEnclosingComputeOp(mlir::Region &region)
Used to obtain the enclosing compute construct operation that contains the provided region.
scf::ExecuteRegionOp convertUnstructuredACCLoopToSCFExecuteRegion(LoopOp loopOp, RewriterBase &rewriter)
Convert an unstructured acc.loop to scf.execute_region.
scf::ForOp convertACCLoopToSCFFor(LoopOp loopOp, RewriterBase &rewriter, bool enableCollapse)
Convert a structured acc.loop to scf.for.
ACCParMappingPolicy< mlir::acc::GPUParallelDimAttr > ACCToGPUMappingPolicy
Type alias for the GPU-specific mapping policy.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
void replaceAllUsesInRegionWith(Value orig, Value replacement, Region &region)
Replace all uses of orig within the given region with replacement.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:125
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Definition Utils.cpp:122
void getUsedValuesDefinedAbove(Region &region, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...