23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/Support/MathExtras.h"
28 #define GEN_PASS_DEF_OPTIMIZESHAREDMEMORY
29 #include "mlir/Dialect/NVGPU/Passes.h.inc"
49 int64_t srcDim, int64_t tgtDim) {
52 Value src = indices[srcDim];
56 const int64_t permuteEveryN = std::max<int64_t>(
58 memrefTy.getElementTypeBitWidth()) /
72 int64_t m = llvm::Log2_64(memrefTy.getDimSize(tgtDim));
75 int64_t mask = (1LL << (m - n)) - 1;
76 if (permuteEveryN > 1)
77 mask = mask << llvm::Log2_64(permuteEveryN);
78 Value srcBits = b.
create<arith::ConstantIndexOp>(loc, mask);
79 srcBits = b.
create<arith::AndIOp>(loc, src, srcBits);
83 if (permuteEveryN > 1) {
84 int64_t shlBits = n - llvm::Log2_64(permuteEveryN);
86 Value finalShiftVal = b.
create<arith::ConstantIndexOp>(loc, shlBits);
87 srcBits = b.
createOrFold<arith::ShLIOp>(loc, srcBits, finalShiftVal);
88 }
else if (shlBits < 0) {
89 Value finalShiftVal = b.
create<arith::ConstantIndexOp>(loc, -1 * shlBits);
90 srcBits = b.
createOrFold<arith::ShRUIOp>(loc, srcBits, finalShiftVal);
93 Value finalShiftVal = b.
create<arith::ConstantIndexOp>(loc, n);
94 srcBits = b.
createOrFold<arith::ShLIOp>(loc, srcBits, finalShiftVal);
97 Value permutedVectorIdx =
98 b.
create<arith::XOrIOp>(loc, indices[tgtDim], srcBits);
99 return permutedVectorIdx;
104 MemRefType memrefTy, int64_t srcDim,
111 if (
auto ldmatrixOp = dyn_cast<LdMatrixOp>(op))
112 return ldmatrixOp.getIndices();
113 if (
auto copyOp = dyn_cast<DeviceAsyncCopyOp>(op))
114 return copyOp.getDstIndices();
115 if (
auto loadOp = dyn_cast<memref::LoadOp>(op))
116 return loadOp.getIndices();
117 if (
auto storeOp = dyn_cast<memref::StoreOp>(op))
118 return storeOp.getIndices();
119 if (
auto vectorReadOp = dyn_cast<vector::LoadOp>(op))
120 return vectorReadOp.getIndices();
121 if (
auto vectorStoreOp = dyn_cast<vector::StoreOp>(op))
122 return vectorStoreOp.getIndices();
123 llvm_unreachable(
"unsupported op type");
127 if (
auto ldmatrixOp = dyn_cast<LdMatrixOp>(op))
128 return ldmatrixOp.getIndicesMutable().assign(indices);
129 if (
auto copyOp = dyn_cast<DeviceAsyncCopyOp>(op))
130 return copyOp.getDstIndicesMutable().assign(indices);
131 if (
auto loadOp = dyn_cast<memref::LoadOp>(op))
132 return loadOp.getIndicesMutable().assign(indices);
133 if (
auto storeOp = dyn_cast<memref::StoreOp>(op))
134 return storeOp.getIndicesMutable().assign(indices);
135 if (
auto vectorReadOp = dyn_cast<vector::LoadOp>(op))
136 return vectorReadOp.getIndicesMutable().assign(indices);
137 if (
auto vectorStoreOp = dyn_cast<vector::StoreOp>(op))
138 return vectorStoreOp.getIndicesMutable().assign(indices);
139 llvm_unreachable(
"unsupported op type");
149 MemoryEffectOpInterface iface = dyn_cast<MemoryEffectOpInterface>(op);
152 std::optional<MemoryEffects::EffectInstance> effect =
155 readOps.push_back(op);
160 writeOps.push_back(op);
165 if (llvm::any_of(readOps, [](
Operation *op) {
166 return !isa<memref::LoadOp, vector::LoadOp, nvgpu::LdMatrixOp>(op) ||
170 if (llvm::any_of(writeOps, [](
Operation *op) {
171 return !isa<memref::StoreOp, vector::StoreOp, nvgpu::DeviceAsyncCopyOp>(
184 if (!memRefType || !NVGPUDialect::hasSharedMemoryAddressSpace(memRefType))
189 bool hasSubView =
false;
190 parentOp->
walk([&](memref::SubViewOp subView) { hasSubView =
true; });
196 const int64_t rowSize = memRefType.getDimSize(memRefType.getRank() - 1);
197 const int64_t rowsPerLine =
200 const int64_t threadGroupSize =
202 if (rowsPerLine >= threadGroupSize)
213 if (shmReadOps.empty() || shmWriteOps.empty())
218 int64_t tgtDim = memRefType.getRank() - 1;
219 int64_t srcDim = memRefType.getRank() - 2;
222 while (!shmWriteOps.empty()) {
223 Operation *shmWriteOp = shmWriteOps.back();
224 shmWriteOps.pop_back();
230 memRefType, srcDim, tgtDim);
235 while (!shmReadOps.empty()) {
236 Operation *shmReadOp = shmReadOps.back();
237 shmReadOps.pop_back();
243 memRefType, srcDim, tgtDim);
251 class OptimizeSharedMemoryPass
252 :
public nvgpu::impl::OptimizeSharedMemoryBase<OptimizeSharedMemoryPass> {
254 OptimizeSharedMemoryPass() =
default;
256 void runOnOperation()
override {
259 op->
walk([&](memref::AllocOp allocOp) {
260 if (!NVGPUDialect::hasSharedMemoryAddressSpace(allocOp.getType()))
262 shmAllocOps.push_back(allocOp);
264 for (
auto allocOp : shmAllocOps) {
266 allocOp.getMemref())))
274 return std::make_unique<OptimizeSharedMemoryPass>();
constexpr int64_t kSharedMemoryLineSizeBytes
The size of a shared memory line according to NV documentation.
static void transformIndices(OpBuilder &builder, Location loc, SmallVector< Value, 4 > &indices, MemRefType memrefTy, int64_t srcDim, int64_t tgtDim)
constexpr int64_t kDefaultVectorSizeBits
We optimize for 128bit accesses, but this can be made an argument in the future.
static Value permuteVectorOffset(OpBuilder &b, Location loc, ArrayRef< Value > indices, MemRefType memrefTy, int64_t srcDim, int64_t tgtDim)
Uses srcIndexValue to permute tgtIndexValue via `result = xor(floordiv(srcIdxVal,permuteEveryN),...
static LogicalResult getShmReadAndWriteOps(Operation *parentOp, Value shmMemRef, SmallVector< Operation *, 16 > &readOps, SmallVector< Operation *, 16 > &writeOps)
Return all operations within parentOp that read from or write to shmMemRef.
Operation::operand_range getIndices(Operation *op)
void setIndices(Operation *op, ArrayRef< Value > indices)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
mlir::LogicalResult optimizeSharedMemoryReadsAndWrites(Operation *parentOp, Value memrefValue)
Passes.
std::unique_ptr< Pass > createOptimizeSharedMemoryPass()
Create a pass to optimize shared memory reads and writes.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
The following effect indicates that the operation reads from some resource.
The following effect indicates that the operation writes to some resource.