18 #include "llvm/Support/DebugLog.h"
19 #include "llvm/Support/FormatVariadic.h"
30 #define DEBUG_TYPE "omp-prepare-for-offload-privatization"
35 #define GEN_PASS_DEF_PREPAREFOROMPOFFLOADPRIVATIZATIONPASS
36 #include "mlir/Dialect/OpenMP/Transforms/Passes.h.inc"
48 class PrepareForOMPOffloadPrivatizationPass
49 :
public omp::impl::PrepareForOMPOffloadPrivatizationPassBase<
50 PrepareForOMPOffloadPrivatizationPass> {
52 void runOnOperation()
override {
53 ModuleOp mod = getOperation();
59 auto offloadModuleInterface =
60 dyn_cast<omp::OffloadModuleInterface>(mod.getOperation());
61 if (offloadModuleInterface && offloadModuleInterface.getIsTargetDevice())
64 getOperation()->walk([&](omp::TargetOp targetOp) {
65 if (!hasPrivateVars(targetOp) || !isTargetTaskDeferred(targetOp))
71 omp::TaskOp cleanupTaskOp;
73 newPrivVars.reserve(privateVars.size());
74 std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms();
75 for (
auto [privVarIdx, privVarSymPair] :
77 Value privVar = std::get<0>(privVarSymPair);
78 Attribute privSym = std::get<1>(privVarSymPair);
80 omp::PrivateClauseOp privatizer =
findPrivatizer(targetOp, privSym);
81 if (!privatizer.needsMap()) {
82 newPrivVars.push_back(privVar);
85 bool isFirstPrivate = privatizer.getDataSharingType() ==
86 omp::DataSharingClauseType::FirstPrivate;
88 Value mappedValue = targetOp.getMappedValueForPrivateVar(privVarIdx);
89 auto mapInfoOp = cast<omp::MapInfoOp>(mappedValue.
getDefiningOp());
91 if (mapInfoOp.getMapCaptureType() == omp::VariableCaptureKind::ByCopy) {
92 newPrivVars.push_back(privVar);
123 bool needsCleanupTask = !privatizer.getDeallocRegion().empty();
129 Value varPtr = mapInfoOp.getVarPtr();
130 Type varType = mapInfoOp.getVarType();
131 bool isPrivatizedByValue =
132 !isa<LLVM::LLVMPointerType>(privVar.
getType());
134 assert(isa<LLVM::LLVMPointerType>(varPtr.
getType()));
136 allocateHeapMem(targetOp, varPtr, varType, mod, rewriter);
139 "Unable to allocate heap memory when trying to move "
140 "a private variable out of the stack and into the "
141 "heap for use by a deferred target task");
143 if (needsCleanupTask && !fakeDependVar)
144 fakeDependVar = heapMem;
154 if (!isPrivatizedByValue)
155 newPrivVars.push_back(heapMem);
165 if (varPtrDefiningOp) {
169 auto blockArg = cast<BlockArgument>(varPtr);
170 users.insert(blockArg.user_begin(), blockArg.user_end());
172 auto usesVarPtr = [&users](
Operation *op) ->
bool {
173 return users.count(op);
177 chainOfOps.push_back(mapInfoOp);
178 for (
auto member : mapInfoOp.getMembers()) {
179 omp::MapInfoOp memberMap =
180 cast<omp::MapInfoOp>(member.getDefiningOp());
181 if (usesVarPtr(memberMap))
182 chainOfOps.push_back(memberMap);
183 if (memberMap.getVarPtrPtr()) {
184 Operation *defOp = memberMap.getVarPtrPtr().getDefiningOp();
185 if (defOp && usesVarPtr(defOp))
186 chainOfOps.push_back(defOp);
204 auto createAlwaysInlineFuncAndCallIt =
205 [&](
Region ®ion, llvm::StringRef funcName,
207 assert(!region.
empty() &&
"region cannot be empty");
208 LLVM::LLVMFuncOp func = createFuncOpForRegion(
209 loc, mod, region, funcName, rewriter, returnsValue);
210 auto call = LLVM::CallOp::create(rewriter, loc, func, args);
211 return call.getResult();
214 Value moldArg, newArg;
215 if (isPrivatizedByValue) {
216 moldArg = LLVM::LoadOp::create(rewriter, loc, varType, varPtr);
217 newArg = LLVM::LoadOp::create(rewriter, loc, varType, heapMem);
223 Value initializedVal;
224 if (!privatizer.getInitRegion().empty())
225 initializedVal = createAlwaysInlineFuncAndCallIt(
226 privatizer.getInitRegion(),
227 llvm::formatv(
"{0}_{1}", privatizer.getSymName(),
"init").str(),
228 {moldArg, newArg},
true);
230 initializedVal = newArg;
232 if (isFirstPrivate && !privatizer.getCopyRegion().empty())
233 initializedVal = createAlwaysInlineFuncAndCallIt(
234 privatizer.getCopyRegion(),
235 llvm::formatv(
"{0}_{1}", privatizer.getSymName(),
"copy").str(),
236 {moldArg, initializedVal},
true);
238 if (isPrivatizedByValue)
239 (void)LLVM::StoreOp::create(rewriter, loc, initializedVal, heapMem);
257 mapInfoOp = cast<omp::MapInfoOp>(cloneModifyAndErase(mapInfoOp));
261 for (
auto member : mapInfoOp.getMembers()) {
262 auto memberMapInfoOp = cast<omp::MapInfoOp>(member.getDefiningOp());
263 if (!usesVarPtr(memberMapInfoOp))
266 cast<omp::MapInfoOp>(cloneModifyAndErase(memberMapInfoOp));
269 if (memberMapInfoOp.getVarPtrPtr()) {
271 memberMapInfoOp.getVarPtrPtr().getDefiningOp();
281 if (isPrivatizedByValue) {
283 auto newPrivVar = LLVM::LoadOp::create(rewriter, mapInfoOp.getLoc(),
285 newPrivVars.push_back(newPrivVar);
289 if (needsCleanupTask) {
290 if (!cleanupTaskOp) {
291 assert(fakeDependVar &&
292 "Need a valid value to set up a dependency");
294 omp::TaskOperands taskOperands;
296 rewriter.
getContext(), omp::ClauseTaskDepend::taskdependin);
297 taskOperands.dependKinds.push_back(inDepend);
298 taskOperands.dependVars.push_back(fakeDependVar);
299 cleanupTaskOp = omp::TaskOp::create(rewriter, loc, taskOperands);
302 omp::TerminatorOp::create(rewriter, cleanupTaskOp.getLoc());
305 &*cleanupTaskOp.getRegion().getBlocks().begin());
306 (void)createAlwaysInlineFuncAndCallIt(
307 privatizer.getDeallocRegion(),
308 llvm::formatv(
"{0}_{1}", privatizer.getSymName(),
"dealloc")
310 {initializedVal},
false);
311 llvm::FailureOr<LLVM::LLVMFuncOp> freeFunc =
313 assert(llvm::succeeded(freeFunc) &&
314 "Could not find free in the module");
315 (void)LLVM::CallOp::create(rewriter, loc, freeFunc.value(),
319 assert(newPrivVars.size() == privateVars.size() &&
320 "The number of private variables must match before and after "
324 rewriter.
getContext(), omp::ClauseTaskDepend::taskdependout);
326 if (!targetOp.getDependVars().empty()) {
327 std::optional<ArrayAttr> dependKinds = targetOp.getDependKinds();
328 assert(dependKinds &&
"bad depend clause in omp::TargetOp");
329 llvm::copy(*dependKinds, std::back_inserter(newDependKinds));
331 newDependKinds.push_back(outDepend);
332 ArrayAttr newDependKindsAttr =
334 targetOp.getDependVarsMutable().append(fakeDependVar);
335 targetOp.setDependKindsAttr(newDependKindsAttr);
338 targetOp.getPrivateVarsMutable().clear();
339 targetOp.getPrivateVarsMutable().assign(newPrivVars);
344 bool hasPrivateVars(omp::TargetOp targetOp)
const {
345 return !targetOp.getPrivateVars().empty();
348 bool isTargetTaskDeferred(omp::TargetOp targetOp)
const {
349 return targetOp.getNowait();
352 template <
typename OpTy>
354 SymbolRefAttr privatizerName = llvm::cast<SymbolRefAttr>(privSym);
355 omp::PrivateClauseOp privatizer =
356 SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(
366 return llvm::alignTo(size, alignment);
369 LLVM::LLVMFuncOp getMalloc(ModuleOp mod,
IRRewriter &rewriter)
const {
370 llvm::FailureOr<LLVM::LLVMFuncOp> mallocCall =
372 assert(llvm::succeeded(mallocCall) &&
373 "Could not find malloc in the module");
374 return mallocCall.value();
377 Value allocateHeapMem(omp::TargetOp targetOp,
Value privVar,
Type varType,
380 Value varPtr = privVar;
384 blockArg = mlir::dyn_cast<BlockArgument>(varPtr);
390 LLVM::LLVMFuncOp mallocFn = getMalloc(mod, rewriter);
392 assert(mod.getDataLayoutSpec() &&
393 "MLIR module with no datalayout spec not handled yet");
398 Value sizeBytes = LLVM::ConstantOp::create(
399 rewriter, loc, mallocFn.getFunctionType().getParamType(0), distance);
402 LLVM::CallOp::create(rewriter, loc, mallocFn,
ValueRange{sizeBytes});
403 return mallocCallOp.getResult();
411 LLVM::LLVMFuncOp createFuncOpForRegion(
Location loc, ModuleOp mod,
413 llvm::StringRef funcName,
415 bool returnsValue =
false) {
421 srcRegion.
cloneInto(&clonedRegion, mapper);
425 Type resultType = returnsValue
428 LLVM::LLVMFunctionType funcType =
431 LLVM::LLVMFuncOp func =
432 LLVM::LLVMFuncOp::create(rewriter, loc, funcName, funcType);
433 func.setAlwaysInline(
true);
435 func.getRegion().end());
436 for (
auto &block : func.getRegion().getBlocks()) {
437 if (isa<omp::YieldOp>(block.getTerminator())) {
438 omp::YieldOp yieldOp = cast<omp::YieldOp>(block.getTerminator());
441 yieldOp.getOperands());
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static MLIRContext * getContext(OpFoldResult val)
static llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static omp::PrivateClauseOp findPrivatizer(Operation *from, SymbolRefAttr symbolName)
Looks up from the operation from and returns the PrivateClauseOp with name symbolName.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Location getLoc() const
Return the location for this argument.
Block represents an ordered list of Operations.
MLIRContext * getContext() const
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
uint64_t getTypeABIAlignment(Type t) const
Returns the required alignment of the given type in the current scope.
A class for computing basic dominance information.
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
This is a utility class for mapping one set of IR entities to another.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
void replaceUsesOfWith(Value from, Value to)
Replace any uses of 'from' with 'to' within this operation.
Location getLoc()
The source location the operation was defined or derived from.
user_iterator user_begin()
This class contains a list of basic blocks and a link to the parent operation it is attached to.
void cloneInto(Region *dest, IRMapping &mapper)
Clone the internal blocks from this region into dest.
ValueTypeRange< BlockArgListType > getArgumentTypes()
Returns the argument types of the first block within the region.
BlockArgument getArgument(unsigned i)
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
void replaceAllOpUsesWith(Operation *from, ValueRange to)
Find uses of from and replace them with to.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateFreeFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateMallocFn(OpBuilder &b, Operation *moduleOp, Type indexType, SymbolTableCollection *symbolTables=nullptr)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...