67#include "llvm/ADT/STLExtras.h"
68#include "llvm/Support/Debug.h"
72#define GEN_PASS_DEF_ACCIFCLAUSELOWERING
73#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
77#define DEBUG_TYPE "acc-if-clause-lowering"
84class ACCIfClauseLowering
85 :
public acc::impl::ACCIfClauseLoweringBase<ACCIfClauseLowering> {
86 using ACCIfClauseLoweringBase<ACCIfClauseLowering>::ACCIfClauseLoweringBase;
93 template <
typename OpTy>
94 void lowerIfClauseForComputeConstruct(OpTy computeConstructOp,
98 void runOnOperation()
override;
101void ACCIfClauseLowering::convertHostRegion(
Operation *computeOp,
106 if (isa<acc::OpenACCDialect>(op->getDialect()))
107 hostOps.push_back(op);
114 config.setUseTopDownTraversal(
true);
117 accSupport->
emitNYI(computeOp->
getLoc(),
"failed to convert host region");
122template <
typename OpTy>
123void ACCIfClauseLowering::lowerIfClauseForComputeConstruct(
125 Value ifCond = computeConstructOp.getIfCond();
131 LLVM_DEBUG(llvm::dbgs() <<
"Converting " << computeConstructOp->getName()
132 <<
" with if condition: " << computeConstructOp
141 for (
Value operand : computeConstructOp.getDataClauseOperands()) {
142 if (
Operation *defOp = operand.getDefiningOp())
143 if (isa<ACC_DATA_ENTRY_OPS>(defOp))
144 dataEntryOps.push_back(defOp);
149 for (
Operation *dataEntryOp : llvm::reverse(dataEntryOps))
151 if (isa<ACC_DATA_EXIT_OPS>(user))
152 dataExitOps.push_back(user);
155 auto ifOp = scf::IfOp::create(rewriter, computeConstructOp.getLoc(),
162 Block &thenBlock = ifOp.getThenRegion().
front();
168 LLVM_DEBUG(llvm::dbgs() <<
"Cloning " << dataEntryOps.size()
169 <<
" data entry operations for device path\n");
173 deviceDataOperands.push_back(clonedDataOp->
getResult(0));
174 deviceMapping.
map(dataOp->getResult(0), clonedDataOp->
getResult(0));
179 OpTy newComputeOp = cast<OpTy>(
180 rewriter.
clone(*computeConstructOp.getOperation(), deviceMapping));
181 newComputeOp.getIfCondMutable().clear();
182 newComputeOp.getDataClauseOperandsMutable().assign(deviceDataOperands);
187 rewriter.
clone(*dataOp, deviceMapping);
191 scf::YieldOp::create(rewriter, computeConstructOp.getLoc());
194 if (!computeConstructOp.getRegion().hasOneBlock()) {
195 accSupport->
emitNYI(computeConstructOp.getLoc(),
196 "region with multiple blocks");
201 ifOp.getElseRegion().takeBody(computeConstructOp.getRegion());
204 Block &elseBlock = ifOp.getElseRegion().
front();
207 scf::YieldOp::create(rewriter, computeConstructOp.getLoc());
209 convertHostRegion(computeConstructOp, ifOp.getElseRegion());
212 eraseOps.push_back(computeConstructOp);
217 eraseOps.push_back(dataOp);
223 eraseOps.push_back(dataOp);
227void ACCIfClauseLowering::runOnOperation() {
228 func::FuncOp funcOp = getOperation();
229 accSupport = &getAnalysis<OpenACCSupport>();
233 if (
auto parallelOp = dyn_cast<acc::ParallelOp>(op))
234 lowerIfClauseForComputeConstruct(parallelOp, eraseOps);
235 else if (
auto kernelsOp = dyn_cast<acc::KernelsOp>(op))
236 lowerIfClauseForComputeConstruct(kernelsOp, eraseOps);
237 else if (
auto serialOp = dyn_cast<acc::SerialOp>(op))
238 lowerIfClauseForComputeConstruct(serialOp, eraseOps);
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
This class allows control over how the GreedyPatternRewriteDriver works.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
user_range getUsers()
Returns a range of all users.
MLIRContext * getContext()
Return the context this operation is associated with.
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RetT walk(FnT &&callback)
Walk all nested operations, blocks or regions (including this region), depending on the type of callb...
This class provides an abstraction over the various different ranges of value types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void replaceAllUsesWith(Value newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
InFlightDiagnostic emitNYI(Location loc, const Twine &message)
Report a case that is not yet supported by the implementation.
mlir::Value getAccVar(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation.
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
void populateACCHostFallbackPatterns(RewritePatternSet &patterns, OpenACCSupport &accSupport, bool enableLoopConversion=true)
Populates all patterns for host fallback path (when if clause evaluates to false).
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
const FrozenRewritePatternSet & patterns
@ ExistingOps
Only pre-existing ops are processed.