MLIR 23.0.0git
ACCIfClauseLowering.cpp
Go to the documentation of this file.
1//===- ACCIfClauseLowering.cpp - Lower ACC compute construct if clauses --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass lowers OpenACC compute constructs (parallel, kernels, serial) with
10// `if` clauses using region specialization. It creates two execution paths:
11// device execution when the condition is true, host execution when false.
12//
13// Overview:
14// ---------
15// When an ACC compute construct has an `if` clause, the construct should only
16// execute on the device when the condition is true. If the condition is false,
17// the code should execute on the host instead. This pass transforms:
18//
19// acc.parallel if(%cond) { ... }
20//
21// Into:
22//
23// scf.if %cond {
24// // Device path: clone data ops, compute construct without if, exit ops
25// acc.parallel { ... }
26// } else {
27// // Host path: original region body with ACC ops converted to host
28// }
29//
30// Transformations:
31// ----------------
32// For each compute construct with an `if` clause:
33//
34// 1. Device Path (true branch):
35// - Clone data entry operations (acc.copyin, acc.create, etc.)
36// - Clone the compute construct without the `if` clause
37// - Clone data exit operations (acc.copyout, acc.delete, etc.)
38//
39// 2. Host Path (false branch):
40// - Move the original region body to the else branch
41// - Apply host fallback patterns to convert ACC ops to host equivalents
42//
43// 3. Cleanup:
44// - Erase the original compute construct and data operations
45// - Replace uses of ACC variables with host variables in the else branch
46//
47// Requirements:
48// -------------
49// To use this pass in a pipeline, the following requirements exist:
50//
51// 1. Analysis Registration (Optional): If custom behavior is needed for
52// emitting not-yet-implemented messages for unsupported cases, the pipeline
53// should pre-register the `acc::OpenACCSupport` analysis.
54//
55//===----------------------------------------------------------------------===//
56
58
64#include "mlir/IR/Builders.h"
65#include "mlir/IR/IRMapping.h"
68#include "llvm/ADT/STLExtras.h"
69#include "llvm/Support/Debug.h"
70
71namespace mlir {
72namespace acc {
73#define GEN_PASS_DEF_ACCIFCLAUSELOWERING
74#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
75} // namespace acc
76} // namespace mlir
77
78#define DEBUG_TYPE "acc-if-clause-lowering"
79
80using namespace mlir;
81using namespace mlir::acc;
82
83namespace {
84
85class ACCIfClauseLowering
86 : public acc::impl::ACCIfClauseLoweringBase<ACCIfClauseLowering> {
87 using ACCIfClauseLoweringBase<ACCIfClauseLowering>::ACCIfClauseLoweringBase;
88
89private:
90 OpenACCSupport *accSupport = nullptr;
91
92 void convertHostRegion(Operation *computeOp, Region &region);
93
94 template <typename OpTy>
95 void lowerIfClauseForComputeConstruct(OpTy computeConstructOp,
96 SmallVector<Operation *> &eraseOps);
97
98public:
99 void runOnOperation() override;
100};
101
102void ACCIfClauseLowering::convertHostRegion(Operation *computeOp,
103 Region &region) {
104 // Only collect ACC dialect operations - other ops don't need conversion
106 region.walk<WalkOrder::PreOrder>([&](Operation *op) {
107 if (isa<acc::OpenACCDialect>(op->getDialect()))
108 hostOps.push_back(op);
109 });
110
111 RewritePatternSet patterns(computeOp->getContext());
112 populateACCHostFallbackPatterns(patterns, *accSupport);
113
114 GreedyRewriteConfig config;
115 config.setUseTopDownTraversal(true);
117 if (failed(applyOpPatternsGreedily(hostOps, std::move(patterns), config)))
118 accSupport->emitNYI(computeOp->getLoc(), "failed to convert host region");
119}
120
121// Template function to handle if condition conversion for ACC compute
122// constructs
123template <typename OpTy>
124void ACCIfClauseLowering::lowerIfClauseForComputeConstruct(
125 OpTy computeConstructOp, SmallVector<Operation *> &eraseOps) {
126 Value ifCond = computeConstructOp.getIfCond();
127 if (!ifCond)
128 return;
129
130 IRRewriter rewriter(computeConstructOp);
131
132 LLVM_DEBUG(llvm::dbgs() << "Converting " << computeConstructOp->getName()
133 << " with if condition: " << computeConstructOp
134 << "\n");
135
136 // Collect data clause operations that need to be recreated in the if
137 // condition
138 SmallVector<Operation *> dataEntryOps;
139 SmallVector<Operation *> dataExitOps;
140 SmallVector<Operation *> firstprivateOps;
141 SmallVector<Operation *> privateOps;
142 SmallVector<Operation *> reductionOps;
143
144 // Collect data entry operations
145 for (Value operand : computeConstructOp.getDataClauseOperands())
146 if (Operation *defOp = operand.getDefiningOp())
147 if (isa<ACC_DATA_ENTRY_OPS>(defOp))
148 dataEntryOps.push_back(defOp);
149
150 // Find corresponding exit operations for each entry operation.
151 // Iterate backwards through entry ops since exit ops appear in reverse order.
152 for (Operation *dataEntryOp : llvm::reverse(dataEntryOps))
153 for (Operation *user : dataEntryOp->getUsers())
154 if (isa<ACC_DATA_EXIT_OPS>(user))
155 dataExitOps.push_back(user);
156
157 // Collect firstprivate, private, and reduction operations
158 auto collectOps = [&](SmallVector<Operation *> &ops, OperandRange operands) {
159 for (Value operand : operands)
160 if (Operation *defOp = operand.getDefiningOp())
161 ops.push_back(defOp);
162 };
163 collectOps(firstprivateOps, computeConstructOp.getFirstprivateOperands());
164 collectOps(privateOps, computeConstructOp.getPrivateOperands());
165 collectOps(reductionOps, computeConstructOp.getReductionOperands());
166
167 // Create scf.if with device and host execution paths
168 auto ifOp = scf::IfOp::create(rewriter, computeConstructOp.getLoc(),
169 TypeRange{}, ifCond, /*withElseRegion=*/true);
170
171 LLVM_DEBUG(llvm::dbgs() << "Cloning " << dataEntryOps.size()
172 << " data entry operations for device path\n");
173
174 // Device execution path (true branch)
175 Block &thenBlock = ifOp.getThenRegion().front();
176 rewriter.setInsertionPointToStart(&thenBlock);
177
178 // Clone data entry operations
179 SmallVector<Value> deviceDataOperands;
180 SmallVector<Value> firstprivateOperands;
181 SmallVector<Value> privateOperands;
182 SmallVector<Value> reductionOperands;
183
184 // Map the data entry and firstprivate ops for the cloned region
185 IRMapping deviceMapping;
186 auto cloneAndMapOps = [&](SmallVector<Operation *> &ops,
187 SmallVector<Value> &operands) {
188 for (Operation *op : ops) {
189 Operation *clonedOp = rewriter.clone(*op, deviceMapping);
190 operands.push_back(clonedOp->getResult(0));
191 deviceMapping.map(op->getResult(0), clonedOp->getResult(0));
192 }
193 };
194 cloneAndMapOps(dataEntryOps, deviceDataOperands);
195 cloneAndMapOps(firstprivateOps, firstprivateOperands);
196 cloneAndMapOps(privateOps, privateOperands);
197 cloneAndMapOps(reductionOps, reductionOperands);
198
199 // Create new compute op without if condition for device execution by
200 // cloning
201 OpTy newComputeOp = cast<OpTy>(
202 rewriter.clone(*computeConstructOp.getOperation(), deviceMapping));
203 newComputeOp.getIfCondMutable().clear();
204 newComputeOp.getDataClauseOperandsMutable().assign(deviceDataOperands);
205 newComputeOp.getFirstprivateOperandsMutable().assign(firstprivateOperands);
206 newComputeOp.getPrivateOperandsMutable().assign(privateOperands);
207 newComputeOp.getReductionOperandsMutable().assign(reductionOperands);
208
209 // Clone data exit operations
210 rewriter.setInsertionPointAfter(newComputeOp);
211 for (Operation *dataOp : dataExitOps)
212 rewriter.clone(*dataOp, deviceMapping);
213
214 rewriter.setInsertionPointToEnd(&thenBlock);
215 if (!thenBlock.getTerminator())
216 scf::YieldOp::create(rewriter, computeConstructOp.getLoc());
217
218 // Host execution path (false branch)
219 Region &hostRegion = computeConstructOp.getRegion();
220 if (hostRegion.hasOneBlock()) {
221 // Don't need to clone original ops, just take them and legalize for host.
222 ifOp.getElseRegion().takeBody(hostRegion);
223
224 // Swap acc yield for scf yield.
225 Block &elseBlock = ifOp.getElseRegion().front();
226 elseBlock.getTerminator()->erase();
227 rewriter.setInsertionPointToEnd(&elseBlock);
228 scf::YieldOp::create(rewriter, computeConstructOp.getLoc());
229
230 convertHostRegion(computeConstructOp, ifOp.getElseRegion());
231 } else {
232 // scf.if regions must stay single-block. Wrap the original multi-block ACC
233 // body in scf.execute_region so it can be hosted in the else branch.
234 Block &elseBlock = ifOp.getElseRegion().front();
235 rewriter.setInsertionPoint(elseBlock.getTerminator());
236 IRMapping hostMapping;
237 auto hostExecuteRegion = wrapMultiBlockRegionWithSCFExecuteRegion(
238 hostRegion, hostMapping, computeConstructOp.getLoc(), rewriter);
239 convertHostRegion(computeConstructOp, hostExecuteRegion.getRegion());
240 }
241
242 // The original op is now empty and can be erased
243 eraseOps.push_back(computeConstructOp);
244
245 // TODO: Can probably 'move' the data ops instead of cloning them
246 // which would eliminate need to explicitly erase
247 for (Operation *dataOp : dataExitOps)
248 eraseOps.push_back(dataOp);
249
250 // The new host code may contain uses of the acc variables. Replace them by
251 // the host values.
252 auto replaceAndEraseOps = [&](SmallVector<Operation *> &ops) {
253 for (Operation *op : ops) {
255 eraseOps.push_back(op);
256 }
257 };
258 replaceAndEraseOps(dataEntryOps);
259 replaceAndEraseOps(firstprivateOps);
260 replaceAndEraseOps(privateOps);
261 replaceAndEraseOps(reductionOps);
262}
263
264void ACCIfClauseLowering::runOnOperation() {
265 func::FuncOp funcOp = getOperation();
266 accSupport = &getAnalysis<OpenACCSupport>();
267
269 funcOp.walk([&](Operation *op) {
270 if (auto parallelOp = dyn_cast<acc::ParallelOp>(op))
271 lowerIfClauseForComputeConstruct(parallelOp, eraseOps);
272 else if (auto kernelsOp = dyn_cast<acc::KernelsOp>(op))
273 lowerIfClauseForComputeConstruct(kernelsOp, eraseOps);
274 else if (auto serialOp = dyn_cast<acc::SerialOp>(op))
275 lowerIfClauseForComputeConstruct(serialOp, eraseOps);
276 });
277
278 for (Operation *op : eraseOps)
279 op->erase();
280}
281
282} // namespace
Block represents an ordered list of Operations.
Definition Block.h:33
Operation & front()
Definition Block.h:163
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
This class allows control over how the GreedyPatternRewriteDriver works.
GreedyRewriteConfig & setUseTopDownTraversal(bool use=true)
GreedyRewriteConfig & setStrictness(GreedyRewriteStrictness mode)
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:566
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:438
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:414
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:44
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:433
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:241
user_range getUsers()
Returns a range of all users.
Definition Operation.h:899
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:234
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
RetT walk(FnT &&callback)
Walk all nested operations, blocks or regions (including this region), depending on the type of callb...
Definition Region.h:296
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
void replaceAllUsesWith(Value newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition Value.h:149
InFlightDiagnostic emitNYI(Location loc, const Twine &message)
Report a case that is not yet supported by the implementation.
mlir::Value getAccVar(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation.
Definition OpenACC.cpp:5098
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
Definition OpenACC.cpp:5067
scf::ExecuteRegionOp wrapMultiBlockRegionWithSCFExecuteRegion(Region &region, IRMapping &mapping, Location loc, RewriterBase &rewriter)
Wrap a multi-block region in an scf.execute_region.
void populateACCHostFallbackPatterns(RewritePatternSet &patterns, OpenACCSupport &accSupport, bool enableLoopConversion=true)
Populates all patterns for host fallback path (when if clause evaluates to false).
Include the generated interface declarations.
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
@ ExistingOps
Only pre-existing ops are processed.