MLIR 23.0.0git
StackToShared.cpp
Go to the documentation of this file.
1//===- StackToShared.cpp -------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements transforms to swap stack allocations on the target
10// device with device shared memory where applicable.
11//
12//===----------------------------------------------------------------------===//
13
15
19#include "mlir/Pass/Pass.h"
20#include "llvm/ADT/STLExtras.h"
21
22namespace mlir {
23namespace omp {
24#define GEN_PASS_DEF_STACKTOSHAREDPASS
25#include "mlir/Dialect/OpenMP/Transforms/Passes.h.inc"
26} // namespace omp
27} // namespace mlir
28
29using namespace mlir;
30
31/// Tell whether to replace an operation representing a stack allocation with a
32/// device shared memory allocation/deallocation pair based on the location of
33/// the allocation and its uses.
36 llvm::any_of(op.getResults(), [&](Value result) {
37 return omp::allocaUsesRequireSharedMem(result);
38 });
39}
40
41/// Based on the location of the definition of the given value representing the
42/// result of a device shared memory allocation, find the corresponding points
43/// where its deallocation should be placed and introduce `omp.free_shared_mem`
44/// ops at those points.
46 TypeAttr elemType,
47 Value arraySize,
48 IntegerAttr alignment,
49 Value allocVal) {
50 Block *allocaBlock = allocVal.getParentBlock();
51 DominanceInfo domInfo;
52 for (Block &block : allocVal.getParentRegion()->getBlocks()) {
53 Operation *terminator = block.getTerminator();
54 if (!terminator->hasSuccessors() &&
55 domInfo.dominates(allocaBlock, &block)) {
56 builder.setInsertionPoint(terminator);
57 omp::FreeSharedMemOp::create(builder, allocVal.getLoc(), elemType,
58 arraySize, alignment, allocVal);
59 }
60 }
61}
62
63namespace {
64class StackToSharedPass
65 : public omp::impl::StackToSharedPassBase<StackToSharedPass> {
66public:
67 StackToSharedPass() = default;
68
69 void runOnOperation() override {
70 MLIRContext *context = &getContext();
71 OpBuilder builder(context);
72
73 LLVM::LLVMFuncOp funcOp = getOperation();
74 auto offloadIface = funcOp->getParentOfType<omp::OffloadModuleInterface>();
75 if (!offloadIface || !offloadIface.getIsTargetDevice())
76 return;
77
78 llvm::SmallVector<Operation *> toBeDeleted;
79 funcOp->walk([&](LLVM::AllocaOp allocaOp) {
81 return;
82 // Replace llvm.alloca with omp.alloc_shared_mem.
83 Type resultType = allocaOp.getResult().getType();
84
85 // TODO: The handling of non-default address spaces might need to be
86 // improved. This currently only handles the case where an alloca to
87 // non-default address space is only used by a single addrspacecast to
88 // default address space.
89 bool nonDefaultAddrSpace = false;
90 if (auto llvmPtrType = dyn_cast<LLVM::LLVMPointerType>(resultType))
91 nonDefaultAddrSpace = llvmPtrType.getAddressSpace() != 0;
92
93 builder.setInsertionPoint(allocaOp);
94 auto sharedAllocOp = omp::AllocSharedMemOp::create(
95 builder, allocaOp->getLoc(), LLVM::LLVMPointerType::get(context),
96 allocaOp.getElemTypeAttr(), allocaOp.getArraySize(),
97 allocaOp.getAlignmentAttr());
98 if (nonDefaultAddrSpace) {
99 assert(allocaOp->hasOneUse() && " unsupported non-default address "
100 "space alloca with multiple uses");
101 auto asCastOp =
102 cast<LLVM::AddrSpaceCastOp>(*allocaOp->getUsers().begin());
103 asCastOp.replaceAllUsesWith(sharedAllocOp.getOperation());
104 // Delete later because we can't delete the cast op before the top-level
105 // iteration visits it. Also, the alloca can't be deleted before because
106 // it's used by it.
107 toBeDeleted.push_back(asCastOp);
108 toBeDeleted.push_back(allocaOp);
109 } else {
110 allocaOp.replaceAllUsesWith(sharedAllocOp.getOperation());
111 allocaOp.erase();
112 }
113
114 // Create a new omp.free_shared_mem for the allocated buffer prior to
115 // exiting the region.
117 builder, sharedAllocOp.getMemElemTypeAttr(),
118 sharedAllocOp.getMemArraySize(), sharedAllocOp.getMemAlignmentAttr(),
119 sharedAllocOp.getResult());
120 });
121 for (Operation *op : toBeDeleted)
122 op->erase();
123 }
124};
125} // namespace
b getContext())
static void insertDeviceSharedMemDeallocation(OpBuilder &builder, TypeAttr elemType, Value arraySize, IntegerAttr alignment, Value allocVal)
Based on the location of the definition of the given value representing the result of a device shared...
static bool shouldReplaceAllocaWithDeviceSharedMem(Operation &op)
Tell whether to replace an operation representing a stack allocation with a device shared memory allo...
Block represents an ordered list of Operations.
Definition Block.h:33
A class for computing basic dominance information.
Definition Dominance.h:140
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
Definition Dominance.h:158
This class helps build Operations.
Definition Builders.h:209
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
bool hasSuccessors()
Definition Operation.h:731
result_range getResults()
Definition Operation.h:441
BlockListType & getBlocks()
Definition Region.h:45
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition Value.cpp:46
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition Value.cpp:39
bool opInSharedDeviceContext(Operation &op)
Check whether the given operation is located in a context where an allocation to be used by multiple ...
Definition Utils.cpp:66
Include the generated interface declarations.