67#include "llvm/ADT/STLExtras.h"
68#include "llvm/Support/Debug.h"
72#define GEN_PASS_DEF_ACCIFCLAUSELOWERING
73#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
77#define DEBUG_TYPE "acc-if-clause-lowering"
84class ACCIfClauseLowering
85 :
public acc::impl::ACCIfClauseLoweringBase<ACCIfClauseLowering> {
86 using ACCIfClauseLoweringBase<ACCIfClauseLowering>::ACCIfClauseLoweringBase;
93 template <
typename OpTy>
94 void lowerIfClauseForComputeConstruct(OpTy computeConstructOp,
98 void runOnOperation()
override;
101void ACCIfClauseLowering::convertHostRegion(
Operation *computeOp,
106 if (isa<acc::OpenACCDialect>(op->getDialect()))
107 hostOps.push_back(op);
114 config.setUseTopDownTraversal(
true);
117 accSupport->
emitNYI(computeOp->
getLoc(),
"failed to convert host region");
122template <
typename OpTy>
123void ACCIfClauseLowering::lowerIfClauseForComputeConstruct(
125 Value ifCond = computeConstructOp.getIfCond();
131 LLVM_DEBUG(llvm::dbgs() <<
"Converting " << computeConstructOp->getName()
132 <<
" with if condition: " << computeConstructOp
144 for (
Value operand : computeConstructOp.getDataClauseOperands())
145 if (
Operation *defOp = operand.getDefiningOp())
146 if (isa<ACC_DATA_ENTRY_OPS>(defOp))
147 dataEntryOps.push_back(defOp);
151 for (
Operation *dataEntryOp : llvm::reverse(dataEntryOps))
153 if (isa<ACC_DATA_EXIT_OPS>(user))
154 dataExitOps.push_back(user);
158 for (
Value operand : operands)
159 if (
Operation *defOp = operand.getDefiningOp())
160 ops.push_back(defOp);
162 collectOps(firstprivateOps, computeConstructOp.getFirstprivateOperands());
163 collectOps(privateOps, computeConstructOp.getPrivateOperands());
164 collectOps(reductionOps, computeConstructOp.getReductionOperands());
167 auto ifOp = scf::IfOp::create(rewriter, computeConstructOp.getLoc(),
170 LLVM_DEBUG(llvm::dbgs() <<
"Cloning " << dataEntryOps.size()
171 <<
" data entry operations for device path\n");
174 Block &thenBlock = ifOp.getThenRegion().
front();
189 operands.push_back(clonedOp->
getResult(0));
190 deviceMapping.
map(op->getResult(0), clonedOp->
getResult(0));
193 cloneAndMapOps(dataEntryOps, deviceDataOperands);
194 cloneAndMapOps(firstprivateOps, firstprivateOperands);
195 cloneAndMapOps(privateOps, privateOperands);
196 cloneAndMapOps(reductionOps, reductionOperands);
200 OpTy newComputeOp = cast<OpTy>(
201 rewriter.
clone(*computeConstructOp.getOperation(), deviceMapping));
202 newComputeOp.getIfCondMutable().clear();
203 newComputeOp.getDataClauseOperandsMutable().assign(deviceDataOperands);
204 newComputeOp.getFirstprivateOperandsMutable().assign(firstprivateOperands);
205 newComputeOp.getPrivateOperandsMutable().assign(privateOperands);
206 newComputeOp.getReductionOperandsMutable().assign(reductionOperands);
211 rewriter.
clone(*dataOp, deviceMapping);
215 scf::YieldOp::create(rewriter, computeConstructOp.getLoc());
218 if (!computeConstructOp.getRegion().hasOneBlock()) {
219 accSupport->
emitNYI(computeConstructOp.getLoc(),
220 "region with multiple blocks");
225 ifOp.getElseRegion().takeBody(computeConstructOp.getRegion());
228 Block &elseBlock = ifOp.getElseRegion().
front();
231 scf::YieldOp::create(rewriter, computeConstructOp.getLoc());
233 convertHostRegion(computeConstructOp, ifOp.getElseRegion());
236 eraseOps.push_back(computeConstructOp);
241 eraseOps.push_back(dataOp);
248 eraseOps.push_back(op);
251 replaceAndEraseOps(dataEntryOps);
252 replaceAndEraseOps(firstprivateOps);
253 replaceAndEraseOps(privateOps);
254 replaceAndEraseOps(reductionOps);
257void ACCIfClauseLowering::runOnOperation() {
258 func::FuncOp funcOp = getOperation();
259 accSupport = &getAnalysis<OpenACCSupport>();
263 if (
auto parallelOp = dyn_cast<acc::ParallelOp>(op))
264 lowerIfClauseForComputeConstruct(parallelOp, eraseOps);
265 else if (
auto kernelsOp = dyn_cast<acc::KernelsOp>(op))
266 lowerIfClauseForComputeConstruct(kernelsOp, eraseOps);
267 else if (
auto serialOp = dyn_cast<acc::SerialOp>(op))
268 lowerIfClauseForComputeConstruct(serialOp, eraseOps);
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
This class allows control over how the GreedyPatternRewriteDriver works.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
user_range getUsers()
Returns a range of all users.
MLIRContext * getContext()
Return the context this operation is associated with.
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RetT walk(FnT &&callback)
Walk all nested operations, blocks or regions (including this region), depending on the type of callb...
This class provides an abstraction over the various different ranges of value types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void replaceAllUsesWith(Value newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
InFlightDiagnostic emitNYI(Location loc, const Twine &message)
Report a case that is not yet supported by the implementation.
mlir::Value getAccVar(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation.
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
void populateACCHostFallbackPatterns(RewritePatternSet &patterns, OpenACCSupport &accSupport, bool enableLoopConversion=true)
Populates all patterns for host fallback path (when if clause evaluates to false).
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
const FrozenRewritePatternSet & patterns
@ ExistingOps
Only pre-existing ops are processed.