MLIR 23.0.0git
OffloadTargetVerifier.cpp
Go to the documentation of this file.
1//===- OffloadTargetVerifier.cpp ------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass verifies that values and symbols used within offload regions are
10// legal for the target execution model.
11//
12// Overview:
13// ---------
14// Offload regions execute on a target device (e.g., GPU) where not all values
15// and symbols from the host context are accessible. This pass checks that
16// live-in values (values defined outside but used inside the region) and
17// symbol references are valid for device execution.
18//
19// The pass operates on any operation implementing `OffloadRegionOpInterface`,
20// which includes OpenACC compute constructs (`acc.parallel`, `acc.kernels`,
21// `acc.serial`) as well as GPU operations like `gpu.launch`.
22//
23// Verification:
24// -------------
25// For each offload region, the pass checks:
26//
27// 1. Live-in Values: Values flowing into the region must be valid for device
28// use. This includes checking that data has been properly mapped via
29// OpenACC data clauses (copyin, copyout, present, etc.) or is a scalar
30// that can be passed by value.
31//
32// 2. Symbol References: Symbols referenced inside the region must be
33// accessible on the device. This includes checking for proper `declare`
34// attributes on globals or device-resident data attributes.
35//
36// Requirements:
37// -------------
38// 1. Target Region Identification: Operations representing offload regions
39// must implement `acc::OffloadRegionOpInterface`.
40//
41// 2. OpenACCSupport Analysis: The pass relies on the `OpenACCSupport`
42// analysis to determine value and symbol validity. This analysis provides
43// dialect-specific hooks for checking legality through `isValidValueUse`
44// and `isValidSymbolUse` methods. Custom dialect support can be registered
45// by providing a derived `OpenACCSupport` analysis before running this
46// pass.
47//
48// 3. Device Type: The `device_type` option specifies the target device.
49// For `host` or `multicore` targets, verification of ACC compute
50// constructs is not yet implemented.
51//
52//===----------------------------------------------------------------------===//
53
59#include "mlir/IR/SymbolTable.h"
60#include "llvm/Support/Debug.h"
61
62namespace mlir {
63namespace acc {
64#define GEN_PASS_DEF_OFFLOADTARGETVERIFIER
65#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
66} // namespace acc
67} // namespace mlir
68
69#define DEBUG_TYPE "offload-target-verifier"
70
71using namespace mlir;
72
73namespace {
74
75class OffloadTargetVerifier
76 : public acc::impl::OffloadTargetVerifierBase<OffloadTargetVerifier> {
77public:
78 using OffloadTargetVerifierBase::OffloadTargetVerifierBase;
79
80 /// Returns true if the target device type corresponds to host execution.
81 bool isHostTarget() const {
82 return deviceType == acc::DeviceType::Host ||
83 deviceType == acc::DeviceType::Multicore;
84 }
85
86 /// Check live-in values for legality.
88 getIllegalLiveInValues(Region &region, Liveness &liveness,
89 acc::OpenACCSupport &accSupport) const {
90 auto isInvalid = [&](Value val) -> bool {
91 return !accSupport.isValidValueUse(val, region);
92 };
93
94 SmallVector<Value> illegalValues(llvm::make_filter_range(
95 liveness.getLiveIn(&region.front()), isInvalid));
96
97 return illegalValues;
98 }
99
100 /// Check symbol uses for legality.
102 getIllegalUsedSymbols(Region &region, acc::OpenACCSupport &accSupport) const {
103 auto symUses = SymbolTable::getSymbolUses(&region);
104
105 // When there are no symbols used in the region, there are no illegal ones.
106 if (!symUses.has_value())
107 return {};
108
109 auto isInvalidSymbol = [&](const SymbolTable::SymbolUse &symUse) -> bool {
110 Operation *definingOp = nullptr;
111 return !accSupport.isValidSymbolUse(symUse.getUser(),
112 symUse.getSymbolRef(), &definingOp);
113 };
114
115 auto invalidSyms =
116 llvm::make_filter_range(symUses.value(), isInvalidSymbol);
117 SmallVector<SymbolTable::SymbolUse> invalidSymsList(invalidSyms);
118 return invalidSymsList;
119 }
120
121 /// Retrieve variable names for the given values.
123 getVariableNames(ArrayRef<Value> values, acc::OpenACCSupport &accSupport) {
125 names.reserve(values.size());
126 for (Value value : values)
127 names.push_back(accSupport.getVariableName(value));
128 return names;
129 }
130
131 /// Check if the region has illegal live-in values.
132 bool hasIllegalLiveInValues(Operation *regionOp,
133 acc::OpenACCSupport &accSupport) const {
134 if (regionOp->getNumRegions() == 0)
135 return false;
136
137 Liveness liveness(regionOp);
138 SmallVector<Value> invalidValues =
139 getIllegalLiveInValues(regionOp->getRegion(0), liveness, accSupport);
140
141 bool hasIllegalValues = !invalidValues.empty();
142
143 if (hasIllegalValues) {
144 SmallVector<std::string> invalidVarNames =
145 getVariableNames(invalidValues, accSupport);
146
147 if (softCheck) {
148 // Emit warnings for each illegal value.
149 auto diag = regionOp->emitWarning("offload target verifier: ")
150 << invalidValues.size() << " illegal live-in value(s)";
151 for (auto [invalidValue, name] :
152 llvm::zip(invalidValues, invalidVarNames)) {
153 if (name.empty()) {
154 diag.attachNote(invalidValue.getLoc()) << "value: " << invalidValue;
155 } else {
156 diag.attachNote(invalidValue.getLoc())
157 << "value: " << invalidValue << ", name: " << name;
158 }
159 }
160 } else {
161 std::string message = "offload target verifier failed due to " +
162 std::to_string(invalidValues.size()) +
163 " illegal live-in value(s)";
164 SmallVector<std::string> availableVarNames;
165 for (const std::string &name : invalidVarNames)
166 if (!name.empty())
167 availableVarNames.push_back(name);
168 if (!availableVarNames.empty())
169 message += " including: " + llvm::join(availableVarNames, ", ");
170 accSupport.emitNYI(regionOp->getLoc(), message);
171 }
172 }
173
174 return hasIllegalValues;
175 }
176
177 /// Check if the region has illegal symbol uses.
178 bool hasIllegalSymbolUses(Operation *regionOp,
179 acc::OpenACCSupport &accSupport) const {
180 if (regionOp->getNumRegions() == 0)
181 return false;
182
184 getIllegalUsedSymbols(regionOp->getRegion(0), accSupport);
185
186 bool hasIllegalSymbols = !invalidSyms.empty();
187
188 if (hasIllegalSymbols) {
189 auto getSymName = [&](SymbolTable::SymbolUse symUse) -> std::string {
190 return symUse.getSymbolRef().getLeafReference().str();
191 };
192 std::string invalidString =
193 llvm::join(llvm::map_range(invalidSyms, getSymName), ", ");
194
195 // Emit only warnings when softCheck is enabled.
196 if (softCheck)
197 regionOp->emitWarning("offload target verifier: illegal symbol(s): ")
198 << invalidString;
199 else
200 accSupport.emitNYI(regionOp->getLoc(),
201 "offload target verifier failed due to illegal "
202 "symbol(s): " +
203 invalidString);
204 }
205
206 return hasIllegalSymbols;
207 }
208
209 void runOnOperation() override {
210 LLVM_DEBUG(llvm::dbgs() << "Enter OffloadTargetVerifier()\n");
211 func::FuncOp func = getOperation();
212
213 // Try to get cached parent analysis first, fall back to local analysis.
214 auto cachedAnalysis =
215 getCachedParentAnalysis<acc::OpenACCSupport>(func->getParentOp());
216 acc::OpenACCSupport &accSupport = cachedAnalysis
217 ? cachedAnalysis->get()
218 : getAnalysis<acc::OpenACCSupport>();
219
220 bool hasErrors = false;
221
222 func.walk([&](Operation *op) {
223 // Only process offload region operations.
224 if (!isa<acc::OffloadRegionOpInterface>(op))
225 return WalkResult::advance();
226
227 // TODO: Host/multicore verification for ACC compute constructs is not yet
228 // implemented.
229 if (isHostTarget() && isa<ACC_COMPUTE_CONSTRUCT_OPS>(op)) {
230 accSupport.emitNYI(op->getLoc(),
231 "host/multicore verification for ACC compute "
232 "constructs");
233 return WalkResult::advance();
234 }
235
236 // Check for illegal live-in values.
237 bool hasIllegalValues = hasIllegalLiveInValues(op, accSupport);
238 if (hasIllegalValues)
239 hasErrors = true;
240
241 // Check for illegal symbol uses.
242 bool hasIllegalSyms = hasIllegalSymbolUses(op, accSupport);
243 if (hasIllegalSyms)
244 hasErrors = true;
245
246 if (!hasIllegalValues && !hasIllegalSyms && softCheck)
247 op->emitRemark("offload target verifier: passed validity check");
248
249 return WalkResult::advance();
250 });
251
252 if (hasErrors && !softCheck)
253 signalPassFailure();
254
255 LLVM_DEBUG(llvm::dbgs() << "Exit OffloadTargetVerifier()\n");
256 }
257};
258
259} // namespace
static std::string diag(const llvm::Value &value)
Represents an analysis for computing liveness information from a given top-level operation.
Definition Liveness.h:47
const ValueSetT & getLiveIn(Block *block) const
Returns a reference to a set containing live-in values (unordered).
Definition Liveness.cpp:231
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:711
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:699
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:240
InFlightDiagnostic emitRemark(const Twine &message={})
Emit a remark about this operation, reporting up to any diagnostic handlers that may be listening.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
This class represents a specific symbol use.
static std::optional< UseRange > getSymbolUses(Operation *from)
Get an iterator range for all of the uses, for any symbol, that are nested within the given operation...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
static WalkResult advance()
Definition WalkResult.h:47
InFlightDiagnostic emitNYI(Location loc, const Twine &message)
Report a case that is not yet supported by the implementation.
bool isValidValueUse(Value v, Region &region)
Check if a value use is legal in an OpenACC region.
bool isValidSymbolUse(Operation *user, SymbolRefAttr symbol, Operation **definingOpPtr=nullptr)
Check if a symbol use is valid for use in an OpenACC region.
std::string getVariableName(Value v)
Get the variable name for a given value.
Include the generated interface declarations.