MLIR 22.0.0git
OffloadLiveInValueCanonicalization.cpp
Go to the documentation of this file.
1//===- OffloadLiveInValueCanonicalization.cpp -----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass canonicalizes live-in values for regions destined for offloading.
10//
11// Overview:
12// ---------
13// When a region is outlined (extracted into a separate function for device
14// execution), values defined outside the region but used inside become
15// arguments to the outlined function. However, some values cannot be passed
16// as arguments because they represent synthetic types (e.g., shape metadata,
17// field indices) or are better handled by recreating them inside the region.
18//
19// This pass identifies such values and either:
20// 1. Sinks the defining operation into the region (if all uses are inside)
21// 2. Rematerializes (clones) the operation inside the region (if there are
22// uses both inside and outside)
23//
24// Transforms:
25// -----------
26// The pass performs two main transformations on live-in values:
27//
28// 1. Sinking: If a candidate operation's result is only used inside the
29// offload region, the operation is moved into the region.
30//
31// 2. Rematerialization: If a candidate operation's result is used both
32// inside and outside the region, the operation is cloned inside the
33// region and uses within the region are updated to use the clone.
34//
35// Candidate operations are:
36// - Constants (matching arith.constant, etc.)
37// - Operations implementing `acc::OutlineRematerializationOpInterface`
38// - Address-of operations (`acc::AddressOfGlobalOpInterface`) referencing
39// symbols that are valid in GPU regions or constant globals
40//
41// The pass traces through view-like operations (`ViewLikeOpInterface`) and
42// partial entity access operations (`acc::PartialEntityAccessOpInterface`)
43// to find the original defining operation before making candidate decisions.
44//
45// Requirements:
46// -------------
47// To use this pass in a pipeline, the following requirements must be met:
48//
49// 1. Target Region Identification: Operations representing offload regions
50// must implement `acc::OffloadRegionOpInterface`. This interface marks
51// regions that will be outlined for device execution.
52//
53// 2. Rematerialization Candidates: Operations producing values that should
54// be rematerialized (rather than passed as arguments) should implement
55// `acc::OutlineRematerializationOpInterface`. Examples include operations
56// producing shape metadata, field indices, or other synthetic types.
57//
58// 3. Analysis Registration (Optional): If custom behavior is needed for
59// symbol validation (e.g., determining if a global is valid on device),
60// pre-register `acc::OpenACCSupport` analysis on the parent module.
61// If not registered, default behavior will be used.
62//
63// 4. View-Like Operations: Operations that create views or casts should
64// implement `ViewLikeOpInterface` or `acc::PartialEntityAccessOpInterface`
65// to allow the pass to trace through to the original defining operation.
66//
67//===----------------------------------------------------------------------===//
68
70
75#include "mlir/IR/Builders.h"
76#include "mlir/IR/Matchers.h"
77#include "mlir/IR/Operation.h"
78#include "mlir/IR/Region.h"
79#include "mlir/IR/SymbolTable.h"
80#include "mlir/IR/Value.h"
83#include "mlir/Pass/Pass.h"
84#include "mlir/Support/LLVM.h"
86
87namespace mlir {
88namespace acc {
89#define GEN_PASS_DEF_OFFLOADLIVEINVALUECANONICALIZATION
90#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
91} // namespace acc
92} // namespace mlir
93
94#define DEBUG_TYPE "offload-livein-value-canonicalization"
95
96using namespace mlir;
97
98namespace {
99
100/// Returns true if all users of the given value are inside the region.
101static bool allUsersAreInsideRegion(Value val, Region &region) {
102 for (Operation *user : val.getUsers())
103 if (!region.isAncestor(user->getParentRegion()))
104 return false;
105 return true;
106}
107
108/// Traces through view-like and partial entity access operations to find the
109/// original defining value.
110static Value getOriginalValue(Value val) {
111 Value prev;
112 while (val && val != prev) {
113 prev = val;
114 if (auto viewLikeOp = val.getDefiningOp<ViewLikeOpInterface>())
115 val = viewLikeOp.getViewSource();
116 if (auto partialAccess =
117 val.getDefiningOp<acc::PartialEntityAccessOpInterface>()) {
118 Value base = partialAccess.getBaseEntity();
119 if (base)
120 val = base;
121 }
122 }
123 return val;
124}
125
126/// Returns true if the operation is a candidate for rematerialization.
127/// Candidates are operations that:
128/// 1. Match the constant pattern (arith.constant, etc.)
129/// 2. Implement OutlineRematerializationOpInterface
130/// 3. Are address-of operations referencing valid symbols or constant globals
131/// The function traces through view-like operations (casts, reinterpret_cast)
132/// to find the original defining operation before making the determination.
133static bool isRematerializationCandidate(Value val,
134 acc::OpenACCSupport &accSupport) {
135 // Trace through view-like operations to find the original value.
136 Value origVal = getOriginalValue(val);
137 Operation *definingOp = origVal.getDefiningOp();
138 if (!definingOp)
139 return false;
140
141 LLVM_DEBUG(llvm::dbgs() << "\tChecking candidate: " << *definingOp << "\n");
142
143 // Constants are trivial and useful to rematerialize.
144 if (matchPattern(definingOp, m_Constant())) {
145 LLVM_DEBUG(llvm::dbgs() << "\t\t-> constant pattern matched\n");
146 return true;
147 }
148
149 // Operations implementing OutlineRematerializationOpInterface are candidates.
150 if (isa<acc::OutlineRematerializationOpInterface>(definingOp)) {
151 LLVM_DEBUG(llvm::dbgs() << "\t\t-> OutlineRematerializationOpInterface\n");
152 return true;
153 }
154
155 // Address-of operations referencing globals that are valid in GPU regions
156 // or referencing constant globals should be rematerialized.
157 if (auto addrOfOp = dyn_cast<acc::AddressOfGlobalOpInterface>(definingOp)) {
158 SymbolRefAttr symbol = addrOfOp.getSymbol();
159 LLVM_DEBUG(llvm::dbgs()
160 << "\t\tAddressOfGlobalOpInterface, symbol: " << symbol << "\n");
161
162 // If the symbol is already valid in GPU regions (e.g., has acc.declare),
163 // rematerializing ensures the address refers to the device copy.
164 Operation *globalOp = nullptr;
165 if (accSupport.isValidSymbolUse(definingOp, symbol, &globalOp)) {
166 LLVM_DEBUG(llvm::dbgs() << "\t\t-> isValidSymbolUse: true\n");
167 return true;
168 }
169 LLVM_DEBUG(llvm::dbgs() << "\t\t-> isValidSymbolUse: false\n");
170
171 // If the referenced global is constant, prefer rematerialization so the
172 // constant can be placed in GPU memory.
173 if (globalOp) {
174 if (auto globalVarOp =
175 dyn_cast<acc::GlobalVariableOpInterface>(globalOp)) {
176 if (globalVarOp.isConstant()) {
177 LLVM_DEBUG(llvm::dbgs() << "\t\t-> constant global\n");
178 return true;
179 }
180 }
181 }
182 }
183
184 LLVM_DEBUG(llvm::dbgs() << "\t\t-> not a candidate\n");
185 return false;
186}
187
188class OffloadLiveInValueCanonicalization
189 : public acc::impl::OffloadLiveInValueCanonicalizationBase<
190 OffloadLiveInValueCanonicalization> {
191public:
192 using acc::impl::OffloadLiveInValueCanonicalizationBase<
193 OffloadLiveInValueCanonicalization>::
194 OffloadLiveInValueCanonicalizationBase;
195
196 /// Canonicalizes live-in values for a region by sinking or rematerializing
197 /// operations. Returns true if any changes were made.
198 bool canonicalizeLiveInValues(Region &region,
199 acc::OpenACCSupport &accSupport) {
200 // 1) Collect live-in values.
201 SetVector<Value> liveInValues;
202 getUsedValuesDefinedAbove(region, liveInValues);
203 LLVM_DEBUG(llvm::dbgs()
204 << "\tFound " << liveInValues.size() << " live-in value(s)\n");
205
206 auto isSinkCandidate = [&region, &accSupport](Value val) -> bool {
207 return isRematerializationCandidate(val, accSupport) &&
208 allUsersAreInsideRegion(val, region);
209 };
210 auto isCloneCandidate = [&region, &accSupport](Value val) -> bool {
211 return isRematerializationCandidate(val, accSupport) &&
212 !allUsersAreInsideRegion(val, region);
213 };
214
215 // 2) Filter values into two sets - sink and rematerialization candidates.
216 SmallVector<Value> sinkCandidates(
217 llvm::make_filter_range(liveInValues, isSinkCandidate));
218 SmallVector<Value> rematerializationCandidates(
219 llvm::make_filter_range(liveInValues, isCloneCandidate));
220
221 LLVM_DEBUG(llvm::dbgs() << "\tSink candidates: " << sinkCandidates.size()
222 << ", clone candidates: "
223 << rematerializationCandidates.size() << "\n");
224
225 if (rematerializationCandidates.empty() && sinkCandidates.empty())
226 return false;
227
228 LLVM_DEBUG(llvm::dbgs() << "\tCanonicalizing values into "
229 << *region.getParentOp() << "\n");
230
231 // 3) Handle the sink set by moving the operations into the region.
232 for (Value sinkCandidate : sinkCandidates) {
233 Operation *sinkOp = sinkCandidate.getDefiningOp();
234 assert(sinkOp && "must have op to be considered");
235 sinkOp->moveBefore(&region.front().front());
236 LLVM_DEBUG(llvm::dbgs() << "\t\tSunk: " << *sinkOp << "\n");
237 }
238
239 // 4) Handle the rematerialization set by copying the operations into
240 // the region.
241 OpBuilder builder(region);
242 SmallVector<Operation *> opsToRematerialize;
243 for (Value rematerializationCandidate : rematerializationCandidates) {
244 Operation *rematerializationOp =
245 rematerializationCandidate.getDefiningOp();
246 assert(rematerializationOp && "must have op to be considered");
247 opsToRematerialize.push_back(rematerializationOp);
248 }
249 computeTopologicalSorting(opsToRematerialize);
250 for (Operation *rematerializationOp : opsToRematerialize) {
251 Operation *clonedOp = builder.clone(*rematerializationOp);
252 for (auto [oldResult, newResult] : llvm::zip(
253 rematerializationOp->getResults(), clonedOp->getResults())) {
254 replaceAllUsesInRegionWith(oldResult, newResult, region);
255 }
256 LLVM_DEBUG(llvm::dbgs() << "\t\tCloned: " << *clonedOp << "\n");
257 }
258
259 return true;
260 }
261
262 void runOnOperation() override {
263 LLVM_DEBUG(llvm::dbgs() << "Enter OffloadLiveInValueCanonicalization\n");
264
265 // Since OpenACCSupport is normally registered on modules, attempt to
266 // get it from the parent module first (if available), then fallback
267 // to the per-function analysis.
268 acc::OpenACCSupport *accSupportPtr = nullptr;
269 if (auto parentAnalysis = getCachedParentAnalysis<acc::OpenACCSupport>())
270 accSupportPtr = &parentAnalysis->get();
271 else
272 accSupportPtr = &getAnalysis<acc::OpenACCSupport>();
273 acc::OpenACCSupport &accSupport = *accSupportPtr;
274
275 func::FuncOp func = getOperation();
276 LLVM_DEBUG(llvm::dbgs()
277 << "Processing function: " << func.getName() << "\n");
278
279 func.walk([&](Operation *op) {
280 if (isa<acc::OffloadRegionOpInterface>(op)) {
281 LLVM_DEBUG(llvm::dbgs()
282 << "Found offload region: " << op->getName() << "\n");
283 assert(op->getNumRegions() == 1 && "must have 1 region");
284
285 // Canonicalization of values changes live-in set.
286 // Rerun the algorithm until convergence.
287 bool changes = false;
288 [[maybe_unused]] int iteration = 0;
289 do {
290 LLVM_DEBUG(llvm::dbgs() << "\tIteration " << iteration++ << "\n");
291 changes = canonicalizeLiveInValues(op->getRegion(0), accSupport);
292 } while (changes);
293 LLVM_DEBUG(llvm::dbgs()
294 << "\tConverged after " << iteration << " iteration(s)\n");
295 }
296 });
297
298 LLVM_DEBUG(llvm::dbgs() << "Exit OffloadLiveInValueCanonicalization\n");
299 }
300};
301
302} // namespace
Operation & front()
Definition Block.h:163
This class helps build Operations.
Definition Builders.h:207
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:562
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:686
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:674
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
result_range getResults()
Definition Operation.h:415
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
Definition Region.h:222
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition Region.h:200
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
user_range getUsers() const
Definition Value.h:218
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
bool isValidSymbolUse(Operation *user, SymbolRefAttr symbol, Operation **definingOpPtr=nullptr)
Check if a symbol use is valid for use in an OpenACC region.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
void replaceAllUsesInRegionWith(Value orig, Value replacement, Region &region)
Replace all uses of orig within the given region with replacement.
bool computeTopologicalSorting(MutableArrayRef< Operation * > ops, function_ref< bool(Value, Operation *)> isOperandReady=nullptr)
Compute a topological ordering of the given ops.
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:131
void getUsedValuesDefinedAbove(Region &region, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369