MLIR 23.0.0git
ACCComputeLowering.cpp
Go to the documentation of this file.
1//===- ACCComputeLowering.cpp - Lower ACC compute to compute_region -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass decomposes OpenACC compute constructs into a representation that
10// separates the data environment from the compute portion and prepares for
11// parallelism assignment and privatization at the appropriate level.
12//
13// Overview:
14// ---------
15// Each compute construct (`acc.parallel`, `acc.serial`, `acc.kernels`) is
16// lowered to (1) `acc.kernel_environment`, which captures the data environment
17// and (2) `acc.compute_region`, which holds the compute body. Inside the
18// compute region, acc.loop is converted to SCF loops (`scf.parallel` or
19// `scf.for`) with any predetermined parallelism expressed as `par_dims`. This
20// decomposition allows later phases to assign parallelism and handle
21// privatization at the right granularity.
22//
23// Transformations:
24// ----------------
25// 1. Compute constructs: acc.parallel, acc.serial, and acc.kernels are
26// replaced by acc.kernel_environment containing a single acc.compute_region.
27// Launch arguments (num_gangs, num_workers, vector_length) become
28// acc.par_width ops (each result is `index`) and are passed as
29// compute_region launch operands (still required to be acc.par_width
30// results by the compute_region verifier).
31//
32// 2. acc.loop: Converted according to context and attributes:
33// - Unstructured: body wrapped in scf.execute_region.
34// - Sequential (serial region or seq clause): scf.parallel with
35// par_dims = sequential.
36// - Auto (in parallel/kernels): scf.for with collapse when
37// multi-dimensional.
38// - Orphan (not inside a compute construct): scf.for, no collapse.
39// - Independent (in parallel/kernels): scf.parallel with par_dims from
40// gang/worker/vector mapping (e.g. block_x).
41//
42//===----------------------------------------------------------------------===//
43
45
54#include "mlir/IR/IRMapping.h"
55#include "mlir/IR/Matchers.h"
59
60namespace mlir {
61namespace acc {
62#define GEN_PASS_DEF_ACCCOMPUTELOWERING
63#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
64} // namespace acc
65} // namespace mlir
66
67#define DEBUG_TYPE "acc-compute-lowering"
68
69using namespace mlir;
70using namespace mlir::acc;
71
72namespace {
73
74//===----------------------------------------------------------------------===//
75// Helper functions
76//===----------------------------------------------------------------------===//
77
78/// Strip index_cast operations from a value before checking for a constant.
79static Value stripIndexCasts(Value val) {
80 while (auto castOp = val.getDefiningOp<arith::IndexCastOp>())
81 val = castOp.getIn();
82 return val;
83}
84
85/// A parallel construct is "effectively serial" when it specifies
86/// num_gangs(1), num_workers(1), and vector_length(1). This matches
87/// the semantics of acc.serial but expressed through acc.parallel.
88static bool isEffectivelySerial(ParallelOp op) {
89 auto numGangs = op.getNumGangsValues();
90 if (numGangs.size() != 1)
91 return false;
92 Value numWorkers = op.getNumWorkersValue();
93 if (!numWorkers)
94 return false;
95 Value vectorLength = op.getVectorLengthValue();
96 if (!vectorLength)
97 return false;
98 return isConstantIntValue(stripIndexCasts(numGangs.front()), 1) &&
99 isConstantIntValue(stripIndexCasts(numWorkers), 1) &&
100 isConstantIntValue(stripIndexCasts(vectorLength), 1);
101}
102
103static bool isOpInComputeRegion(Operation *op) {
104 Region *region = op->getBlock()->getParent();
105 return getEnclosingComputeOp(*region) != nullptr;
106}
107
108static bool isOpInSerialRegion(Operation *op) {
109 if (auto parallelOp = op->getParentOfType<ParallelOp>())
110 return isEffectivelySerial(parallelOp);
111 if (auto computeRegion = op->getParentOfType<ComputeRegionOp>())
112 return computeRegion.isEffectivelySerial();
113 if (op->getParentOfType<SerialOp>())
114 return true;
115 if (auto funcOp = op->getParentOfType<FunctionOpInterface>()) {
116 if (isSpecializedAccRoutine(funcOp)) {
117 auto attr = funcOp->getAttrOfType<SpecializedRoutineAttr>(
119 if (attr && attr.getLevel().getValue() == ParLevel::seq)
120 return true;
121 }
122 }
123 return false;
124}
125
126static void setParDimsAttr(Operation *op, GPUParallelDimsAttr attr) {
127 op->setAttr(GPUParallelDimsAttr::name, attr);
128}
129
130/// Clone defining ops of constant live-in values into `region`, rewrite uses
131/// inside the region to the clones, and remove those values from
132/// `liveInValues` so they are not threaded through `acc.compute_region` ins.
133static void materializeConstantLiveInsIntoRegion(Region &region,
134 SetVector<Value> &liveInValues,
135 RewriterBase &rewriter) {
136 SmallVector<Value> constantLiveIns;
137 for (Value v : liveInValues) {
138 Operation *defOp = v.getDefiningOp();
139 if (defOp && matchPattern(defOp, m_Constant())) {
140 // As per the definition of ConstantLike trait, constants must have a
141 // single result.
142 assert(defOp->getNumResults() == 1 &&
143 "constants must have a single result");
144 constantLiveIns.push_back(v);
145 }
146 }
147 if (constantLiveIns.empty())
148 return;
149
150 OpBuilder::InsertionGuard guard(rewriter);
151 rewriter.setInsertionPointToStart(&region.front());
152
153 for (Value v : constantLiveIns) {
154 Value newV = rewriter.clone(*v.getDefiningOp())->getResult(0);
155 replaceAllUsesInRegionWith(v, newV, region);
156 liveInValues.remove(v);
157 }
158}
159
160/// Insert a parallel dimension into the list, maintaining order by
161/// GPUParallelDimAttr::getOrder (descending).
162static void insertParDim(SmallVectorImpl<GPUParallelDimAttr> &parDims,
163 GPUParallelDimAttr parDim) {
164 GPUParallelDimAttr *lb = llvm::lower_bound(
165 parDims, parDim,
166 [](const GPUParallelDimAttr &a, const GPUParallelDimAttr &b) {
167 return a.getOrder() > b.getOrder();
168 });
169 if (lb == parDims.end() || *lb != parDim)
170 parDims.insert(lb, parDim);
171}
172
173/// Map loop parallelism clauses (gang/worker/vector) to GPU parallel
174/// dimensions using the given mapping policy.
176getParallelDimensions(LoopOp loopOp, const ACCToGPUMappingPolicy &policy,
177 DeviceType deviceType) {
179 auto *ctx = loopOp->getContext();
180
181 if (loopOp.hasVector(deviceType))
182 insertParDim(parDims, policy.vectorDim(ctx));
183 if (loopOp.hasWorker(deviceType))
184 insertParDim(parDims, policy.workerDim(ctx));
185 if (auto gangDimValue = loopOp.getGangValue(GangArgType::Dim, deviceType)) {
186 if (auto gangDimDefOp =
187 gangDimValue.getDefiningOp<arith::ConstantIntOp>()) {
188 auto gangLevel = getGangParLevel(gangDimDefOp.value());
189 insertParDim(parDims, policy.gangDim(ctx, gangLevel));
190 }
191 } else if (loopOp.hasGang(deviceType)) {
192 insertParDim(parDims, policy.gangDim(ctx, ParLevel::gang_dim1));
193 }
194 return parDims;
195}
196
197/// Create acc.par_width operations from gang/worker/vector values of a
198/// compute construct. Queries the device-type-specific values first, falling
199/// back to the default (DeviceType::None) values.
200template <typename ComputeConstructT>
202assignKnownLaunchArgs(ComputeConstructT computeOp, DeviceType deviceType,
203 RewriterBase &rewriter,
204 const ACCToGPUMappingPolicy &policy) {
205 SmallVector<Value> values;
206 auto *ctx = rewriter.getContext();
207 auto indexTy = rewriter.getIndexType();
208 auto loc = computeOp->getLoc();
209
210 auto numGangs = computeOp.getNumGangsValues(deviceType);
211 if (numGangs.empty())
212 numGangs = computeOp.getNumGangsValues();
213 for (auto [gangDimIdx, gangSize] : llvm::enumerate(numGangs)) {
214 auto gangLevel = getGangParLevel(gangDimIdx + 1);
215 values.push_back(
216 ParWidthOp::create(rewriter, loc,
218 rewriter, gangSize.getLoc(), indexTy, gangSize),
219 policy.gangDim(ctx, gangLevel)));
220 }
221
222 Value numWorkers = computeOp.getNumWorkersValue(deviceType);
223 if (!numWorkers)
224 numWorkers = computeOp.getNumWorkersValue();
225 if (numWorkers) {
226 values.push_back(ParWidthOp::create(
227 rewriter, loc,
228 getValueOrCreateCastToIndexLike(rewriter, numWorkers.getLoc(), indexTy,
229 numWorkers),
230 policy.workerDim(ctx)));
231 }
232
233 Value vectorLength = computeOp.getVectorLengthValue(deviceType);
234 if (!vectorLength)
235 vectorLength = computeOp.getVectorLengthValue();
236 if (vectorLength) {
237 values.push_back(ParWidthOp::create(
238 rewriter, loc,
239 getValueOrCreateCastToIndexLike(rewriter, vectorLength.getLoc(),
240 indexTy, vectorLength),
241 policy.vectorDim(ctx)));
242 }
243 return values;
244}
245
246/// SerialOp has no gang/worker/vector clauses.
247template <>
249assignKnownLaunchArgs<SerialOp>(SerialOp, DeviceType, RewriterBase &,
250 const ACCToGPUMappingPolicy &) {
251 return {};
252}
253
254//===----------------------------------------------------------------------===//
255// Loop conversion pattern
256//===----------------------------------------------------------------------===//
257
258class ACCLoopConversion : public OpRewritePattern<LoopOp> {
259public:
260 ACCLoopConversion(MLIRContext *ctx, const ACCToGPUMappingPolicy &policy,
261 DeviceType deviceType)
262 : OpRewritePattern<LoopOp>(ctx), policy(policy), deviceType(deviceType) {}
263
264 LogicalResult matchAndRewrite(LoopOp loopOp,
265 PatternRewriter &rewriter) const override {
266 if (loopOp.getUnstructured()) {
267 auto executeRegion =
269 if (!executeRegion)
270 return failure();
271 rewriter.replaceOp(loopOp, executeRegion);
272 return success();
273 }
274
275 LoopParMode parMode = loopOp.getDefaultOrDeviceTypeParallelism(deviceType);
276
277 if (parMode == LoopParMode::loop_seq || isOpInSerialRegion(loopOp)) {
278 // Although it might seem unintuitive, scf.parallel is used here because
279 // the parallelism of the loop is already predetermined (as sequential).
280 // scf.for will become a candidate for auto-parallelization analysis.
281 auto parallelOp = convertACCLoopToSCFParallel(loopOp, rewriter);
282 if (!parallelOp)
283 return failure();
284 setParDimsAttr(parallelOp,
285 GPUParallelDimsAttr::seq(loopOp->getContext()));
286 rewriter.replaceOp(loopOp, parallelOp);
287 } else if (parMode == LoopParMode::loop_auto) {
288 // All loops in serial regions should have already been handled.
289 assert(!isOpInSerialRegion(loopOp) &&
290 "Expected loop to be in non-serial region");
291 // Mark as scf.for to allow auto-parallelization analysis later.
292 auto forOp =
293 convertACCLoopToSCFFor(loopOp, rewriter, /*enableCollapse=*/true);
294 if (!forOp)
295 return failure();
296 rewriter.replaceOp(loopOp, forOp);
297 } else if (!isOpInComputeRegion(loopOp) &&
299 loopOp->getParentOfType<FunctionOpInterface>())) {
300 // This loop is an orphan `acc loop` but it is not in any sort
301 // of compute region. Thus it is just a sequential non-accelerator loop.
302 auto forOp =
303 convertACCLoopToSCFFor(loopOp, rewriter, /*enableCollapse=*/false);
304 if (!forOp)
305 return failure();
306 rewriter.replaceOp(loopOp, forOp);
307 } else {
308 assert(parMode == LoopParMode::loop_independent &&
309 "Expected loop to be independent");
310 auto parallelOp = convertACCLoopToSCFParallel(loopOp, rewriter);
311 if (!parallelOp)
312 return failure();
313
315 getParallelDimensions(loopOp, policy, deviceType);
316 if (!parDims.empty()) {
317 auto parDimsAttr =
318 GPUParallelDimsAttr::get(loopOp->getContext(), parDims);
319 setParDimsAttr(parallelOp, parDimsAttr);
320 }
321
322 rewriter.replaceOp(loopOp, parallelOp);
323 }
324 return success();
325 }
326
327private:
328 const ACCToGPUMappingPolicy &policy;
329 DeviceType deviceType;
330};
331
332//===----------------------------------------------------------------------===//
333// Compute construct conversion pattern
334//===----------------------------------------------------------------------===//
335
336template <typename ComputeConstructT>
337class ComputeOpConversion : public OpRewritePattern<ComputeConstructT> {
338public:
339 ComputeOpConversion(MLIRContext *ctx, const ACCToGPUMappingPolicy &policy,
340 DeviceType deviceType)
341 : OpRewritePattern<ComputeConstructT>(ctx), policy(policy),
342 deviceType(deviceType) {}
343
344 LogicalResult matchAndRewrite(ComputeConstructT computeOp,
345 PatternRewriter &rewriter) const override {
346 rewriter.setInsertionPoint(computeOp);
347 auto kernelEnv =
348 KernelEnvironmentOp::createAndPopulate(computeOp, rewriter);
349 auto launchArgs =
350 assignKnownLaunchArgs(computeOp, deviceType, rewriter, policy);
351 Region &region = computeOp.getRegion();
352 SetVector<Value> liveInValues;
353 getUsedValuesDefinedAbove(region, region, liveInValues);
354 materializeConstantLiveInsIntoRegion(region, liveInValues, rewriter);
355 IRMapping mapping;
356 auto computeRegion = buildComputeRegion(
357 computeOp->getLoc(), launchArgs, liveInValues.getArrayRef(),
358 ComputeConstructT::getOperationName(), region, rewriter, mapping);
359 if (!computeRegion) {
360 rewriter.eraseOp(kernelEnv);
361 return failure();
362 }
363 rewriter.eraseOp(computeOp);
364 return success();
365 }
366
367private:
368 const ACCToGPUMappingPolicy &policy;
369 DeviceType deviceType;
370};
371
372//===----------------------------------------------------------------------===//
373// Pass implementation
374//===----------------------------------------------------------------------===//
375
376class ACCComputeLowering
377 : public acc::impl::ACCComputeLoweringBase<ACCComputeLowering> {
378public:
379 using ACCComputeLoweringBase::ACCComputeLoweringBase;
380
381 void runOnOperation() override {
382 auto op = getOperation();
383 auto *context = op.getContext();
384
386
387 // Part 1: Convert acc.loop to scf.parallel/scf.for while the parent
388 // compute construct is still present (needed to determine conversion
389 // strategy).
390 RewritePatternSet loopPatterns(context);
391 loopPatterns.insert<ACCLoopConversion>(context, policy, deviceType);
392 if (failed(applyPatternsGreedily(op, std::move(loopPatterns))))
393 return signalPassFailure();
394
395 // Part 2: Convert acc.parallel, acc.kernels, and acc.serial to
396 // acc.kernel_environment { acc.compute_region { ... } }.
397 RewritePatternSet computePatterns(context);
398 computePatterns
399 .insert<ComputeOpConversion<ParallelOp>, ComputeOpConversion<KernelsOp>,
400 ComputeOpConversion<SerialOp>>(context, policy, deviceType);
401 if (failed(applyPatternsGreedily(op, std::move(computePatterns))))
402 return signalPassFailure();
403 }
404};
405
406} // namespace
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition Block.cpp:27
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:55
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:566
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:231
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:433
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:256
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:430
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
ParDimAttrT vectorDim(MLIRContext *ctx) const
ParDimAttrT workerDim(MLIRContext *ctx) const
ParDimAttrT gangDim(MLIRContext *ctx, ParLevel level) const
Convenience methods for specific parallelism levels.
Default policy that provides the standard GPU mapping: gang(dim:1) -> BlockX (gridDim....
Specialization of arith.constant op that returns an integer value.
Definition Arith.h:54
ParLevel getGangParLevel(int64_t gangDimValue)
Convert a gang dimension value (1, 2, or 3) to the corresponding ParLevel.
static constexpr StringLiteral getSpecializedRoutineAttrName()
Definition OpenACC.h:186
ComputeRegionOp buildComputeRegion(Location loc, ValueRange launchArgs, ValueRange inputArgs, llvm::StringRef origin, Region &regionToClone, RewriterBase &rewriter, IRMapping &mapping, ValueRange output={}, FlatSymbolRefAttr kernelFuncName={}, FlatSymbolRefAttr kernelModuleName={}, Value stream={}, ValueRange inputArgsToMap={})
Build an acc.compute_region operation by cloning a source region.
bool isSpecializedAccRoutine(mlir::Operation *op)
Used to check whether this is a specialized accelerator version of acc routine function.
Definition OpenACC.h:198
scf::ParallelOp convertACCLoopToSCFParallel(LoopOp loopOp, RewriterBase &rewriter)
Convert acc.loop to scf.parallel.
mlir::Operation * getEnclosingComputeOp(mlir::Region &region)
Used to obtain the enclosing compute construct operation that contains the provided region.
scf::ExecuteRegionOp convertUnstructuredACCLoopToSCFExecuteRegion(LoopOp loopOp, RewriterBase &rewriter)
Convert an unstructured acc.loop to scf.execute_region.
scf::ForOp convertACCLoopToSCFFor(LoopOp loopOp, RewriterBase &rewriter, bool enableCollapse)
Convert a structured acc.loop to scf.for.
ACCParMappingPolicy< mlir::acc::GPUParallelDimAttr > ACCToGPUMappingPolicy
Type alias for the GPU-specific mapping policy.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
void replaceAllUsesInRegionWith(Value orig, Value replacement, Region &region)
Replace all uses of orig within the given region with replacement.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:125
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Definition Utils.cpp:122
void getUsedValuesDefinedAbove(Region &region, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...