22#include "llvm/ADT/STLExtras.h"
23#include "llvm/Support/MathExtras.h"
27#define GEN_PASS_DEF_OPTIMIZESHAREDMEMORY
28#include "mlir/Dialect/NVGPU/Transforms/Passes.h.inc"
55 const int64_t permuteEveryN = std::max<int64_t>(
57 memrefTy.getElementTypeBitWidth()) /
71 int64_t m = llvm::Log2_64(memrefTy.getDimSize(tgtDim));
74 int64_t mask = (1LL << (m - n)) - 1;
75 if (permuteEveryN > 1)
76 mask = mask << llvm::Log2_64(permuteEveryN);
78 srcBits = arith::AndIOp::create(
b, loc, src, srcBits);
82 if (permuteEveryN > 1) {
83 int64_t shlBits = n - llvm::Log2_64(permuteEveryN);
86 srcBits =
b.createOrFold<arith::ShLIOp>(loc, srcBits, finalShiftVal);
87 }
else if (shlBits < 0) {
90 srcBits =
b.createOrFold<arith::ShRUIOp>(loc, srcBits, finalShiftVal);
94 srcBits =
b.createOrFold<arith::ShLIOp>(loc, srcBits, finalShiftVal);
97 Value permutedVectorIdx =
98 arith::XOrIOp::create(
b, loc,
indices[tgtDim], srcBits);
99 return permutedVectorIdx;
104 MemRefType memrefTy,
int64_t srcDim,
117 MemoryEffectOpInterface iface = dyn_cast<MemoryEffectOpInterface>(op);
120 std::optional<MemoryEffects::EffectInstance> effect =
123 readOps.push_back(op);
128 writeOps.push_back(op);
133 if (llvm::any_of(readOps, [](
Operation *op) {
134 return !isa<memref::LoadOp, vector::LoadOp, nvgpu::LdMatrixOp>(op) ||
138 if (llvm::any_of(writeOps, [](
Operation *op) {
139 return !isa<memref::StoreOp, vector::StoreOp, nvgpu::DeviceAsyncCopyOp>(
151 auto memRefType = dyn_cast<MemRefType>(memrefValue.
getType());
152 if (!memRefType || !NVGPUDialect::hasSharedMemoryAddressSpace(memRefType))
156 if (memRefType.getRank() == 0)
161 if (!memRefType.getElementType().isIntOrFloat())
166 bool hasSubView =
false;
167 parentOp->
walk([&](memref::SubViewOp subView) { hasSubView =
true; });
173 const int64_t rowSize = memRefType.getDimSize(memRefType.getRank() - 1);
174 if (ShapedType::isDynamic(rowSize) || rowSize == 0)
180 const int64_t threadGroupSize =
182 if (rowsPerLine >= threadGroupSize)
193 if (shmReadOps.empty() || shmWriteOps.empty())
198 int64_t tgtDim = memRefType.getRank() - 1;
199 int64_t srcDim = memRefType.getRank() - 2;
202 while (!shmWriteOps.empty()) {
203 Operation *shmWriteOp = shmWriteOps.back();
204 shmWriteOps.pop_back();
210 memRefType, srcDim, tgtDim);
215 while (!shmReadOps.empty()) {
216 Operation *shmReadOp = shmReadOps.back();
217 shmReadOps.pop_back();
223 memRefType, srcDim, tgtDim);
231class OptimizeSharedMemoryPass
232 :
public nvgpu::impl::OptimizeSharedMemoryBase<OptimizeSharedMemoryPass> {
234 OptimizeSharedMemoryPass() =
default;
236 void runOnOperation()
override {
239 op->
walk([&](memref::AllocOp allocOp) {
240 if (!NVGPUDialect::hasSharedMemoryAddressSpace(allocOp.getType()))
242 shmAllocOps.push_back(allocOp);
244 for (
auto allocOp : shmAllocOps) {
246 allocOp.getMemref())))
254 return std::make_unique<OptimizeSharedMemoryPass>();
constexpr int64_t kSharedMemoryLineSizeBytes
The size of a shared memory line according to NV documentation.
static void transformIndices(OpBuilder &builder, Location loc, SmallVector< Value, 4 > &indices, MemRefType memrefTy, int64_t srcDim, int64_t tgtDim)
constexpr int64_t kDefaultVectorSizeBits
We optimize for 128bit accesses, but this can be made an argument in the future.
static Value permuteVectorOffset(OpBuilder &b, Location loc, ArrayRef< Value > indices, MemRefType memrefTy, int64_t srcDim, int64_t tgtDim)
Uses srcIndexValue to permute tgtIndexValue via result = xor(floordiv(srcIdxVal,permuteEveryN),...
static LogicalResult getShmReadAndWriteOps(Operation *parentOp, Value shmMemRef, SmallVector< Operation *, 16 > &readOps, SmallVector< Operation *, 16 > &writeOps)
Return all operations within parentOp that read from or write to shmMemRef.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MLIRContext * getContext()
Return the context this operation is associated with.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
llvm::LogicalResult optimizeSharedMemoryReadsAndWrites(Operation *parentOp, Value memrefValue)
Passes.
std::unique_ptr< Pass > createOptimizeSharedMemoryPass()
Create a pass to optimize shared memory reads and writes.
void setIndices(Operation *op, ArrayRef< Value > indices)
Set the indices that the given load/store operation is operating on.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Include the generated interface declarations.
The following effect indicates that the operation reads from some resource.
The following effect indicates that the operation writes to some resource.