MLIR 22.0.0git
OpenMPOffloadPrivatizationPrepare.cpp
Go to the documentation of this file.
1//===- OpenMPOffloadPrivatizationPrepare.cpp - Prepare OMP privatization --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
13#include "mlir/IR/Builders.h"
14#include "mlir/IR/Dominance.h"
15#include "mlir/IR/IRMapping.h"
16#include "mlir/Pass/Pass.h"
17#include "mlir/Support/LLVM.h"
18#include "llvm/Support/DebugLog.h"
19#include "llvm/Support/FormatVariadic.h"
20#include <cstdint>
21#include <iterator>
22#include <utility>
23
24//===----------------------------------------------------------------------===//
25// A pass that prepares OpenMP code for translation of delayed privatization
26// in the context of deferred target tasks. Deferred target tasks are created
27// when the nowait clause is used on the target directive.
28//===----------------------------------------------------------------------===//
29
30#define DEBUG_TYPE "omp-prepare-for-offload-privatization"
31
32namespace mlir {
33namespace omp {
34
35#define GEN_PASS_DEF_PREPAREFOROMPOFFLOADPRIVATIZATIONPASS
36#include "mlir/Dialect/OpenMP/Transforms/Passes.h.inc"
37
38} // namespace omp
39} // namespace mlir
40
41using namespace mlir;
42namespace {
43
44//===----------------------------------------------------------------------===//
45// PrepareForOMPOffloadPrivatizationPass
46//===----------------------------------------------------------------------===//
47
48class PrepareForOMPOffloadPrivatizationPass
49 : public omp::impl::PrepareForOMPOffloadPrivatizationPassBase<
50 PrepareForOMPOffloadPrivatizationPass> {
51
52 void runOnOperation() override {
53 ModuleOp mod = getOperation();
54
55 // In this pass, we make host-allocated privatized variables persist for
56 // deferred target tasks by copying them to the heap. Once the target task
57 // is done, this heap memory is freed. Since all of this happens on the host
58 // we can skip device modules.
59 auto offloadModuleInterface =
60 dyn_cast<omp::OffloadModuleInterface>(mod.getOperation());
61 if (offloadModuleInterface && offloadModuleInterface.getIsTargetDevice())
62 return;
63
64 getOperation()->walk([&](omp::TargetOp targetOp) {
65 if (!hasPrivateVars(targetOp) || !isTargetTaskDeferred(targetOp))
66 return;
67 IRRewriter rewriter(&getContext());
68 OperandRange privateVars = targetOp.getPrivateVars();
69 SmallVector<mlir::Value> newPrivVars;
70 Value fakeDependVar;
71 omp::TaskOp cleanupTaskOp;
72
73 newPrivVars.reserve(privateVars.size());
74 std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms();
75 for (auto [privVarIdx, privVarSymPair] :
76 llvm::enumerate(llvm::zip_equal(privateVars, *privateSyms))) {
77 Value privVar = std::get<0>(privVarSymPair);
78 Attribute privSym = std::get<1>(privVarSymPair);
79
80 omp::PrivateClauseOp privatizer = findPrivatizer(targetOp, privSym);
81 if (!privatizer.needsMap()) {
82 newPrivVars.push_back(privVar);
83 continue;
84 }
85 bool isFirstPrivate = privatizer.getDataSharingType() ==
86 omp::DataSharingClauseType::FirstPrivate;
87
88 Value mappedValue = targetOp.getMappedValueForPrivateVar(privVarIdx);
89 auto mapInfoOp = cast<omp::MapInfoOp>(mappedValue.getDefiningOp());
90
91 if (mapInfoOp.getMapCaptureType() == omp::VariableCaptureKind::ByCopy) {
92 newPrivVars.push_back(privVar);
93 continue;
94 }
95
96 // For deferred target tasks (!$omp target nowait), we need to keep
97 // a copy of the original, i.e. host variable being privatized so
98 // that it is available when the target task is eventually executed.
99 // We do this by first allocating as much heap memory as is needed by
100 // the original variable. Then, we use the init and copy regions of the
101 // privatizer, an instance of omp::PrivateClauseOp to set up the heap-
102 // allocated copy.
103 // After the target task is done, we need to use the dealloc region
104 // of the privatizer to clean up everything. We also need to free
105 // the heap memory we allocated. But due to the deferred nature
106 // of the target task, we cannot simply deallocate right after the
107 // omp.target operation else we may end up freeing memory before
108 // its eventual use by the target task. So, we create a dummy
109 // dependence between the target task and new omp.task. In the omp.task,
110 // we do all the cleanup. So, we end up with the following structure
111 //
112 // omp.target map_entries(..) ... nowait depend(out:fakeDependVar) {
113 // ...
114 // omp.terminator
115 // }
116 // omp.task depend(in: fakeDependVar) {
117 // /*cleanup_code*/
118 // omp.terminator
119 // }
120 // fakeDependVar is the address of the first heap-allocated copy of the
121 // host variable being privatized.
122
123 bool needsCleanupTask = !privatizer.getDeallocRegion().empty();
124
125 // Allocate heap memory that corresponds to the type of memory
126 // pointed to by varPtr
127 // For boxchars this won't be a pointer. But, MapsForPrivatizedSymbols
128 // should have mapped the pointer to the boxchar so use that as varPtr.
129 Value varPtr = mapInfoOp.getVarPtr();
130 Type varType = mapInfoOp.getVarType();
131 bool isPrivatizedByValue =
132 !isa<LLVM::LLVMPointerType>(privVar.getType());
133
134 assert(isa<LLVM::LLVMPointerType>(varPtr.getType()));
135 Value heapMem =
136 allocateHeapMem(targetOp, varPtr, varType, mod, rewriter);
137 if (!heapMem)
138 targetOp.emitError(
139 "Unable to allocate heap memory when trying to move "
140 "a private variable out of the stack and into the "
141 "heap for use by a deferred target task");
142
143 if (needsCleanupTask && !fakeDependVar)
144 fakeDependVar = heapMem;
145
146 // The types of private vars should match before and after the
147 // transformation. In particular, if the type is a pointer,
148 // simply record the newly allocated malloc location as the
149 // new private variable. If, however, the type is not a pointer
150 // then, we need to load the value from the newly allocated
151 // location. We'll insert that load later after we have updated
152 // the malloc'd location with the contents of the original
153 // variable.
154 if (!isPrivatizedByValue)
155 newPrivVars.push_back(heapMem);
156
157 // We now need to copy the original private variable into the newly
158 // allocated location in the heap.
159 // Find the earliest insertion point for the copy. This will be before
160 // the first in the list of omp::MapInfoOp instances that use varPtr.
161 // After the copy these omp::MapInfoOp instances will refer to heapMem
162 // instead.
163 Operation *varPtrDefiningOp = varPtr.getDefiningOp();
165 if (varPtrDefiningOp) {
166 users.insert(varPtrDefiningOp->user_begin(),
167 varPtrDefiningOp->user_end());
168 } else {
169 auto blockArg = cast<BlockArgument>(varPtr);
170 users.insert(blockArg.user_begin(), blockArg.user_end());
171 }
172 auto usesVarPtr = [&users](Operation *op) -> bool {
173 return users.count(op);
174 };
175
176 SmallVector<Operation *> chainOfOps;
177 chainOfOps.push_back(mapInfoOp);
178 for (auto member : mapInfoOp.getMembers()) {
179 omp::MapInfoOp memberMap =
180 cast<omp::MapInfoOp>(member.getDefiningOp());
181 if (usesVarPtr(memberMap))
182 chainOfOps.push_back(memberMap);
183 if (memberMap.getVarPtrPtr()) {
184 Operation *defOp = memberMap.getVarPtrPtr().getDefiningOp();
185 if (defOp && usesVarPtr(defOp))
186 chainOfOps.push_back(defOp);
187 }
188 }
189
190 DominanceInfo dom;
191 llvm::sort(chainOfOps, [&](Operation *l, Operation *r) {
192 if (l == r)
193 return false;
194 return dom.properlyDominates(l, r);
195 });
196
197 rewriter.setInsertionPoint(chainOfOps.front());
198
199 Operation *firstOp = chainOfOps.front();
200 Location loc = firstOp->getLoc();
201
202 // Create a llvm.func for 'region' that is marked always_inline and call
203 // it.
204 auto createAlwaysInlineFuncAndCallIt =
205 [&](Region &region, llvm::StringRef funcName,
206 llvm::ArrayRef<Value> args, bool returnsValue) -> Value {
207 assert(!region.empty() && "region cannot be empty");
208 LLVM::LLVMFuncOp func = createFuncOpForRegion(
209 loc, mod, region, funcName, rewriter, returnsValue);
210 auto call = LLVM::CallOp::create(rewriter, loc, func, args);
211 return call.getResult();
212 };
213
214 Value moldArg, newArg;
215 if (isPrivatizedByValue) {
216 moldArg = LLVM::LoadOp::create(rewriter, loc, varType, varPtr);
217 newArg = LLVM::LoadOp::create(rewriter, loc, varType, heapMem);
218 } else {
219 moldArg = varPtr;
220 newArg = heapMem;
221 }
222
223 Value initializedVal;
224 if (!privatizer.getInitRegion().empty())
225 initializedVal = createAlwaysInlineFuncAndCallIt(
226 privatizer.getInitRegion(),
227 llvm::formatv("{0}_{1}", privatizer.getSymName(), "init").str(),
228 {moldArg, newArg}, /*returnsValue=*/true);
229 else
230 initializedVal = newArg;
231
232 if (isFirstPrivate && !privatizer.getCopyRegion().empty())
233 initializedVal = createAlwaysInlineFuncAndCallIt(
234 privatizer.getCopyRegion(),
235 llvm::formatv("{0}_{1}", privatizer.getSymName(), "copy").str(),
236 {moldArg, initializedVal}, /*returnsValue=*/true);
237
238 if (isPrivatizedByValue)
239 (void)LLVM::StoreOp::create(rewriter, loc, initializedVal, heapMem);
240
241 // clone origOp, replace all uses of varPtr with heapMem and
242 // erase origOp.
243 auto cloneModifyAndErase = [&](Operation *origOp) -> Operation * {
244 Operation *clonedOp = rewriter.clone(*origOp);
245 rewriter.replaceAllOpUsesWith(origOp, clonedOp);
246 rewriter.modifyOpInPlace(clonedOp, [&]() {
247 clonedOp->replaceUsesOfWith(varPtr, heapMem);
248 });
249 rewriter.eraseOp(origOp);
250 return clonedOp;
251 };
252
253 // Now that we have set up the heap-allocated copy of the private
254 // variable, rewrite all the uses of the original variable with
255 // the heap-allocated variable.
256 rewriter.setInsertionPoint(targetOp);
257 mapInfoOp = cast<omp::MapInfoOp>(cloneModifyAndErase(mapInfoOp));
258 rewriter.setInsertionPoint(mapInfoOp);
259
260 // Fix any members that may use varPtr to now use heapMem
261 for (auto member : mapInfoOp.getMembers()) {
262 auto memberMapInfoOp = cast<omp::MapInfoOp>(member.getDefiningOp());
263 if (!usesVarPtr(memberMapInfoOp))
264 continue;
265 memberMapInfoOp =
266 cast<omp::MapInfoOp>(cloneModifyAndErase(memberMapInfoOp));
267 rewriter.setInsertionPoint(memberMapInfoOp);
268
269 if (memberMapInfoOp.getVarPtrPtr()) {
270 Operation *varPtrPtrdefOp =
271 memberMapInfoOp.getVarPtrPtr().getDefiningOp();
272 rewriter.setInsertionPoint(cloneModifyAndErase(varPtrPtrdefOp));
273 }
274 }
275
276 // If the type of the private variable is not a pointer,
277 // which is typically the case with !fir.boxchar types, then
278 // we need to ensure that the new private variable is also
279 // not a pointer. Insert a load from heapMem right before
280 // targetOp.
281 if (isPrivatizedByValue) {
282 rewriter.setInsertionPoint(targetOp);
283 auto newPrivVar = LLVM::LoadOp::create(rewriter, mapInfoOp.getLoc(),
284 varType, heapMem);
285 newPrivVars.push_back(newPrivVar);
286 }
287
288 // Deallocate
289 if (needsCleanupTask) {
290 if (!cleanupTaskOp) {
291 assert(fakeDependVar &&
292 "Need a valid value to set up a dependency");
293 rewriter.setInsertionPointAfter(targetOp);
294 omp::TaskOperands taskOperands;
295 auto inDepend = omp::ClauseTaskDependAttr::get(
296 rewriter.getContext(), omp::ClauseTaskDepend::taskdependin);
297 taskOperands.dependKinds.push_back(inDepend);
298 taskOperands.dependVars.push_back(fakeDependVar);
299 cleanupTaskOp = omp::TaskOp::create(rewriter, loc, taskOperands);
300 Block *taskBlock = rewriter.createBlock(&cleanupTaskOp.getRegion());
301 rewriter.setInsertionPointToEnd(taskBlock);
302 omp::TerminatorOp::create(rewriter, cleanupTaskOp.getLoc());
303 }
305 &*cleanupTaskOp.getRegion().getBlocks().begin());
306 (void)createAlwaysInlineFuncAndCallIt(
307 privatizer.getDeallocRegion(),
308 llvm::formatv("{0}_{1}", privatizer.getSymName(), "dealloc")
309 .str(),
310 {initializedVal}, /*returnsValue=*/false);
311 llvm::FailureOr<LLVM::LLVMFuncOp> freeFunc =
312 LLVM::lookupOrCreateFreeFn(rewriter, mod);
313 assert(llvm::succeeded(freeFunc) &&
314 "Could not find free in the module");
315 (void)LLVM::CallOp::create(rewriter, loc, freeFunc.value(),
316 ValueRange{heapMem});
317 }
318 }
319 assert(newPrivVars.size() == privateVars.size() &&
320 "The number of private variables must match before and after "
321 "transformation");
322 if (fakeDependVar) {
323 omp::ClauseTaskDependAttr outDepend = omp::ClauseTaskDependAttr::get(
324 rewriter.getContext(), omp::ClauseTaskDepend::taskdependout);
325 SmallVector<Attribute> newDependKinds;
326 if (!targetOp.getDependVars().empty()) {
327 std::optional<ArrayAttr> dependKinds = targetOp.getDependKinds();
328 assert(dependKinds && "bad depend clause in omp::TargetOp");
329 llvm::copy(*dependKinds, std::back_inserter(newDependKinds));
330 }
331 newDependKinds.push_back(outDepend);
332 ArrayAttr newDependKindsAttr =
333 ArrayAttr::get(rewriter.getContext(), newDependKinds);
334 targetOp.getDependVarsMutable().append(fakeDependVar);
335 targetOp.setDependKindsAttr(newDependKindsAttr);
336 }
337 rewriter.setInsertionPoint(targetOp);
338 targetOp.getPrivateVarsMutable().clear();
339 targetOp.getPrivateVarsMutable().assign(newPrivVars);
340 });
341 }
342
343private:
344 bool hasPrivateVars(omp::TargetOp targetOp) const {
345 return !targetOp.getPrivateVars().empty();
346 }
347
348 bool isTargetTaskDeferred(omp::TargetOp targetOp) const {
349 return targetOp.getNowait();
350 }
351
352 template <typename OpTy>
353 omp::PrivateClauseOp findPrivatizer(OpTy op, Attribute privSym) const {
354 SymbolRefAttr privatizerName = llvm::cast<SymbolRefAttr>(privSym);
355 omp::PrivateClauseOp privatizer =
357 op, privatizerName);
358 return privatizer;
359 }
360
361 // Get the (compile-time constant) size of varType as per the
362 // given DataLayout dl.
363 std::int64_t getSizeInBytes(const DataLayout &dl, Type varType) const {
364 llvm::TypeSize size = dl.getTypeSize(varType);
365 unsigned short alignment = dl.getTypeABIAlignment(varType);
366 return llvm::alignTo(size, alignment);
367 }
368
369 LLVM::LLVMFuncOp getMalloc(ModuleOp mod, IRRewriter &rewriter) const {
370 llvm::FailureOr<LLVM::LLVMFuncOp> mallocCall =
371 LLVM::lookupOrCreateMallocFn(rewriter, mod, rewriter.getI64Type());
372 assert(llvm::succeeded(mallocCall) &&
373 "Could not find malloc in the module");
374 return mallocCall.value();
375 }
376
377 Value allocateHeapMem(omp::TargetOp targetOp, Value privVar, Type varType,
378 ModuleOp mod, IRRewriter &rewriter) const {
379 OpBuilder::InsertionGuard guard(rewriter);
380 Value varPtr = privVar;
381 Operation *definingOp = varPtr.getDefiningOp();
382 BlockArgument blockArg;
383 if (!definingOp) {
384 blockArg = mlir::dyn_cast<BlockArgument>(varPtr);
385 rewriter.setInsertionPointToStart(blockArg.getParentBlock());
386 } else {
387 rewriter.setInsertionPoint(definingOp);
388 }
389 Location loc = definingOp ? definingOp->getLoc() : blockArg.getLoc();
390 LLVM::LLVMFuncOp mallocFn = getMalloc(mod, rewriter);
391
392 assert(mod.getDataLayoutSpec() &&
393 "MLIR module with no datalayout spec not handled yet");
394
395 const DataLayout &dl = DataLayout(mod);
396 std::int64_t distance = getSizeInBytes(dl, varType);
397
398 Value sizeBytes = LLVM::ConstantOp::create(
399 rewriter, loc, mallocFn.getFunctionType().getParamType(0), distance);
400
401 auto mallocCallOp =
402 LLVM::CallOp::create(rewriter, loc, mallocFn, ValueRange{sizeBytes});
403 return mallocCallOp.getResult();
404 }
405
406 // Create a function for srcRegion and attribute it to be always_inline.
407 // The big assumption here is that srcRegion is one of init, copy or dealloc
408 // regions of a omp::PrivateClauseop. Accordingly, the return type is assumed
409 // to either be the same as the types of the two arguments of the region (for
410 // init and copy regions) or void as would be the case for dealloc regions.
411 LLVM::LLVMFuncOp createFuncOpForRegion(Location loc, ModuleOp mod,
412 Region &srcRegion,
413 llvm::StringRef funcName,
414 IRRewriter &rewriter,
415 bool returnsValue = false) {
416
417 OpBuilder::InsertionGuard guard(rewriter);
418 rewriter.setInsertionPoint(mod.getBody(), mod.getBody()->end());
419 Region clonedRegion;
420 IRMapping mapper;
421 srcRegion.cloneInto(&clonedRegion, mapper);
422
423 SmallVector<Type> paramTypes;
424 llvm::copy(srcRegion.getArgumentTypes(), std::back_inserter(paramTypes));
425 Type resultType = returnsValue
426 ? srcRegion.getArgument(0).getType()
427 : LLVM::LLVMVoidType::get(rewriter.getContext());
428 LLVM::LLVMFunctionType funcType =
429 LLVM::LLVMFunctionType::get(resultType, paramTypes);
430
431 LLVM::LLVMFuncOp func =
432 LLVM::LLVMFuncOp::create(rewriter, loc, funcName, funcType);
433 func.setAlwaysInline(true);
434 rewriter.inlineRegionBefore(clonedRegion, func.getRegion(),
435 func.getRegion().end());
436 for (auto &block : func.getRegion().getBlocks()) {
437 if (isa<omp::YieldOp>(block.getTerminator())) {
438 omp::YieldOp yieldOp = cast<omp::YieldOp>(block.getTerminator());
439 rewriter.setInsertionPoint(yieldOp);
440 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(yieldOp, TypeRange(),
441 yieldOp.getOperands());
442 }
443 }
444 return func;
445 }
446};
447} // namespace
ArrayAttr()
b getContext())
static llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static omp::PrivateClauseOp findPrivatizer(Operation *from, SymbolRefAttr symbolName)
Looks up from the operation from and returns the PrivateClauseOp with name symbolName.
Location getLoc() const
Return the location for this argument.
Definition Value.h:324
IntegerType getI64Type()
Definition Builders.cpp:65
MLIRContext * getContext() const
Definition Builders.h:56
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
uint64_t getTypeABIAlignment(Type t) const
Returns the required alignment of the given type in the current scope.
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:430
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:562
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:436
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:412
void replaceUsesOfWith(Value from, Value to)
Replace any uses of 'from' with 'to' within this operation.
user_iterator user_end()
Definition Operation.h:870
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
user_iterator user_begin()
Definition Operation.h:869
bool empty()
Definition Region.h:60
void cloneInto(Region *dest, IRMapping &mapper)
Clone the internal blocks from this region into dest.
Definition Region.cpp:70
ValueTypeRange< BlockArgListType > getArgumentTypes()
Returns the argument types of the first block within the region.
Definition Region.cpp:36
BlockArgument getArgument(unsigned i)
Definition Region.h:124
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
void replaceAllOpUsesWith(Operation *from, ValueRange to)
Find uses of from and replace them with to.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Type getType() const
Return the type of this value.
Definition Value.h:105
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition Value.cpp:46
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Include the generated interface declarations.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:128