MLIR 22.0.0git
OpenMPOffloadPrivatizationPrepare.cpp
Go to the documentation of this file.
1//===- OpenMPOffloadPrivatizationPrepare.cpp - Prepare OMP privatization --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
13#include "mlir/IR/Builders.h"
14#include "mlir/IR/Dominance.h"
15#include "mlir/IR/IRMapping.h"
16#include "mlir/Pass/Pass.h"
17#include "mlir/Support/LLVM.h"
18#include "llvm/Support/DebugLog.h"
19#include "llvm/Support/FormatVariadic.h"
20#include <cstdint>
21#include <iterator>
22#include <utility>
23
24//===----------------------------------------------------------------------===//
25// A pass that prepares OpenMP code for translation of delayed privatization
26// in the context of deferred target tasks. Deferred target tasks are created
27// when the nowait clause is used on the target directive.
28//===----------------------------------------------------------------------===//
30#define DEBUG_TYPE "omp-prepare-for-offload-privatization"
32namespace mlir {
33namespace omp {
35#define GEN_PASS_DEF_PREPAREFOROMPOFFLOADPRIVATIZATIONPASS
36#include "mlir/Dialect/OpenMP/Transforms/Passes.h.inc"
38} // namespace omp
39} // namespace mlir
41using namespace mlir;
42namespace {
43
44//===----------------------------------------------------------------------===//
45// PrepareForOMPOffloadPrivatizationPass
46//===----------------------------------------------------------------------===//
47
48class PrepareForOMPOffloadPrivatizationPass
50 PrepareForOMPOffloadPrivatizationPass> {
52 void runOnOperation() override {
53 ModuleOp mod = getOperation();
54
55 // In this pass, we make host-allocated privatized variables persist for
56 // deferred target tasks by copying them to the heap. Once the target task
57 // is done, this heap memory is freed. Since all of this happens on the host
58 // we can skip device modules.
59 auto offloadModuleInterface =
60 dyn_cast<omp::OffloadModuleInterface>(mod.getOperation());
61 if (offloadModuleInterface && offloadModuleInterface.getIsTargetDevice())
62 return;
63
64 getOperation()->walk([&](omp::TargetOp targetOp) {
65 if (!hasPrivateVars(targetOp) || !isTargetTaskDeferred(targetOp))
66 return;
67 IRRewriter rewriter(&getContext());
68 OperandRange privateVars = targetOp.getPrivateVars();
69 SmallVector<mlir::Value> newPrivVars;
70 Value fakeDependVar;
71 omp::TaskOp cleanupTaskOp;
72
73 newPrivVars.reserve(privateVars.size());
74 std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms();
75 for (auto [privVarIdx, privVarSymPair] :
76 llvm::enumerate(llvm::zip_equal(privateVars, *privateSyms))) {
77 Value privVar = std::get<0>(privVarSymPair);
78 Attribute privSym = std::get<1>(privVarSymPair);
80 omp::PrivateClauseOp privatizer = findPrivatizer(targetOp, privSym);
81 if (!privatizer.needsMap()) {
82 newPrivVars.push_back(privVar);
83 continue;
84 }
85 bool isFirstPrivate = privatizer.getDataSharingType() ==
86 omp::DataSharingClauseType::FirstPrivate;
87
88 Value mappedValue = targetOp.getMappedValueForPrivateVar(privVarIdx);
89 auto mapInfoOp = cast<omp::MapInfoOp>(mappedValue.getDefiningOp());
90
91 if (mapInfoOp.getMapCaptureType() == omp::VariableCaptureKind::ByCopy) {
92 newPrivVars.push_back(privVar);
93 continue;
94 }
95
96 // For deferred target tasks (!$omp target nowait), we need to keep
97 // a copy of the original, i.e. host variable being privatized so
98 // that it is available when the target task is eventually executed.
99 // We do this by first allocating as much heap memory as is needed by
100 // the original variable. Then, we use the init and copy regions of the
101 // privatizer, an instance of omp::PrivateClauseOp to set up the heap-
102 // allocated copy.
103 // After the target task is done, we need to use the dealloc region
104 // of the privatizer to clean up everything. We also need to free
105 // the heap memory we allocated. But due to the deferred nature
106 // of the target task, we cannot simply deallocate right after the
107 // omp.target operation else we may end up freeing memory before
108 // its eventual use by the target task. So, we create a dummy
109 // dependence between the target task and new omp.task. In the omp.task,
110 // we do all the cleanup. So, we end up with the following structure
111 //
112 // omp.target map_entries(..) ... nowait depend(out:fakeDependVar) {
113 // ...
114 // omp.terminator
115 // }
116 // omp.task depend(in: fakeDependVar) {
117 // /*cleanup_code*/
118 // omp.terminator
119 // }
120 // fakeDependVar is the address of the first heap-allocated copy of the
121 // host variable being privatized.
122
123 bool needsCleanupTask = !privatizer.getDeallocRegion().empty();
124
125 // Allocate heap memory that corresponds to the type of memory
126 // pointed to by varPtr
127 // For boxchars this won't be a pointer. But, MapsForPrivatizedSymbols
128 // should have mapped the pointer to the boxchar so use that as varPtr.
129 Value varPtr = mapInfoOp.getVarPtr();
130 Type varType = mapInfoOp.getVarType();
131 bool isPrivatizedByValue =
132 !isa<LLVM::LLVMPointerType>(privVar.getType());
133
134 assert(isa<LLVM::LLVMPointerType>(varPtr.getType()));
135 Value heapMem =
136 allocateHeapMem(targetOp, varPtr, varType, mod, rewriter);
137 if (!heapMem)
138 targetOp.emitError(
139 "Unable to allocate heap memory when trying to move "
140 "a private variable out of the stack and into the "
141 "heap for use by a deferred target task");
142
143 if (needsCleanupTask && !fakeDependVar)
144 fakeDependVar = heapMem;
145
146 // The types of private vars should match before and after the
147 // transformation. In particular, if the type is a pointer,
148 // simply record the newly allocated malloc location as the
149 // new private variable. If, however, the type is not a pointer
150 // then, we need to load the value from the newly allocated
151 // location. We'll insert that load later after we have updated
152 // the malloc'd location with the contents of the original
153 // variable.
154 if (!isPrivatizedByValue)
155 newPrivVars.push_back(heapMem);
156
157 // We now need to copy the original private variable into the newly
158 // allocated location in the heap.
159 // Find the earliest insertion point for the copy. This will be before
160 // the first in the list of omp::MapInfoOp instances that use varPtr.
161 // After the copy these omp::MapInfoOp instances will refer to heapMem
162 // instead.
163 Operation *varPtrDefiningOp = varPtr.getDefiningOp();
165 if (varPtrDefiningOp) {
166 users.insert(varPtrDefiningOp->user_begin(),
167 varPtrDefiningOp->user_end());
168 } else {
169 auto blockArg = cast<BlockArgument>(varPtr);
170 users.insert(blockArg.user_begin(), blockArg.user_end());
171 }
172 auto usesVarPtr = [&users](Operation *op) -> bool {
173 return users.count(op);
174 };
175
176 SmallVector<Operation *> chainOfOps;
177 chainOfOps.push_back(mapInfoOp);
178 for (auto member : mapInfoOp.getMembers()) {
179 omp::MapInfoOp memberMap =
180 cast<omp::MapInfoOp>(member.getDefiningOp());
181 if (usesVarPtr(memberMap))
182 chainOfOps.push_back(memberMap);
183 if (memberMap.getVarPtrPtr()) {
184 Operation *defOp = memberMap.getVarPtrPtr().getDefiningOp();
185 if (defOp && usesVarPtr(defOp))
186 chainOfOps.push_back(defOp);
187 }
188 }
189
190 DominanceInfo dom;
191 llvm::sort(chainOfOps, [&](Operation *l, Operation *r) {
192 if (l == r)
193 return false;
194 return dom.properlyDominates(l, r);
195 });
196
197 rewriter.setInsertionPoint(chainOfOps.front());
198
199 Operation *firstOp = chainOfOps.front();
200 Location loc = firstOp->getLoc();
201
202 // Create a llvm.func for 'region' that is marked always_inline and call
203 // it.
204 auto createAlwaysInlineFuncAndCallIt =
205 [&](Region &region, llvm::StringRef funcName,
206 llvm::ArrayRef<Value> args, bool returnsValue) -> Value {
207 assert(!region.empty() && "region cannot be empty");
208 LLVM::LLVMFuncOp func = createFuncOpForRegion(
209 loc, mod, region, funcName, rewriter, returnsValue);
210 auto call = LLVM::CallOp::create(rewriter, loc, func, args);
211 return call.getResult();
212 };
213
214 Value moldArg, newArg;
215 if (isPrivatizedByValue) {
216 moldArg = LLVM::LoadOp::create(rewriter, loc, varType, varPtr);
217 newArg = LLVM::LoadOp::create(rewriter, loc, varType, heapMem);
218 } else {
219 moldArg = varPtr;
220 newArg = heapMem;
221 }
222
223 Value initializedVal;
224 if (!privatizer.getInitRegion().empty())
225 initializedVal = createAlwaysInlineFuncAndCallIt(
226 privatizer.getInitRegion(),
227 llvm::formatv("{0}_{1}", privatizer.getSymName(), "init").str(),
228 {moldArg, newArg}, /*returnsValue=*/true);
229 else
230 initializedVal = newArg;
231
232 if (isFirstPrivate && !privatizer.getCopyRegion().empty())
233 initializedVal = createAlwaysInlineFuncAndCallIt(
234 privatizer.getCopyRegion(),
235 llvm::formatv("{0}_{1}", privatizer.getSymName(), "copy").str(),
236 {moldArg, initializedVal}, /*returnsValue=*/true);
237
238 if (isPrivatizedByValue)
239 (void)LLVM::StoreOp::create(rewriter, loc, initializedVal, heapMem);
240
241 // clone origOp, replace all uses of varPtr with heapMem and
242 // erase origOp.
243 auto cloneModifyAndErase = [&](Operation *origOp) -> Operation * {
244 Operation *clonedOp = rewriter.clone(*origOp);
245 rewriter.replaceAllOpUsesWith(origOp, clonedOp);
246 rewriter.modifyOpInPlace(clonedOp, [&]() {
247 clonedOp->replaceUsesOfWith(varPtr, heapMem);
248 });
249 rewriter.eraseOp(origOp);
250 return clonedOp;
251 };
252
253 // Now that we have set up the heap-allocated copy of the private
254 // variable, rewrite all the uses of the original variable with
255 // the heap-allocated variable.
256 rewriter.setInsertionPoint(targetOp);
257 mapInfoOp = cast<omp::MapInfoOp>(cloneModifyAndErase(mapInfoOp));
258 rewriter.setInsertionPoint(mapInfoOp);
259
260 // Fix any members that may use varPtr to now use heapMem
261 for (auto member : mapInfoOp.getMembers()) {
262 auto memberMapInfoOp = cast<omp::MapInfoOp>(member.getDefiningOp());
263 if (!usesVarPtr(memberMapInfoOp))
264 continue;
265 memberMapInfoOp =
266 cast<omp::MapInfoOp>(cloneModifyAndErase(memberMapInfoOp));
267 rewriter.setInsertionPoint(memberMapInfoOp);
268
269 if (memberMapInfoOp.getVarPtrPtr()) {
270 Operation *varPtrPtrdefOp =
271 memberMapInfoOp.getVarPtrPtr().getDefiningOp();
272 rewriter.setInsertionPoint(cloneModifyAndErase(varPtrPtrdefOp));
273 }
274 }
275
276 // If the type of the private variable is not a pointer,
277 // which is typically the case with !fir.boxchar types, then
278 // we need to ensure that the new private variable is also
279 // not a pointer. Insert a load from heapMem right before
280 // targetOp.
281 if (isPrivatizedByValue) {
282 rewriter.setInsertionPoint(targetOp);
283 auto newPrivVar = LLVM::LoadOp::create(rewriter, mapInfoOp.getLoc(),
284 varType, heapMem);
285 newPrivVars.push_back(newPrivVar);
286 }
287
288 // Deallocate
289 if (needsCleanupTask) {
290 if (!cleanupTaskOp) {
291 assert(fakeDependVar &&
292 "Need a valid value to set up a dependency");
293 rewriter.setInsertionPointAfter(targetOp);
294 omp::TaskOperands taskOperands;
295 auto inDepend = omp::ClauseTaskDependAttr::get(
296 rewriter.getContext(), omp::ClauseTaskDepend::taskdependin);
297 taskOperands.dependKinds.push_back(inDepend);
298 taskOperands.dependVars.push_back(fakeDependVar);
299 cleanupTaskOp = omp::TaskOp::create(rewriter, loc, taskOperands);
300 Block *taskBlock = rewriter.createBlock(&cleanupTaskOp.getRegion());
301 rewriter.setInsertionPointToEnd(taskBlock);
302 omp::TerminatorOp::create(rewriter, cleanupTaskOp.getLoc());
303 }
305 &*cleanupTaskOp.getRegion().getBlocks().begin());
306 (void)createAlwaysInlineFuncAndCallIt(
307 privatizer.getDeallocRegion(),
308 llvm::formatv("{0}_{1}", privatizer.getSymName(), "dealloc")
309 .str(),
310 {initializedVal}, /*returnsValue=*/false);
311 llvm::FailureOr<LLVM::LLVMFuncOp> freeFunc =
312 LLVM::lookupOrCreateFreeFn(rewriter, mod);
313 assert(llvm::succeeded(freeFunc) &&
314 "Could not find free in the module");
315 (void)LLVM::CallOp::create(rewriter, loc, freeFunc.value(),
316 ValueRange{heapMem});
317 }
318 }
319 assert(newPrivVars.size() == privateVars.size() &&
320 "The number of private variables must match before and after "
321 "transformation");
322 if (fakeDependVar) {
323 omp::ClauseTaskDependAttr outDepend = omp::ClauseTaskDependAttr::get(
324 rewriter.getContext(), omp::ClauseTaskDepend::taskdependout);
325 SmallVector<Attribute> newDependKinds;
326 if (!targetOp.getDependVars().empty()) {
327 std::optional<ArrayAttr> dependKinds = targetOp.getDependKinds();
328 assert(dependKinds && "bad depend clause in omp::TargetOp");
329 llvm::copy(*dependKinds, std::back_inserter(newDependKinds));
330 }
331 newDependKinds.push_back(outDepend);
332 ArrayAttr newDependKindsAttr =
333 ArrayAttr::get(rewriter.getContext(), newDependKinds);
334 targetOp.getDependVarsMutable().append(fakeDependVar);
335 targetOp.setDependKindsAttr(newDependKindsAttr);
336 }
337 rewriter.setInsertionPoint(targetOp);
338 targetOp.getPrivateVarsMutable().clear();
339 targetOp.getPrivateVarsMutable().assign(newPrivVars);
340 });
341 }
342
343private:
344 bool hasPrivateVars(omp::TargetOp targetOp) const {
345 return !targetOp.getPrivateVars().empty();
346 }
347
348 bool isTargetTaskDeferred(omp::TargetOp targetOp) const {
349 return targetOp.getNowait();
350 }
351
352 template <typename OpTy>
353 omp::PrivateClauseOp findPrivatizer(OpTy op, Attribute privSym) const {
354 SymbolRefAttr privatizerName = llvm::cast<SymbolRefAttr>(privSym);
355 omp::PrivateClauseOp privatizer =
357 op, privatizerName);
358 return privatizer;
359 }
360
361 // Get the (compile-time constant) size of varType as per the
362 // given DataLayout dl.
363 std::int64_t getSizeInBytes(const DataLayout &dl, Type varType) const {
364 llvm::TypeSize size = dl.getTypeSize(varType);
365 unsigned short alignment = dl.getTypeABIAlignment(varType);
366 return llvm::alignTo(size, alignment);
367 }
368
369 LLVM::LLVMFuncOp getMalloc(ModuleOp mod, IRRewriter &rewriter) const {
370 llvm::FailureOr<LLVM::LLVMFuncOp> mallocCall =
371 LLVM::lookupOrCreateMallocFn(rewriter, mod, rewriter.getI64Type());
372 assert(llvm::succeeded(mallocCall) &&
373 "Could not find malloc in the module");
374 return mallocCall.value();
375 }
376
377 Value allocateHeapMem(omp::TargetOp targetOp, Value privVar, Type varType,
378 ModuleOp mod, IRRewriter &rewriter) const {
379 OpBuilder::InsertionGuard guard(rewriter);
380 Value varPtr = privVar;
381 Operation *definingOp = varPtr.getDefiningOp();
382 BlockArgument blockArg;
383 if (!definingOp) {
384 blockArg = mlir::dyn_cast<BlockArgument>(varPtr);
385 rewriter.setInsertionPointToStart(blockArg.getParentBlock());
386 } else {
387 rewriter.setInsertionPoint(definingOp);
388 }
389 Location loc = definingOp ? definingOp->getLoc() : blockArg.getLoc();
390 LLVM::LLVMFuncOp mallocFn = getMalloc(mod, rewriter);
391
392 assert(mod.getDataLayoutSpec() &&
393 "MLIR module with no datalayout spec not handled yet");
394
395 const DataLayout &dl = DataLayout(mod);
396 std::int64_t distance = getSizeInBytes(dl, varType);
397
398 Value sizeBytes = LLVM::ConstantOp::create(
399 rewriter, loc, mallocFn.getFunctionType().getParamType(0), distance);
400
401 auto mallocCallOp =
402 LLVM::CallOp::create(rewriter, loc, mallocFn, ValueRange{sizeBytes});
403 return mallocCallOp.getResult();
404 }
405
406 // Create a function for srcRegion and attribute it to be always_inline.
407 // The big assumption here is that srcRegion is one of init, copy or dealloc
408 // regions of a omp::PrivateClauseop. Accordingly, the return type is assumed
409 // to either be the same as the types of the two arguments of the region (for
410 // init and copy regions) or void as would be the case for dealloc regions.
411 LLVM::LLVMFuncOp createFuncOpForRegion(Location loc, ModuleOp mod,
412 Region &srcRegion,
413 llvm::StringRef funcName,
414 IRRewriter &rewriter,
415 bool returnsValue = false) {
416
417 OpBuilder::InsertionGuard guard(rewriter);
418 rewriter.setInsertionPoint(mod.getBody(), mod.getBody()->end());
419 Region clonedRegion;
420 IRMapping mapper;
421 srcRegion.cloneInto(&clonedRegion, mapper);
422
423 SmallVector<Type> paramTypes;
424 llvm::copy(srcRegion.getArgumentTypes(), std::back_inserter(paramTypes));
425 Type resultType = returnsValue
426 ? srcRegion.getArgument(0).getType()
427 : LLVM::LLVMVoidType::get(rewriter.getContext());
428 LLVM::LLVMFunctionType funcType =
429 LLVM::LLVMFunctionType::get(resultType, paramTypes);
430
431 LLVM::LLVMFuncOp func =
432 LLVM::LLVMFuncOp::create(rewriter, loc, funcName, funcType);
433 func.setAlwaysInline(true);
434 rewriter.inlineRegionBefore(clonedRegion, func.getRegion(),
435 func.getRegion().end());
436 for (auto &block : func.getRegion().getBlocks()) {
437 if (isa<omp::YieldOp>(block.getTerminator())) {
438 omp::YieldOp yieldOp = cast<omp::YieldOp>(block.getTerminator());
439 rewriter.setInsertionPoint(yieldOp);
440 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(yieldOp, TypeRange(),
441 yieldOp.getOperands());
442 }
443 }
444 return func;
445 }
446};
447} // namespace
ArrayAttr()
b getContext())
static llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static omp::PrivateClauseOp findPrivatizer(Operation *from, SymbolRefAttr symbolName)
Looks up from the operation from and returns the PrivateClauseOp with name symbolName.
Attributes are known-constant values of operations.
Definition Attributes.h:25
Location getLoc() const
Return the location for this argument.
Definition Value.h:324
IntegerType getI64Type()
Definition Builders.cpp:65
MLIRContext * getContext() const
Definition Builders.h:56
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
uint64_t getTypeABIAlignment(Type t) const
Returns the required alignment of the given type in the current scope.
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:430
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:562
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:436
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:412
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
void replaceUsesOfWith(Value from, Value to)
Replace any uses of 'from' with 'to' within this operation.
user_iterator user_end()
Definition Operation.h:870
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
user_iterator user_begin()
Definition Operation.h:869
virtual void runOnOperation()=0
The polymorphic API that runs the pass over the currently held operation.
bool empty()
Definition Region.h:60
void cloneInto(Region *dest, IRMapping &mapper)
Clone the internal blocks from this region into dest.
Definition Region.cpp:70
ValueTypeRange< BlockArgListType > getArgumentTypes()
Returns the argument types of the first block within the region.
Definition Region.cpp:36
BlockArgument getArgument(unsigned i)
Definition Region.h:124
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
void replaceAllOpUsesWith(Operation *from, ValueRange to)
Find uses of from and replace them with to.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition Value.cpp:46
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateFreeFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateMallocFn(OpBuilder &b, Operation *moduleOp, Type indexType, SymbolTableCollection *symbolTables=nullptr)
Include the generated interface declarations.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:128