MLIR 22.0.0git
ACCSpecializeForDevice.cpp
Go to the documentation of this file.
1//===- ACCSpecializeForDevice.cpp -----------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass strips OpenACC constructs that are invalid or unnecessary inside
10// device code (specialized acc routines or compute construct regions).
11//
12// Overview:
13// ---------
14// In a specialized acc routine or compute construct, many OpenACC operations
15// do not make sense because they are host-side constructs. This pass removes
16// or transforms these operations appropriately:
17//
18// - Data operations that manage device memory from host perspective
19// - Compute constructs that launch kernels (we're already on device)
20// - Runtime operations like init/shutdown/set/wait
21//
22// Transformations:
23// ----------------
24// The pass applies the following transformations:
25//
26// 1. Data Entry Ops (replaced with var operand):
27// acc.attach, acc.copyin, acc.create, acc.declare_device_resident,
28// acc.declare_link, acc.deviceptr, acc.get_deviceptr, acc.nocreate,
29// acc.present, acc.update_device, acc.use_device
30//
31// 2. Data Exit Ops (erased):
32// acc.copyout, acc.delete, acc.detach, acc.update_host
33//
34// 3. Structured Data/Compute Constructs (region inlined):
35// acc.data, acc.host_data, acc.kernel_environment, acc.parallel,
36// acc.serial, acc.kernels
37//
38// 4. Unstructured Data Ops (erased):
39// acc.enter_data, acc.exit_data, acc.update, acc.declare_enter,
40// acc.declare_exit
41//
42// 5. Runtime Ops (erased):
43// acc.init, acc.shutdown, acc.set, acc.wait
44//
45// Scope of Application:
46// ---------------------
47// - For functions with `acc.specialized_routine` attribute: patterns are
48// applied to the entire function body.
49// - For non-specialized functions: patterns are applied only to ACC
50// operations INSIDE compute constructs (parallel, serial, kernels),
51// not to the compute constructs themselves or their data operands.
52//
53// Note: acc.cache, acc.private, acc.reduction, acc.firstprivate are NOT
54// transformed by this pass as they are valid in device code.
55//
56//===----------------------------------------------------------------------===//
57
59
65
66namespace mlir {
67namespace acc {
68#define GEN_PASS_DEF_ACCSPECIALIZEFORDEVICE
69#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
70} // namespace acc
71} // namespace mlir
72
73using namespace mlir;
74using namespace mlir::acc;
75
76namespace {
77
78class ACCSpecializeForDevice
79 : public acc::impl::ACCSpecializeForDeviceBase<ACCSpecializeForDevice> {
80public:
81 using ACCSpecializeForDeviceBase<
82 ACCSpecializeForDevice>::ACCSpecializeForDeviceBase;
83
84 void runOnOperation() override {
85 func::FuncOp func = getOperation();
86
87 RewritePatternSet patterns(&getContext());
89 GreedyRewriteConfig config;
90 config.setUseTopDownTraversal(true);
91
93 // For specialized acc routines, apply patterns to the entire function
94 (void)applyPatternsGreedily(func, std::move(patterns), config);
95 } else {
96 // For non-specialized functions, apply patterns only to ACC operations
97 // inside compute constructs (not to the compute constructs themselves).
98 SmallVector<Operation *> opsToTransform;
99 func.walk([&](Operation *op) {
100 if (isa<ACC_COMPUTE_CONSTRUCT_OPS>(op)) {
101 // Walk inside the compute construct and collect ACC ops
102 op->walk([&](Operation *innerOp) {
103 // Skip the compute construct itself
104 if (innerOp == op)
105 return;
106 if (isa<acc::OpenACCDialect>(innerOp->getDialect()))
107 opsToTransform.push_back(innerOp);
108 });
109 }
110 });
111 if (!opsToTransform.empty())
112 (void)applyOpPatternsGreedily(opsToTransform, std::move(patterns),
113 config);
114 }
115 }
116};
117
118} // namespace
119
120//===----------------------------------------------------------------------===//
121// Pattern population functions
122//===----------------------------------------------------------------------===//
123
126 MLIRContext *context = patterns.getContext();
127
128 // Declare patterns - erase declare_enter and its associated declare_exit
129 patterns.insert<ACCDeclareEnterOpConversion>(context);
130
131 // Data entry ops - replaced with their var operand
132 // Note: acc.cache, acc.private, acc.reduction, acc.firstprivate are NOT
133 // included here - they are valid in device code
145
146 // Data exit ops - simply erased (no results)
151
152 // Structured data constructs - unwrap their regions
156
157 // Compute constructs - unwrap their regions
161
162 // Unstructured data operations - erase them
166
167 // Runtime operations - erase them
168 patterns.insert<
171 context);
172}
b getContext())
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition Operation.h:220
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:797
Pattern to erase acc.declare_enter and its associated acc.declare_exit.
Pattern to simply erase an ACC op (for ops with no results).
Pattern to replace an ACC op with its var operand.
Pattern to unwrap a region from an ACC op and erase the wrapper.
void populateACCSpecializeForDevicePatterns(RewritePatternSet &patterns)
Populates all patterns for device specialization.
bool isSpecializedAccRoutine(mlir::Operation *op)
Used to check whether this is a specialized accelerator version of acc routine function.
Definition OpenACC.h:195
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
const FrozenRewritePatternSet & patterns