MLIR 23.0.0git
ACCIfClauseLowering.cpp
Go to the documentation of this file.
1//===- ACCIfClauseLowering.cpp - Lower ACC compute construct if clauses --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass lowers OpenACC compute constructs (parallel, kernels, serial) with
10// `if` clauses using region specialization. It creates two execution paths:
11// device execution when the condition is true, host execution when false.
12//
13// Overview:
14// ---------
15// When an ACC compute construct has an `if` clause, the construct should only
16// execute on the device when the condition is true. If the condition is false,
17// the code should execute on the host instead. This pass transforms:
18//
19// acc.parallel if(%cond) { ... }
20//
21// Into:
22//
23// scf.if %cond {
24// // Device path: clone data ops, compute construct without if, exit ops
25// acc.parallel { ... }
26// } else {
27// // Host path: original region body with ACC ops converted to host
28// }
29//
30// Transformations:
31// ----------------
32// For each compute construct with an `if` clause:
33//
34// 1. Device Path (true branch):
35// - Clone data entry operations (acc.copyin, acc.create, etc.)
36// - Clone the compute construct without the `if` clause
37// - Clone data exit operations (acc.copyout, acc.delete, etc.)
38//
39// 2. Host Path (false branch):
40// - Move the original region body to the else branch
41// - Apply host fallback patterns to convert ACC ops to host equivalents
42//
43// 3. Cleanup:
44// - Erase the original compute construct and data operations
45// - Replace uses of ACC variables with host variables in the else branch
46//
47// Requirements:
48// -------------
49// To use this pass in a pipeline, the following requirements exist:
50//
51// 1. Analysis Registration (Optional): If custom behavior is needed for
52// emitting not-yet-implemented messages for unsupported cases, the pipeline
53// should pre-register the `acc::OpenACCSupport` analysis.
54//
55//===----------------------------------------------------------------------===//
56
58
63#include "mlir/IR/Builders.h"
64#include "mlir/IR/IRMapping.h"
67#include "llvm/ADT/STLExtras.h"
68#include "llvm/Support/Debug.h"
69
70namespace mlir {
71namespace acc {
72#define GEN_PASS_DEF_ACCIFCLAUSELOWERING
73#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
74} // namespace acc
75} // namespace mlir
76
77#define DEBUG_TYPE "acc-if-clause-lowering"
78
79using namespace mlir;
80using namespace mlir::acc;
81
82namespace {
83
84class ACCIfClauseLowering
85 : public acc::impl::ACCIfClauseLoweringBase<ACCIfClauseLowering> {
86 using ACCIfClauseLoweringBase<ACCIfClauseLowering>::ACCIfClauseLoweringBase;
87
88private:
89 OpenACCSupport *accSupport = nullptr;
90
91 void convertHostRegion(Operation *computeOp, Region &region);
92
93 template <typename OpTy>
94 void lowerIfClauseForComputeConstruct(OpTy computeConstructOp,
95 SmallVector<Operation *> &eraseOps);
96
97public:
98 void runOnOperation() override;
99};
100
101void ACCIfClauseLowering::convertHostRegion(Operation *computeOp,
102 Region &region) {
103 // Only collect ACC dialect operations - other ops don't need conversion
105 region.walk<WalkOrder::PreOrder>([&](Operation *op) {
106 if (isa<acc::OpenACCDialect>(op->getDialect()))
107 hostOps.push_back(op);
108 });
109
112
114 config.setUseTopDownTraversal(true);
116 if (failed(applyOpPatternsGreedily(hostOps, std::move(patterns), config)))
117 accSupport->emitNYI(computeOp->getLoc(), "failed to convert host region");
118}
119
120// Template function to handle if condition conversion for ACC compute
121// constructs
122template <typename OpTy>
123void ACCIfClauseLowering::lowerIfClauseForComputeConstruct(
124 OpTy computeConstructOp, SmallVector<Operation *> &eraseOps) {
125 Value ifCond = computeConstructOp.getIfCond();
126 if (!ifCond)
127 return;
128
129 IRRewriter rewriter(computeConstructOp);
130
131 LLVM_DEBUG(llvm::dbgs() << "Converting " << computeConstructOp->getName()
132 << " with if condition: " << computeConstructOp
133 << "\n");
134
135 // Collect data clause operations that need to be recreated in the if
136 // condition
137 SmallVector<Operation *> dataEntryOps;
138 SmallVector<Operation *> dataExitOps;
139 SmallVector<Operation *> firstprivateOps;
140 SmallVector<Operation *> privateOps;
141 SmallVector<Operation *> reductionOps;
142
143 // Collect data entry operations
144 for (Value operand : computeConstructOp.getDataClauseOperands())
145 if (Operation *defOp = operand.getDefiningOp())
146 if (isa<ACC_DATA_ENTRY_OPS>(defOp))
147 dataEntryOps.push_back(defOp);
148
149 // Find corresponding exit operations for each entry operation.
150 // Iterate backwards through entry ops since exit ops appear in reverse order.
151 for (Operation *dataEntryOp : llvm::reverse(dataEntryOps))
152 for (Operation *user : dataEntryOp->getUsers())
153 if (isa<ACC_DATA_EXIT_OPS>(user))
154 dataExitOps.push_back(user);
155
156 // Collect firstprivate, private, and reduction operations
157 auto collectOps = [&](SmallVector<Operation *> &ops, OperandRange operands) {
158 for (Value operand : operands)
159 if (Operation *defOp = operand.getDefiningOp())
160 ops.push_back(defOp);
161 };
162 collectOps(firstprivateOps, computeConstructOp.getFirstprivateOperands());
163 collectOps(privateOps, computeConstructOp.getPrivateOperands());
164 collectOps(reductionOps, computeConstructOp.getReductionOperands());
165
166 // Create scf.if with device and host execution paths
167 auto ifOp = scf::IfOp::create(rewriter, computeConstructOp.getLoc(),
168 TypeRange{}, ifCond, /*withElseRegion=*/true);
169
170 LLVM_DEBUG(llvm::dbgs() << "Cloning " << dataEntryOps.size()
171 << " data entry operations for device path\n");
172
173 // Device execution path (true branch)
174 Block &thenBlock = ifOp.getThenRegion().front();
175 rewriter.setInsertionPointToStart(&thenBlock);
176
177 // Clone data entry operations
178 SmallVector<Value> deviceDataOperands;
179 SmallVector<Value> firstprivateOperands;
180 SmallVector<Value> privateOperands;
181 SmallVector<Value> reductionOperands;
182
183 // Map the data entry and firstprivate ops for the cloned region
184 IRMapping deviceMapping;
185 auto cloneAndMapOps = [&](SmallVector<Operation *> &ops,
186 SmallVector<Value> &operands) {
187 for (Operation *op : ops) {
188 Operation *clonedOp = rewriter.clone(*op, deviceMapping);
189 operands.push_back(clonedOp->getResult(0));
190 deviceMapping.map(op->getResult(0), clonedOp->getResult(0));
191 }
192 };
193 cloneAndMapOps(dataEntryOps, deviceDataOperands);
194 cloneAndMapOps(firstprivateOps, firstprivateOperands);
195 cloneAndMapOps(privateOps, privateOperands);
196 cloneAndMapOps(reductionOps, reductionOperands);
197
198 // Create new compute op without if condition for device execution by
199 // cloning
200 OpTy newComputeOp = cast<OpTy>(
201 rewriter.clone(*computeConstructOp.getOperation(), deviceMapping));
202 newComputeOp.getIfCondMutable().clear();
203 newComputeOp.getDataClauseOperandsMutable().assign(deviceDataOperands);
204 newComputeOp.getFirstprivateOperandsMutable().assign(firstprivateOperands);
205 newComputeOp.getPrivateOperandsMutable().assign(privateOperands);
206 newComputeOp.getReductionOperandsMutable().assign(reductionOperands);
207
208 // Clone data exit operations
209 rewriter.setInsertionPointAfter(newComputeOp);
210 for (Operation *dataOp : dataExitOps)
211 rewriter.clone(*dataOp, deviceMapping);
212
213 rewriter.setInsertionPointToEnd(&thenBlock);
214 if (!thenBlock.getTerminator())
215 scf::YieldOp::create(rewriter, computeConstructOp.getLoc());
216
217 // Host execution path (false branch)
218 if (!computeConstructOp.getRegion().hasOneBlock()) {
219 accSupport->emitNYI(computeConstructOp.getLoc(),
220 "region with multiple blocks");
221 return;
222 }
223
224 // Don't need to clone original ops, just take them and legalize for host
225 ifOp.getElseRegion().takeBody(computeConstructOp.getRegion());
226
227 // Swap acc yield for scf yield
228 Block &elseBlock = ifOp.getElseRegion().front();
229 elseBlock.getTerminator()->erase();
230 rewriter.setInsertionPointToEnd(&elseBlock);
231 scf::YieldOp::create(rewriter, computeConstructOp.getLoc());
232
233 convertHostRegion(computeConstructOp, ifOp.getElseRegion());
234
235 // The original op is now empty and can be erased
236 eraseOps.push_back(computeConstructOp);
237
238 // TODO: Can probably 'move' the data ops instead of cloning them
239 // which would eliminate need to explicitly erase
240 for (Operation *dataOp : dataExitOps)
241 eraseOps.push_back(dataOp);
242
243 // The new host code may contain uses of the acc variables. Replace them by
244 // the host values.
245 auto replaceAndEraseOps = [&](SmallVector<Operation *> &ops) {
246 for (Operation *op : ops) {
248 eraseOps.push_back(op);
249 }
250 };
251 replaceAndEraseOps(dataEntryOps);
252 replaceAndEraseOps(firstprivateOps);
253 replaceAndEraseOps(privateOps);
254 replaceAndEraseOps(reductionOps);
255}
256
257void ACCIfClauseLowering::runOnOperation() {
258 func::FuncOp funcOp = getOperation();
259 accSupport = &getAnalysis<OpenACCSupport>();
260
262 funcOp.walk([&](Operation *op) {
263 if (auto parallelOp = dyn_cast<acc::ParallelOp>(op))
264 lowerIfClauseForComputeConstruct(parallelOp, eraseOps);
265 else if (auto kernelsOp = dyn_cast<acc::KernelsOp>(op))
266 lowerIfClauseForComputeConstruct(kernelsOp, eraseOps);
267 else if (auto serialOp = dyn_cast<acc::SerialOp>(op))
268 lowerIfClauseForComputeConstruct(serialOp, eraseOps);
269 });
270
271 for (Operation *op : eraseOps)
272 op->erase();
273}
274
275} // namespace
Block represents an ordered list of Operations.
Definition Block.h:33
Operation & front()
Definition Block.h:163
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
This class allows control over how the GreedyPatternRewriteDriver works.
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:566
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:438
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:414
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
user_range getUsers()
Returns a range of all users.
Definition Operation.h:873
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
RetT walk(FnT &&callback)
Walk all nested operations, blocks or regions (including this region), depending on the type of callb...
Definition Region.h:285
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
void replaceAllUsesWith(Value newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition Value.h:149
InFlightDiagnostic emitNYI(Location loc, const Twine &message)
Report a case that is not yet supported by the implementation.
mlir::Value getAccVar(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation.
Definition OpenACC.cpp:5191
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
Definition OpenACC.cpp:5160
void populateACCHostFallbackPatterns(RewritePatternSet &patterns, OpenACCSupport &accSupport, bool enableLoopConversion=true)
Populates all patterns for host fallback path (when if clause evaluates to false).
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
const FrozenRewritePatternSet & patterns
@ ExistingOps
Only pre-existing ops are processed.