MLIR 23.0.0git
ACCComputeLowering.cpp
Go to the documentation of this file.
1//===- ACCComputeLowering.cpp - Lower ACC compute to compute_region -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass decomposes OpenACC compute constructs into a representation that
10// separates the data environment from the compute portion and prepares for
11// parallelism assignment and privatization at the appropriate level.
12//
13// Overview:
14// ---------
15// Each compute construct (`acc.parallel`, `acc.serial`, `acc.kernels`) is
16// lowered to (1) `acc.kernel_environment`, which captures the data environment
17// and (2) `acc.compute_region`, which holds the compute body. Inside the
18// compute region, acc.loop is converted to SCF loops (`scf.parallel` or
19// `scf.for`) with any predetermined parallelism expressed as `par_dims`. This
20// decomposition allows later phases to assign parallelism and handle
21// privatization at the right granularity.
22//
23// Transformations:
24// ----------------
25// 1. Compute constructs: acc.parallel, acc.serial, and acc.kernels are
26// replaced by acc.kernel_environment containing a single acc.compute_region.
27// For acc.parallel / acc.kernels, launch arguments (num_gangs, num_workers,
28// vector_length) become acc.par_width ops (each result is `index`) and are
29// passed as compute_region launch operands. Compute regions with
30// num_gangs(1), num_workers(1), and vector_length(1) and acc serial use a
31// single sequential acc.par_width launch operand.
32//
33// 2. acc.loop: Converted according to context and attributes:
34// - Unstructured: body wrapped in scf.execute_region.
35// - Sequential (serial region, seq clause, or compute region with
36// num_gangs(1), num_workers(1), and vector_length(1)):
37// scf.parallel with par_dims = sequential.
38// - Auto (in parallel/kernels): scf.for with collapse when
39// multi-dimensional.
40// - Orphan (not inside a compute construct): scf.for, no collapse.
41// - Independent (in parallel/kernels): scf.parallel with par_dims from
42// gang/worker/vector mapping (e.g. block_x).
43//
44//===----------------------------------------------------------------------===//
45
47
56#include "mlir/IR/IRMapping.h"
57#include "mlir/IR/Matchers.h"
61#include "llvm/ADT/STLExtras.h"
62
63namespace mlir {
64namespace acc {
65#define GEN_PASS_DEF_ACCCOMPUTELOWERING
66#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
67} // namespace acc
68} // namespace mlir
69
70#define DEBUG_TYPE "acc-compute-lowering"
71
72using namespace mlir;
73using namespace mlir::acc;
74
75namespace {
76
77//===----------------------------------------------------------------------===//
78// Helper functions
79//===----------------------------------------------------------------------===//
80
81/// Strip index_cast operations from a value before checking for a constant.
82static Value stripIndexCasts(Value val) {
83 while (auto castOp = val.getDefiningOp<arith::IndexCastOp>())
84 val = castOp.getIn();
85 return val;
86}
87
88template <typename ComputeOpT>
89static bool isGangWorkerVectorAllOne(ComputeOpT op) {
90 auto numGangs = op.getNumGangsValues();
91 if (numGangs.empty())
92 return false;
93 for (Value gangSize : numGangs) {
94 if (!isConstantIntValue(stripIndexCasts(gangSize), 1))
95 return false;
96 }
97 Value numWorkers = op.getNumWorkersValue();
98 if (!numWorkers)
99 return false;
100 Value vectorLength = op.getVectorLengthValue();
101 if (!vectorLength)
102 return false;
103 return isConstantIntValue(stripIndexCasts(numWorkers), 1) &&
104 isConstantIntValue(stripIndexCasts(vectorLength), 1);
105}
106
107/// A compute construct is "effectively serial" when it specifies
108/// num_gangs(1), num_workers(1), and vector_length(1). This is because
109/// these are the only parallelism dimensions expressible from OpenACC spec
110/// point-of-view and is consistent with how `serial` semantics are defined.
111template <typename ComputeOpT>
112static bool isEffectivelySerial(ComputeOpT op) {
113 return isGangWorkerVectorAllOne(op);
114}
115
116static bool isOpInComputeRegion(Operation *op) {
117 Region *region = op->getBlock()->getParent();
118 return getEnclosingComputeOp(*region) != nullptr;
119}
120
121static bool isOpInSerialRegion(Operation *op) {
122 if (auto parallelOp = op->getParentOfType<ParallelOp>())
123 return isEffectivelySerial(parallelOp);
124 if (auto kernelsOp = op->getParentOfType<KernelsOp>())
125 return isEffectivelySerial(kernelsOp);
126 if (op->getParentOfType<SerialOp>())
127 return true;
128 if (auto computeRegion = op->getParentOfType<ComputeRegionOp>())
129 return computeRegion.isEffectivelySerial();
130 if (auto funcOp = op->getParentOfType<FunctionOpInterface>()) {
131 if (isSpecializedAccRoutine(funcOp)) {
132 auto attr = funcOp->getAttrOfType<SpecializedRoutineAttr>(
134 if (attr && attr.getLevel().getValue() == ParLevel::seq)
135 return true;
136 }
137 }
138 return false;
139}
140
141static void setParDimsAttr(Operation *op, GPUParallelDimsAttr attr) {
142 op->setAttr(GPUParallelDimsAttr::name, attr);
143}
144
145/// Clone defining ops of constant live-in values into `region`, rewrite uses
146/// inside the region to the clones, and remove those values from
147/// `liveInValues` so they are not threaded through `acc.compute_region` ins.
148static void materializeConstantLiveInsIntoRegion(Region &region,
149 SetVector<Value> &liveInValues,
150 RewriterBase &rewriter) {
151 SmallVector<Value> constantLiveIns;
152 for (Value v : liveInValues) {
153 Operation *defOp = v.getDefiningOp();
154 if (defOp && matchPattern(defOp, m_Constant())) {
155 // As per the definition of ConstantLike trait, constants must have a
156 // single result.
157 assert(defOp->getNumResults() == 1 &&
158 "constants must have a single result");
159 constantLiveIns.push_back(v);
160 }
161 }
162 if (constantLiveIns.empty())
163 return;
164
165 OpBuilder::InsertionGuard guard(rewriter);
166 rewriter.setInsertionPointToStart(&region.front());
167
168 for (Value v : constantLiveIns) {
169 Value newV = rewriter.clone(*v.getDefiningOp())->getResult(0);
170 replaceAllUsesInRegionWith(v, newV, region);
171 liveInValues.remove(v);
172 }
173}
174
175/// Insert a parallel dimension into the list, maintaining order by
176/// GPUParallelDimAttr::getOrder (descending).
177static void insertParDim(SmallVectorImpl<GPUParallelDimAttr> &parDims,
178 GPUParallelDimAttr parDim) {
179 GPUParallelDimAttr *lb = llvm::lower_bound(
180 parDims, parDim,
181 [](const GPUParallelDimAttr &a, const GPUParallelDimAttr &b) {
182 return a.getOrder() > b.getOrder();
183 });
184 if (lb == parDims.end() || *lb != parDim)
185 parDims.insert(lb, parDim);
186}
187
188/// Map loop parallelism clauses (gang/worker/vector) to GPU parallel
189/// dimensions using the given mapping policy.
191getParallelDimensions(LoopOp loopOp, const ACCToGPUMappingPolicy &policy,
192 DeviceType deviceType) {
194 auto *ctx = loopOp->getContext();
195
196 if (loopOp.hasVector(deviceType))
197 insertParDim(parDims, policy.vectorDim(ctx));
198 if (loopOp.hasWorker(deviceType))
199 insertParDim(parDims, policy.workerDim(ctx));
200 if (auto gangDimValue = loopOp.getGangValue(GangArgType::Dim, deviceType)) {
201 if (auto gangDimDefOp =
202 gangDimValue.getDefiningOp<arith::ConstantIntOp>()) {
203 auto gangLevel = getGangParLevel(gangDimDefOp.value());
204 insertParDim(parDims, policy.gangDim(ctx, gangLevel));
205 }
206 } else if (loopOp.hasGang(deviceType)) {
207 insertParDim(parDims, policy.gangDim(ctx, ParLevel::gang_dim1));
208 }
209 return parDims;
210}
211
212/// Build `acc.compute_region` launch operands: one sequential `acc.par_width`
213/// for `acc.serial`, for `acc.parallel` / `acc.kernels` when every num_gangs
214/// operand and num_workers / vector_length are the constant 1, and otherwise
215/// `acc.par_width` from gang/worker/vector (device-type operands first, then
216/// default DeviceType::None).
217template <typename ComputeConstructT>
219assignKnownLaunchArgs(ComputeConstructT computeOp, DeviceType deviceType,
220 RewriterBase &rewriter,
221 const ACCToGPUMappingPolicy &policy) {
222 auto *ctx = rewriter.getContext();
223 auto loc = computeOp->getLoc();
224
225 if constexpr (std::is_same_v<ComputeConstructT, SerialOp>) {
226 return {ParWidthOp::create(rewriter, loc, Value(), policy.seqDim(ctx))};
227 } else if constexpr (llvm::is_one_of<ComputeConstructT, ParallelOp,
228 KernelsOp>::value) {
229 if (isEffectivelySerial(computeOp))
230 return {ParWidthOp::create(rewriter, loc, Value(), policy.seqDim(ctx))};
231
232 SmallVector<Value> values;
233 auto indexTy = rewriter.getIndexType();
234
235 auto numGangs = computeOp.getNumGangsValues(deviceType);
236 if (numGangs.empty())
237 numGangs = computeOp.getNumGangsValues();
238 for (auto [gangDimIdx, gangSize] : llvm::enumerate(numGangs)) {
239 auto gangLevel = getGangParLevel(gangDimIdx + 1);
240 values.push_back(ParWidthOp::create(
241 rewriter, loc,
242 getValueOrCreateCastToIndexLike(rewriter, gangSize.getLoc(), indexTy,
243 gangSize),
244 policy.gangDim(ctx, gangLevel)));
245 }
246
247 Value numWorkers = computeOp.getNumWorkersValue(deviceType);
248 if (!numWorkers)
249 numWorkers = computeOp.getNumWorkersValue();
250 if (numWorkers) {
251 values.push_back(ParWidthOp::create(
252 rewriter, loc,
253 getValueOrCreateCastToIndexLike(rewriter, numWorkers.getLoc(),
254 indexTy, numWorkers),
255 policy.workerDim(ctx)));
256 }
257
258 Value vectorLength = computeOp.getVectorLengthValue(deviceType);
259 if (!vectorLength)
260 vectorLength = computeOp.getVectorLengthValue();
261 if (vectorLength) {
262 values.push_back(ParWidthOp::create(
263 rewriter, loc,
264 getValueOrCreateCastToIndexLike(rewriter, vectorLength.getLoc(),
265 indexTy, vectorLength),
266 policy.vectorDim(ctx)));
267 }
268 return values;
269 } else {
270 llvm_unreachable("assignKnownLaunchArgs: expected parallel, kernels, or "
271 "serial");
272 }
273}
274
275//===----------------------------------------------------------------------===//
276// Loop conversion pattern
277//===----------------------------------------------------------------------===//
278
279class ACCLoopConversion : public OpRewritePattern<LoopOp> {
280public:
281 ACCLoopConversion(MLIRContext *ctx, const ACCToGPUMappingPolicy &policy,
282 DeviceType deviceType)
283 : OpRewritePattern<LoopOp>(ctx), policy(policy), deviceType(deviceType) {}
284
285 LogicalResult matchAndRewrite(LoopOp loopOp,
286 PatternRewriter &rewriter) const override {
287 if (loopOp.getUnstructured()) {
288 auto executeRegion =
290 if (!executeRegion)
291 return failure();
292 rewriter.replaceOp(loopOp, executeRegion);
293 return success();
294 }
295
296 LoopParMode parMode = loopOp.getDefaultOrDeviceTypeParallelism(deviceType);
297
298 if (parMode == LoopParMode::loop_seq || isOpInSerialRegion(loopOp)) {
299 // Although it might seem unintuitive, scf.parallel is used here because
300 // the parallelism of the loop is already predetermined (as sequential).
301 // scf.for will become a candidate for auto-parallelization analysis.
302 auto parallelOp = convertACCLoopToSCFParallel(loopOp, rewriter);
303 if (!parallelOp)
304 return failure();
305 setParDimsAttr(parallelOp,
306 GPUParallelDimsAttr::seq(loopOp->getContext()));
307 rewriter.replaceOp(loopOp, parallelOp);
308 } else if (parMode == LoopParMode::loop_auto) {
309 // All loops in serial regions should have already been handled.
310 assert(!isOpInSerialRegion(loopOp) &&
311 "Expected loop to be in non-serial region");
312 // Mark as scf.for to allow auto-parallelization analysis later.
313 auto forOp =
314 convertACCLoopToSCFFor(loopOp, rewriter, /*enableCollapse=*/true);
315 if (!forOp)
316 return failure();
317 rewriter.replaceOp(loopOp, forOp);
318 } else if (!isOpInComputeRegion(loopOp) &&
320 loopOp->getParentOfType<FunctionOpInterface>())) {
321 // This loop is an orphan `acc loop` but it is not in any sort
322 // of compute region. Thus it is just a sequential non-accelerator loop.
323 auto forOp =
324 convertACCLoopToSCFFor(loopOp, rewriter, /*enableCollapse=*/false);
325 if (!forOp)
326 return failure();
327 rewriter.replaceOp(loopOp, forOp);
328 } else {
329 assert(parMode == LoopParMode::loop_independent &&
330 "Expected loop to be independent");
331 auto parallelOp = convertACCLoopToSCFParallel(loopOp, rewriter);
332 if (!parallelOp)
333 return failure();
334
336 getParallelDimensions(loopOp, policy, deviceType);
337 if (!parDims.empty()) {
338 auto parDimsAttr =
339 GPUParallelDimsAttr::get(loopOp->getContext(), parDims);
340 setParDimsAttr(parallelOp, parDimsAttr);
341 }
342
343 rewriter.replaceOp(loopOp, parallelOp);
344 }
345 return success();
346 }
347
348private:
349 const ACCToGPUMappingPolicy &policy;
350 DeviceType deviceType;
351};
352
353//===----------------------------------------------------------------------===//
354// Compute construct conversion pattern
355//===----------------------------------------------------------------------===//
356
357template <typename ComputeConstructT>
358class ComputeOpConversion : public OpRewritePattern<ComputeConstructT> {
359public:
360 ComputeOpConversion(MLIRContext *ctx, const ACCToGPUMappingPolicy &policy,
361 DeviceType deviceType)
362 : OpRewritePattern<ComputeConstructT>(ctx), policy(policy),
363 deviceType(deviceType) {}
364
365 LogicalResult matchAndRewrite(ComputeConstructT computeOp,
366 PatternRewriter &rewriter) const override {
367 rewriter.setInsertionPoint(computeOp);
368 auto kernelEnv =
369 KernelEnvironmentOp::createAndPopulate(computeOp, rewriter);
370 auto launchArgs =
371 assignKnownLaunchArgs(computeOp, deviceType, rewriter, policy);
372 Region &region = computeOp.getRegion();
373 SetVector<Value> liveInValues;
374 getUsedValuesDefinedAbove(region, region, liveInValues);
375 materializeConstantLiveInsIntoRegion(region, liveInValues, rewriter);
376 IRMapping mapping;
377 auto computeRegion = buildComputeRegion(
378 computeOp->getLoc(), launchArgs, liveInValues.getArrayRef(),
379 ComputeConstructT::getOperationName(), region, rewriter, mapping);
380 if (!computeRegion) {
381 rewriter.eraseOp(kernelEnv);
382 return failure();
383 }
384 rewriter.eraseOp(computeOp);
385 return success();
386 }
387
388private:
389 const ACCToGPUMappingPolicy &policy;
390 DeviceType deviceType;
391};
392
393//===----------------------------------------------------------------------===//
394// Pass implementation
395//===----------------------------------------------------------------------===//
396
397class ACCComputeLowering
398 : public acc::impl::ACCComputeLoweringBase<ACCComputeLowering> {
399public:
400 using ACCComputeLoweringBase::ACCComputeLoweringBase;
401
402 void runOnOperation() override {
403 auto op = getOperation();
404 auto *context = op.getContext();
405
407
408 // Part 1: Convert acc.loop to scf.parallel/scf.for while the parent
409 // compute construct is still present (needed to determine conversion
410 // strategy).
411 RewritePatternSet loopPatterns(context);
412 loopPatterns.insert<ACCLoopConversion>(context, policy, deviceType);
413 if (failed(applyPatternsGreedily(op, std::move(loopPatterns))))
414 return signalPassFailure();
415
416 // Part 2: Convert acc.parallel, acc.kernels, and acc.serial to
417 // acc.kernel_environment { acc.compute_region { ... } }.
418 RewritePatternSet computePatterns(context);
419 computePatterns
420 .insert<ComputeOpConversion<ParallelOp>, ComputeOpConversion<KernelsOp>,
421 ComputeOpConversion<SerialOp>>(context, policy, deviceType);
422 if (failed(applyPatternsGreedily(op, std::move(computePatterns))))
423 return signalPassFailure();
424 }
425};
426
427} // namespace
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition Block.cpp:27
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:55
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:566
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:231
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:433
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:256
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:430
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
ParDimAttrT seqDim(MLIRContext *ctx) const
ParDimAttrT vectorDim(MLIRContext *ctx) const
ParDimAttrT workerDim(MLIRContext *ctx) const
ParDimAttrT gangDim(MLIRContext *ctx, ParLevel level) const
Convenience methods for specific parallelism levels.
Default policy that provides the standard GPU mapping: gang(dim:1) -> BlockX (gridDim....
Specialization of arith.constant op that returns an integer value.
Definition Arith.h:54
ParLevel getGangParLevel(int64_t gangDimValue)
Convert a gang dimension value (1, 2, or 3) to the corresponding ParLevel.
static constexpr StringLiteral getSpecializedRoutineAttrName()
Definition OpenACC.h:188
ComputeRegionOp buildComputeRegion(Location loc, ValueRange launchArgs, ValueRange inputArgs, llvm::StringRef origin, Region &regionToClone, RewriterBase &rewriter, IRMapping &mapping, ValueRange output={}, FlatSymbolRefAttr kernelFuncName={}, FlatSymbolRefAttr kernelModuleName={}, Value stream={}, ValueRange inputArgsToMap={})
Build an acc.compute_region operation by cloning a source region.
bool isSpecializedAccRoutine(mlir::Operation *op)
Used to check whether this is a specialized accelerator version of acc routine function.
Definition OpenACC.h:200
scf::ParallelOp convertACCLoopToSCFParallel(LoopOp loopOp, RewriterBase &rewriter)
Convert acc.loop to scf.parallel.
mlir::Operation * getEnclosingComputeOp(mlir::Region &region)
Used to obtain the enclosing compute construct operation that contains the provided region.
scf::ExecuteRegionOp convertUnstructuredACCLoopToSCFExecuteRegion(LoopOp loopOp, RewriterBase &rewriter)
Convert an unstructured acc.loop to scf.execute_region.
scf::ForOp convertACCLoopToSCFFor(LoopOp loopOp, RewriterBase &rewriter, bool enableCollapse)
Convert a structured acc.loop to scf.for.
ACCParMappingPolicy< mlir::acc::GPUParallelDimAttr > ACCToGPUMappingPolicy
Type alias for the GPU-specific mapping policy.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
void replaceAllUsesInRegionWith(Value orig, Value replacement, Region &region)
Replace all uses of orig within the given region with replacement.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:125
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Definition Utils.cpp:122
void getUsedValuesDefinedAbove(Region &region, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...