MLIR 22.0.0git
ACCIfClauseLowering.cpp
Go to the documentation of this file.
1//===- ACCIfClauseLowering.cpp - Lower ACC compute construct if clauses --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass lowers OpenACC compute constructs (parallel, kernels, serial) with
10// `if` clauses using region specialization. It creates two execution paths:
11// device execution when the condition is true, host execution when false.
12//
13// Overview:
14// ---------
15// When an ACC compute construct has an `if` clause, the construct should only
16// execute on the device when the condition is true. If the condition is false,
17// the code should execute on the host instead. This pass transforms:
18//
19// acc.parallel if(%cond) { ... }
20//
21// Into:
22//
23// scf.if %cond {
24// // Device path: clone data ops, compute construct without if, exit ops
25// acc.parallel { ... }
26// } else {
27// // Host path: original region body with ACC ops converted to host
28// }
29//
30// Transformations:
31// ----------------
32// For each compute construct with an `if` clause:
33//
34// 1. Device Path (true branch):
35// - Clone data entry operations (acc.copyin, acc.create, etc.)
36// - Clone the compute construct without the `if` clause
37// - Clone data exit operations (acc.copyout, acc.delete, etc.)
38//
39// 2. Host Path (false branch):
40// - Move the original region body to the else branch
41// - Apply host fallback patterns to convert ACC ops to host equivalents
42//
43// 3. Cleanup:
44// - Erase the original compute construct and data operations
45// - Replace uses of ACC variables with host variables in the else branch
46//
47// Requirements:
48// -------------
49// To use this pass in a pipeline, the following requirements exist:
50//
51// 1. Analysis Registration (Optional): If custom behavior is needed for
52// emitting not-yet-implemented messages for unsupported cases, the pipeline
53// should pre-register the `acc::OpenACCSupport` analysis.
54//
55//===----------------------------------------------------------------------===//
56
58
63#include "mlir/IR/Builders.h"
64#include "mlir/IR/IRMapping.h"
67#include "llvm/ADT/STLExtras.h"
68#include "llvm/Support/Debug.h"
69
70namespace mlir {
71namespace acc {
72#define GEN_PASS_DEF_ACCIFCLAUSELOWERING
73#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
74} // namespace acc
75} // namespace mlir
76
77#define DEBUG_TYPE "acc-if-clause-lowering"
78
79using namespace mlir;
80using namespace mlir::acc;
81
82namespace {
83
84class ACCIfClauseLowering
85 : public acc::impl::ACCIfClauseLoweringBase<ACCIfClauseLowering> {
86 using ACCIfClauseLoweringBase<ACCIfClauseLowering>::ACCIfClauseLoweringBase;
87
88private:
89 OpenACCSupport *accSupport = nullptr;
90
91 void convertHostRegion(Operation *computeOp, Region &region);
92
93 template <typename OpTy>
94 void lowerIfClauseForComputeConstruct(OpTy computeConstructOp,
95 SmallVector<Operation *> &eraseOps);
96
97public:
98 void runOnOperation() override;
99};
100
101void ACCIfClauseLowering::convertHostRegion(Operation *computeOp,
102 Region &region) {
103 // Only collect ACC dialect operations - other ops don't need conversion
105 region.walk<WalkOrder::PreOrder>([&](Operation *op) {
106 if (isa<acc::OpenACCDialect>(op->getDialect()))
107 hostOps.push_back(op);
108 });
109
112
114 config.setUseTopDownTraversal(true);
116 if (failed(applyOpPatternsGreedily(hostOps, std::move(patterns), config)))
117 accSupport->emitNYI(computeOp->getLoc(), "failed to convert host region");
118}
119
120// Template function to handle if condition conversion for ACC compute
121// constructs
122template <typename OpTy>
123void ACCIfClauseLowering::lowerIfClauseForComputeConstruct(
124 OpTy computeConstructOp, SmallVector<Operation *> &eraseOps) {
125 Value ifCond = computeConstructOp.getIfCond();
126 if (!ifCond)
127 return;
128
129 IRRewriter rewriter(computeConstructOp);
130
131 LLVM_DEBUG(llvm::dbgs() << "Converting " << computeConstructOp->getName()
132 << " with if condition: " << computeConstructOp
133 << "\n");
134
135 // Collect data clause operations that need to be recreated in the if
136 // condition
137 SmallVector<Operation *> dataEntryOps;
138 SmallVector<Operation *> dataExitOps;
139
140 // Collect data entry operations
141 for (Value operand : computeConstructOp.getDataClauseOperands()) {
142 if (Operation *defOp = operand.getDefiningOp())
143 if (isa<ACC_DATA_ENTRY_OPS>(defOp))
144 dataEntryOps.push_back(defOp);
145 }
146
147 // Find corresponding exit operations for each entry operation.
148 // Iterate backwards through entry ops since exit ops appear in reverse order.
149 for (Operation *dataEntryOp : llvm::reverse(dataEntryOps))
150 for (Operation *user : dataEntryOp->getUsers())
151 if (isa<ACC_DATA_EXIT_OPS>(user))
152 dataExitOps.push_back(user);
153
154 // Create scf.if with device and host execution paths
155 auto ifOp = scf::IfOp::create(rewriter, computeConstructOp.getLoc(),
156 TypeRange{}, ifCond, /*withElseRegion=*/true);
157
158 // Declare deviceMapping at function scope for later use
159 IRMapping deviceMapping;
160
161 // Device execution path (true branch)
162 Block &thenBlock = ifOp.getThenRegion().front();
163 rewriter.setInsertionPointToStart(&thenBlock);
164
165 // Clone data entry operations
166 SmallVector<Value> deviceDataOperands;
167
168 LLVM_DEBUG(llvm::dbgs() << "Cloning " << dataEntryOps.size()
169 << " data entry operations for device path\n");
170
171 for (Operation *dataOp : dataEntryOps) {
172 Operation *clonedDataOp = rewriter.clone(*dataOp, deviceMapping);
173 deviceDataOperands.push_back(clonedDataOp->getResult(0));
174 deviceMapping.map(dataOp->getResult(0), clonedDataOp->getResult(0));
175 }
176
177 // Create new compute op without if condition for device execution by
178 // cloning
179 OpTy newComputeOp = cast<OpTy>(
180 rewriter.clone(*computeConstructOp.getOperation(), deviceMapping));
181 newComputeOp.getIfCondMutable().clear();
182 newComputeOp.getDataClauseOperandsMutable().assign(deviceDataOperands);
183
184 // Clone data exit operations
185 rewriter.setInsertionPointAfter(newComputeOp);
186 for (Operation *dataOp : dataExitOps)
187 rewriter.clone(*dataOp, deviceMapping);
188
189 rewriter.setInsertionPointToEnd(&thenBlock);
190 if (!thenBlock.getTerminator())
191 scf::YieldOp::create(rewriter, computeConstructOp.getLoc());
192
193 // Host execution path (false branch)
194 if (!computeConstructOp.getRegion().hasOneBlock()) {
195 accSupport->emitNYI(computeConstructOp.getLoc(),
196 "region with multiple blocks");
197 return;
198 }
199
200 // Don't need to clone original ops, just take them and legalize for host
201 ifOp.getElseRegion().takeBody(computeConstructOp.getRegion());
202
203 // Swap acc yield for scf yield
204 Block &elseBlock = ifOp.getElseRegion().front();
205 elseBlock.getTerminator()->erase();
206 rewriter.setInsertionPointToEnd(&elseBlock);
207 scf::YieldOp::create(rewriter, computeConstructOp.getLoc());
208
209 convertHostRegion(computeConstructOp, ifOp.getElseRegion());
210
211 // The original op is now empty and can be erased
212 eraseOps.push_back(computeConstructOp);
213
214 // TODO: Can probably 'move' the data ops instead of cloning them
215 // which would eliminate need to explicitly erase
216 for (Operation *dataOp : dataExitOps)
217 eraseOps.push_back(dataOp);
218
219 for (Operation *dataOp : dataEntryOps) {
220 // The new host code may contain uses of the acc variables. Replace them by
221 // the host values.
222 getAccVar(dataOp).replaceAllUsesWith(getVar(dataOp));
223 eraseOps.push_back(dataOp);
224 }
225}
226
227void ACCIfClauseLowering::runOnOperation() {
228 func::FuncOp funcOp = getOperation();
229 accSupport = &getAnalysis<OpenACCSupport>();
230
232 funcOp.walk([&](Operation *op) {
233 if (auto parallelOp = dyn_cast<acc::ParallelOp>(op))
234 lowerIfClauseForComputeConstruct(parallelOp, eraseOps);
235 else if (auto kernelsOp = dyn_cast<acc::KernelsOp>(op))
236 lowerIfClauseForComputeConstruct(kernelsOp, eraseOps);
237 else if (auto serialOp = dyn_cast<acc::SerialOp>(op))
238 lowerIfClauseForComputeConstruct(serialOp, eraseOps);
239 });
240
241 for (Operation *op : eraseOps)
242 op->erase();
243}
244
245} // namespace
Block represents an ordered list of Operations.
Definition Block.h:33
Operation & front()
Definition Block.h:153
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:244
This class allows control over how the GreedyPatternRewriteDriver works.
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:562
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:436
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:412
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
user_range getUsers()
Returns a range of all users.
Definition Operation.h:873
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
RetT walk(FnT &&callback)
Walk all nested operations, blocks or regions (including this region), depending on the type of callb...
Definition Region.h:285
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
void replaceAllUsesWith(Value newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition Value.h:149
InFlightDiagnostic emitNYI(Location loc, const Twine &message)
Report a case that is not yet supported by the implementation.
mlir::Value getAccVar(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation.
Definition OpenACC.cpp:4854
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
Definition OpenACC.cpp:4823
void populateACCHostFallbackPatterns(RewritePatternSet &patterns, OpenACCSupport &accSupport, bool enableLoopConversion=true)
Populates all patterns for host fallback path (when if clause evaluates to false).
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
const FrozenRewritePatternSet & patterns
@ ExistingOps
Only pre-existing ops are processed.