MLIR 23.0.0git
ACCIfClauseLowering.cpp
Go to the documentation of this file.
1//===- ACCIfClauseLowering.cpp - Lower ACC compute construct if clauses --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass lowers OpenACC compute constructs (parallel, kernels, serial) with
10// `if` clauses using region specialization. It creates two execution paths:
11// device execution when the condition is true, host execution when false.
12//
13// Overview:
14// ---------
15// When an ACC compute construct has an `if` clause, the construct should only
16// execute on the device when the condition is true. If the condition is false,
17// the code should execute on the host instead. This pass transforms:
18//
19// acc.parallel if(%cond) { ... }
20//
21// Into:
22//
23// scf.if %cond {
24// // Device path: clone data ops, compute construct without if, exit ops
25// acc.parallel { ... }
26// } else {
27// // Host path: original region body with ACC ops converted to host
28// }
29//
30// Transformations:
31// ----------------
32// For each compute construct with an `if` clause:
33//
34// 1. Device Path (true branch):
35// - Clone data entry operations (acc.copyin, acc.create, etc.)
36// - Clone the compute construct without the `if` clause
37// - Clone data exit operations (acc.copyout, acc.delete, etc.)
38//
39// 2. Host Path (false branch):
40// - Move the original region body to the else branch
41// - Apply host fallback patterns to convert ACC ops to host equivalents
42//
43// 3. Cleanup:
44// - Erase the original compute construct and data operations
45// - Replace uses of ACC variables with host variables in the else branch
46//
47// Requirements:
48// -------------
49// To use this pass in a pipeline, the following requirements exist:
50//
51// 1. Analysis Registration (Optional): If custom behavior is needed for
52// emitting not-yet-implemented messages for unsupported cases, the pipeline
53// should pre-register the `acc::OpenACCSupport` analysis.
54//
55//===----------------------------------------------------------------------===//
56
64#include "mlir/IR/IRMapping.h"
67#include "llvm/ADT/STLExtras.h"
68#include "llvm/Support/Debug.h"
69
70namespace mlir {
71namespace acc {
72#define GEN_PASS_DEF_ACCIFCLAUSELOWERING
73#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
74} // namespace acc
75} // namespace mlir
76
77#define DEBUG_TYPE "acc-if-clause-lowering"
78
79using namespace mlir;
80using namespace mlir::acc;
81
82namespace {
83
84class ACCIfClauseLowering
85 : public acc::impl::ACCIfClauseLoweringBase<ACCIfClauseLowering> {
86 using ACCIfClauseLoweringBase<ACCIfClauseLowering>::ACCIfClauseLoweringBase;
87
88private:
89 OpenACCSupport *accSupport = nullptr;
91 void convertHostRegion(Operation *computeOp, Region &region);
92
93 template <typename OpTy>
94 void lowerIfClauseForComputeConstruct(OpTy computeConstructOp,
95 SmallVector<Operation *> &eraseOps);
96
97public:
98 void runOnOperation() override;
99};
100
101void ACCIfClauseLowering::convertHostRegion(Operation *computeOp,
102 Region &region) {
103 // Only collect ACC dialect operations - other ops don't need conversion
105 region.walk<WalkOrder::PreOrder>([&](Operation *op) {
106 if (isa<acc::OpenACCDialect>(op->getDialect()))
107 hostOps.push_back(op);
108 });
109
112
114 config.setUseTopDownTraversal(true);
116 if (failed(applyOpPatternsGreedily(hostOps, std::move(patterns), config)))
117 accSupport->emitNYI(computeOp->getLoc(), "failed to convert host region");
118}
119
120// Template function to handle if condition conversion for ACC compute
121// constructs
122template <typename OpTy>
123void ACCIfClauseLowering::lowerIfClauseForComputeConstruct(
124 OpTy computeConstructOp, SmallVector<Operation *> &eraseOps) {
125 Value ifCond = computeConstructOp.getIfCond();
126 if (!ifCond)
127 return;
128
129 IRRewriter rewriter(computeConstructOp);
130
131 LLVM_DEBUG(llvm::dbgs() << "Converting " << computeConstructOp->getName()
132 << " with if condition: " << computeConstructOp
133 << "\n");
134
135 // Collect data clause operations that need to be recreated in the if
136 // condition
137 SmallVector<Operation *> dataEntryOps;
138 SmallVector<Operation *> dataExitOps;
139 SmallVector<Operation *> firstprivateOps;
140 SmallVector<Operation *> reductionOps;
141
142 // Collect data entry operations
143 for (Value operand : computeConstructOp.getDataClauseOperands()) {
144 if (Operation *defOp = operand.getDefiningOp())
145 if (isa<ACC_DATA_ENTRY_OPS>(defOp))
146 dataEntryOps.push_back(defOp);
147 }
148 // Collect firstprivate operations
149 for (Value operand : computeConstructOp.getFirstprivateOperands()) {
150 if (Operation *defOp = operand.getDefiningOp())
151 firstprivateOps.push_back(defOp);
152 }
153
154 // Collect reduction operations
155 for (Value operand : computeConstructOp.getReductionOperands()) {
156 if (Operation *defOp = operand.getDefiningOp())
157 reductionOps.push_back(defOp);
158 }
159
160 // Find corresponding exit operations for each entry operation.
161 // Iterate backwards through entry ops since exit ops appear in reverse order.
162 for (Operation *dataEntryOp : llvm::reverse(dataEntryOps))
163 for (Operation *user : dataEntryOp->getUsers())
164 if (isa<ACC_DATA_EXIT_OPS>(user))
165 dataExitOps.push_back(user);
166
167 // Create scf.if with device and host execution paths
168 auto ifOp = scf::IfOp::create(rewriter, computeConstructOp.getLoc(),
169 TypeRange{}, ifCond, /*withElseRegion=*/true);
170
171 LLVM_DEBUG(llvm::dbgs() << "Cloning " << dataEntryOps.size()
172 << " data entry operations for device path\n");
173
174 // Device execution path (true branch)
175 Block &thenBlock = ifOp.getThenRegion().front();
176 rewriter.setInsertionPointToStart(&thenBlock);
177
178 // Clone data entry operations
179 SmallVector<Value> deviceDataOperands;
180 SmallVector<Value> firstprivateOperands;
181 SmallVector<Value> reductionOperands;
182
183 // Map the data entry and firstprivate ops for the cloned region
184 IRMapping deviceMapping;
185 for (Operation *dataOp : dataEntryOps) {
186 Operation *clonedDataOp = rewriter.clone(*dataOp, deviceMapping);
187 deviceDataOperands.push_back(clonedDataOp->getResult(0));
188 deviceMapping.map(dataOp->getResult(0), clonedDataOp->getResult(0));
189 }
190 for (Operation *firstprivateOp : firstprivateOps) {
191 Operation *clonedOp = rewriter.clone(*firstprivateOp, deviceMapping);
192 firstprivateOperands.push_back(clonedOp->getResult(0));
193 deviceMapping.map(firstprivateOp->getResult(0), clonedOp->getResult(0));
194 }
195 for (Operation *reductionOp : reductionOps) {
196 Operation *clonedOp = rewriter.clone(*reductionOp, deviceMapping);
197 reductionOperands.push_back(clonedOp->getResult(0));
198 deviceMapping.map(reductionOp->getResult(0), clonedOp->getResult(0));
199 }
200
201 // Create new compute op without if condition for device execution by
202 // cloning
203 OpTy newComputeOp = cast<OpTy>(
204 rewriter.clone(*computeConstructOp.getOperation(), deviceMapping));
205 newComputeOp.getIfCondMutable().clear();
206 newComputeOp.getDataClauseOperandsMutable().assign(deviceDataOperands);
207 newComputeOp.getFirstprivateOperandsMutable().assign(firstprivateOperands);
208 newComputeOp.getReductionOperandsMutable().assign(reductionOperands);
209
210 // Clone data exit operations
211 rewriter.setInsertionPointAfter(newComputeOp);
212 for (Operation *dataOp : dataExitOps)
213 rewriter.clone(*dataOp, deviceMapping);
214
215 rewriter.setInsertionPointToEnd(&thenBlock);
216 if (!thenBlock.getTerminator())
217 scf::YieldOp::create(rewriter, computeConstructOp.getLoc());
218
219 // Host execution path (false branch)
220 if (!computeConstructOp.getRegion().hasOneBlock()) {
221 accSupport->emitNYI(computeConstructOp.getLoc(),
222 "region with multiple blocks");
223 return;
224 }
225
226 // Don't need to clone original ops, just take them and legalize for host
227 ifOp.getElseRegion().takeBody(computeConstructOp.getRegion());
228
229 // Swap acc yield for scf yield
230 Block &elseBlock = ifOp.getElseRegion().front();
231 elseBlock.getTerminator()->erase();
232 rewriter.setInsertionPointToEnd(&elseBlock);
233 scf::YieldOp::create(rewriter, computeConstructOp.getLoc());
234
235 convertHostRegion(computeConstructOp, ifOp.getElseRegion());
236
237 // The original op is now empty and can be erased
238 eraseOps.push_back(computeConstructOp);
239
240 // TODO: Can probably 'move' the data ops instead of cloning them
241 // which would eliminate need to explicitly erase
242 for (Operation *dataOp : dataExitOps)
243 eraseOps.push_back(dataOp);
244
245 // The new host code may contain uses of the acc variables. Replace them by
246 // the host values.
247 for (Operation *dataOp : dataEntryOps) {
248 getAccVar(dataOp).replaceAllUsesWith(getVar(dataOp));
249 eraseOps.push_back(dataOp);
250 }
251 for (Operation *firstprivateOp : firstprivateOps) {
252 getAccVar(firstprivateOp).replaceAllUsesWith(getVar(firstprivateOp));
253 eraseOps.push_back(firstprivateOp);
254 }
255 for (Operation *reductionOp : reductionOps) {
256 getAccVar(reductionOp).replaceAllUsesWith(getVar(reductionOp));
257 eraseOps.push_back(reductionOp);
258 }
259}
260
261void ACCIfClauseLowering::runOnOperation() {
262 func::FuncOp funcOp = getOperation();
263 accSupport = &getAnalysis<OpenACCSupport>();
264
265 SmallVector<Operation *> eraseOps;
266 funcOp.walk([&](Operation *op) {
267 if (auto parallelOp = dyn_cast<acc::ParallelOp>(op))
268 lowerIfClauseForComputeConstruct(parallelOp, eraseOps);
269 else if (auto kernelsOp = dyn_cast<acc::KernelsOp>(op))
270 lowerIfClauseForComputeConstruct(kernelsOp, eraseOps);
271 else if (auto serialOp = dyn_cast<acc::SerialOp>(op))
272 lowerIfClauseForComputeConstruct(serialOp, eraseOps);
273 });
274
275 for (Operation *op : eraseOps)
276 op->erase();
277}
278
279} // namespace
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
This class allows control over how the GreedyPatternRewriteDriver works.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
RetT walk(FnT &&callback)
Walk all nested operations, blocks or regions (including this region), depending on the type of callb...
Definition Region.h:285
void replaceAllUsesWith(Value newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition Value.h:149
InFlightDiagnostic emitNYI(Location loc, const Twine &message)
Report a case that is not yet supported by the implementation.
mlir::Value getAccVar(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation.
Definition OpenACC.cpp:4923
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
Definition OpenACC.cpp:4892
void populateACCHostFallbackPatterns(RewritePatternSet &patterns, OpenACCSupport &accSupport, bool enableLoopConversion=true)
Populates all patterns for host fallback path (when if clause evaluates to false).
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
const FrozenRewritePatternSet & patterns
@ ExistingOps
Only pre-existing ops are processed.