68#include "llvm/ADT/STLExtras.h"
69#include "llvm/Support/Debug.h"
73#define GEN_PASS_DEF_ACCIFCLAUSELOWERING
74#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
78#define DEBUG_TYPE "acc-if-clause-lowering"
85class ACCIfClauseLowering
86 :
public acc::impl::ACCIfClauseLoweringBase<ACCIfClauseLowering> {
87 using ACCIfClauseLoweringBase<ACCIfClauseLowering>::ACCIfClauseLoweringBase;
94 template <
typename OpTy>
95 void lowerIfClauseForComputeConstruct(OpTy computeConstructOp,
99 void runOnOperation()
override;
102void ACCIfClauseLowering::convertHostRegion(
Operation *computeOp,
107 if (isa<acc::OpenACCDialect>(op->getDialect()))
108 hostOps.push_back(op);
118 accSupport->
emitNYI(computeOp->
getLoc(),
"failed to convert host region");
123template <
typename OpTy>
124void ACCIfClauseLowering::lowerIfClauseForComputeConstruct(
126 Value ifCond = computeConstructOp.getIfCond();
132 LLVM_DEBUG(llvm::dbgs() <<
"Converting " << computeConstructOp->getName()
133 <<
" with if condition: " << computeConstructOp
145 for (
Value operand : computeConstructOp.getDataClauseOperands())
146 if (
Operation *defOp = operand.getDefiningOp())
147 if (isa<ACC_DATA_ENTRY_OPS>(defOp))
148 dataEntryOps.push_back(defOp);
152 for (
Operation *dataEntryOp : llvm::reverse(dataEntryOps))
154 if (isa<ACC_DATA_EXIT_OPS>(user))
155 dataExitOps.push_back(user);
159 for (
Value operand : operands)
160 if (
Operation *defOp = operand.getDefiningOp())
161 ops.push_back(defOp);
163 collectOps(firstprivateOps, computeConstructOp.getFirstprivateOperands());
164 collectOps(privateOps, computeConstructOp.getPrivateOperands());
165 collectOps(reductionOps, computeConstructOp.getReductionOperands());
168 auto ifOp = scf::IfOp::create(rewriter, computeConstructOp.getLoc(),
171 LLVM_DEBUG(llvm::dbgs() <<
"Cloning " << dataEntryOps.size()
172 <<
" data entry operations for device path\n");
175 Block &thenBlock = ifOp.getThenRegion().
front();
190 operands.push_back(clonedOp->
getResult(0));
191 deviceMapping.
map(op->getResult(0), clonedOp->
getResult(0));
194 cloneAndMapOps(dataEntryOps, deviceDataOperands);
195 cloneAndMapOps(firstprivateOps, firstprivateOperands);
196 cloneAndMapOps(privateOps, privateOperands);
197 cloneAndMapOps(reductionOps, reductionOperands);
201 OpTy newComputeOp = cast<OpTy>(
202 rewriter.
clone(*computeConstructOp.getOperation(), deviceMapping));
203 newComputeOp.getIfCondMutable().clear();
204 newComputeOp.getDataClauseOperandsMutable().assign(deviceDataOperands);
205 newComputeOp.getFirstprivateOperandsMutable().assign(firstprivateOperands);
206 newComputeOp.getPrivateOperandsMutable().assign(privateOperands);
207 newComputeOp.getReductionOperandsMutable().assign(reductionOperands);
212 rewriter.
clone(*dataOp, deviceMapping);
216 scf::YieldOp::create(rewriter, computeConstructOp.getLoc());
219 Region &hostRegion = computeConstructOp.getRegion();
222 ifOp.getElseRegion().takeBody(hostRegion);
225 Block &elseBlock = ifOp.getElseRegion().
front();
228 scf::YieldOp::create(rewriter, computeConstructOp.getLoc());
230 convertHostRegion(computeConstructOp, ifOp.getElseRegion());
234 Block &elseBlock = ifOp.getElseRegion().
front();
238 hostRegion, hostMapping, computeConstructOp.getLoc(), rewriter);
239 convertHostRegion(computeConstructOp, hostExecuteRegion.getRegion());
243 eraseOps.push_back(computeConstructOp);
248 eraseOps.push_back(dataOp);
255 eraseOps.push_back(op);
258 replaceAndEraseOps(dataEntryOps);
259 replaceAndEraseOps(firstprivateOps);
260 replaceAndEraseOps(privateOps);
261 replaceAndEraseOps(reductionOps);
264void ACCIfClauseLowering::runOnOperation() {
265 func::FuncOp funcOp = getOperation();
266 accSupport = &getAnalysis<OpenACCSupport>();
270 if (
auto parallelOp = dyn_cast<acc::ParallelOp>(op))
271 lowerIfClauseForComputeConstruct(parallelOp, eraseOps);
272 else if (
auto kernelsOp = dyn_cast<acc::KernelsOp>(op))
273 lowerIfClauseForComputeConstruct(kernelsOp, eraseOps);
274 else if (
auto serialOp = dyn_cast<acc::SerialOp>(op))
275 lowerIfClauseForComputeConstruct(serialOp, eraseOps);
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
This class allows control over how the GreedyPatternRewriteDriver works.
GreedyRewriteConfig & setUseTopDownTraversal(bool use=true)
GreedyRewriteConfig & setStrictness(GreedyRewriteStrictness mode)
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
user_range getUsers()
Returns a range of all users.
MLIRContext * getContext()
Return the context this operation is associated with.
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
bool hasOneBlock()
Return true if this region has exactly one block.
RetT walk(FnT &&callback)
Walk all nested operations, blocks or regions (including this region), depending on the type of callb...
This class provides an abstraction over the various different ranges of value types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void replaceAllUsesWith(Value newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
InFlightDiagnostic emitNYI(Location loc, const Twine &message)
Report a case that is not yet supported by the implementation.
mlir::Value getAccVar(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation.
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
scf::ExecuteRegionOp wrapMultiBlockRegionWithSCFExecuteRegion(Region ®ion, IRMapping &mapping, Location loc, RewriterBase &rewriter)
Wrap a multi-block region in an scf.execute_region.
void populateACCHostFallbackPatterns(RewritePatternSet &patterns, OpenACCSupport &accSupport, bool enableLoopConversion=true)
Populates all patterns for host fallback path (when if clause evaluates to false).
Include the generated interface declarations.
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
@ ExistingOps
Only pre-existing ops are processed.