67#include "llvm/ADT/STLExtras.h"
68#include "llvm/Support/Debug.h"
72#define GEN_PASS_DEF_ACCIFCLAUSELOWERING
73#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
77#define DEBUG_TYPE "acc-if-clause-lowering"
84class ACCIfClauseLowering
93 template <
typename OpTy>
94 void lowerIfClauseForComputeConstruct(OpTy computeConstructOp,
98 void runOnOperation()
override;
101void ACCIfClauseLowering::convertHostRegion(
Operation *computeOp,
106 if (isa<acc::OpenACCDialect>(op->getDialect()))
107 hostOps.push_back(op);
114 config.setUseTopDownTraversal(
true);
117 accSupport->
emitNYI(computeOp->
getLoc(),
"failed to convert host region");
122template <
typename OpTy>
123void ACCIfClauseLowering::lowerIfClauseForComputeConstruct(
124 OpTy computeConstructOp, SmallVector<Operation *> &eraseOps) {
125 Value ifCond = computeConstructOp.getIfCond();
129 IRRewriter rewriter(computeConstructOp);
131 LLVM_DEBUG(llvm::dbgs() <<
"Converting " << computeConstructOp->getName()
132 <<
" with if condition: " << computeConstructOp
137 SmallVector<Operation *> dataEntryOps;
138 SmallVector<Operation *> dataExitOps;
139 SmallVector<Operation *> firstprivateOps;
140 SmallVector<Operation *> reductionOps;
143 for (Value operand : computeConstructOp.getDataClauseOperands()) {
144 if (Operation *defOp = operand.getDefiningOp())
145 if (isa<ACC_DATA_ENTRY_OPS>(defOp))
146 dataEntryOps.push_back(defOp);
149 for (Value operand : computeConstructOp.getFirstprivateOperands()) {
150 if (Operation *defOp = operand.getDefiningOp())
151 firstprivateOps.push_back(defOp);
155 for (Value operand : computeConstructOp.getReductionOperands()) {
156 if (Operation *defOp = operand.getDefiningOp())
157 reductionOps.push_back(defOp);
162 for (Operation *dataEntryOp : llvm::reverse(dataEntryOps))
163 for (Operation *user : dataEntryOp->getUsers())
164 if (isa<ACC_DATA_EXIT_OPS>(user))
165 dataExitOps.push_back(user);
168 auto ifOp = scf::IfOp::create(rewriter, computeConstructOp.getLoc(),
171 LLVM_DEBUG(llvm::dbgs() <<
"Cloning " << dataEntryOps.size()
172 <<
" data entry operations for device path\n");
175 Block &thenBlock = ifOp.getThenRegion().front();
176 rewriter.setInsertionPointToStart(&thenBlock);
179 SmallVector<Value> deviceDataOperands;
180 SmallVector<Value> firstprivateOperands;
181 SmallVector<Value> reductionOperands;
184 IRMapping deviceMapping;
185 for (Operation *dataOp : dataEntryOps) {
186 Operation *clonedDataOp = rewriter.clone(*dataOp, deviceMapping);
187 deviceDataOperands.push_back(clonedDataOp->
getResult(0));
188 deviceMapping.
map(dataOp->getResult(0), clonedDataOp->
getResult(0));
190 for (Operation *firstprivateOp : firstprivateOps) {
191 Operation *clonedOp = rewriter.clone(*firstprivateOp, deviceMapping);
192 firstprivateOperands.push_back(clonedOp->
getResult(0));
193 deviceMapping.
map(firstprivateOp->getResult(0), clonedOp->
getResult(0));
195 for (Operation *reductionOp : reductionOps) {
196 Operation *clonedOp = rewriter.
clone(*reductionOp, deviceMapping);
197 reductionOperands.push_back(clonedOp->
getResult(0));
198 deviceMapping.
map(reductionOp->getResult(0), clonedOp->
getResult(0));
203 OpTy newComputeOp = cast<OpTy>(
204 rewriter.clone(*computeConstructOp.getOperation(), deviceMapping));
205 newComputeOp.getIfCondMutable().clear();
206 newComputeOp.getDataClauseOperandsMutable().assign(deviceDataOperands);
207 newComputeOp.getFirstprivateOperandsMutable().assign(firstprivateOperands);
208 newComputeOp.getReductionOperandsMutable().assign(reductionOperands);
211 rewriter.setInsertionPointAfter(newComputeOp);
212 for (Operation *dataOp : dataExitOps)
213 rewriter.clone(*dataOp, deviceMapping);
215 rewriter.setInsertionPointToEnd(&thenBlock);
217 scf::YieldOp::create(rewriter, computeConstructOp.getLoc());
220 if (!computeConstructOp.getRegion().hasOneBlock()) {
221 accSupport->
emitNYI(computeConstructOp.getLoc(),
222 "region with multiple blocks");
227 ifOp.getElseRegion().takeBody(computeConstructOp.getRegion());
230 Block &elseBlock = ifOp.getElseRegion().front();
232 rewriter.setInsertionPointToEnd(&elseBlock);
233 scf::YieldOp::create(rewriter, computeConstructOp.getLoc());
235 convertHostRegion(computeConstructOp, ifOp.getElseRegion());
238 eraseOps.push_back(computeConstructOp);
242 for (Operation *dataOp : dataExitOps)
243 eraseOps.push_back(dataOp);
247 for (Operation *dataOp : dataEntryOps) {
249 eraseOps.push_back(dataOp);
251 for (Operation *firstprivateOp : firstprivateOps) {
253 eraseOps.push_back(firstprivateOp);
255 for (Operation *reductionOp : reductionOps) {
257 eraseOps.push_back(reductionOp);
261void ACCIfClauseLowering::runOnOperation() {
262 func::FuncOp funcOp = getOperation();
263 accSupport = &getAnalysis<OpenACCSupport>();
265 SmallVector<Operation *> eraseOps;
266 funcOp.walk([&](Operation *op) {
267 if (
auto parallelOp = dyn_cast<acc::ParallelOp>(op))
268 lowerIfClauseForComputeConstruct(parallelOp, eraseOps);
269 else if (
auto kernelsOp = dyn_cast<acc::KernelsOp>(op))
270 lowerIfClauseForComputeConstruct(kernelsOp, eraseOps);
271 else if (
auto serialOp = dyn_cast<acc::SerialOp>(op))
272 lowerIfClauseForComputeConstruct(serialOp, eraseOps);
275 for (Operation *op : eraseOps)
Operation * getTerminator()
Get the terminator operation of this block.
This class allows control over how the GreedyPatternRewriteDriver works.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Operation is the basic unit of execution within MLIR.
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
MLIRContext * getContext()
Return the context this operation is associated with.
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RetT walk(FnT &&callback)
Walk all nested operations, blocks or regions (including this region), depending on the type of callb...
void replaceAllUsesWith(Value newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
InFlightDiagnostic emitNYI(Location loc, const Twine &message)
Report a case that is not yet supported by the implementation.
ACCIfClauseLoweringBase()
mlir::Value getAccVar(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation.
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
void populateACCHostFallbackPatterns(RewritePatternSet &patterns, OpenACCSupport &accSupport, bool enableLoopConversion=true)
Populates all patterns for host fallback path (when if clause evaluates to false).
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
const FrozenRewritePatternSet & patterns
@ ExistingOps
Only pre-existing ops are processed.