24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/Support/MathExtras.h"
29 #define GEN_PASS_DEF_OPTIMIZESHAREDMEMORY
30 #include "mlir/Dialect/NVGPU/Transforms/Passes.h.inc"
50 int64_t srcDim, int64_t tgtDim) {
53 Value src = indices[srcDim];
57 const int64_t permuteEveryN = std::max<int64_t>(
59 memrefTy.getElementTypeBitWidth()) /
73 int64_t m = llvm::Log2_64(memrefTy.getDimSize(tgtDim));
76 int64_t mask = (1LL << (m - n)) - 1;
77 if (permuteEveryN > 1)
78 mask = mask << llvm::Log2_64(permuteEveryN);
79 Value srcBits = b.
create<arith::ConstantIndexOp>(loc, mask);
80 srcBits = b.
create<arith::AndIOp>(loc, src, srcBits);
84 if (permuteEveryN > 1) {
85 int64_t shlBits = n - llvm::Log2_64(permuteEveryN);
87 Value finalShiftVal = b.
create<arith::ConstantIndexOp>(loc, shlBits);
88 srcBits = b.
createOrFold<arith::ShLIOp>(loc, srcBits, finalShiftVal);
89 }
else if (shlBits < 0) {
90 Value finalShiftVal = b.
create<arith::ConstantIndexOp>(loc, -1 * shlBits);
91 srcBits = b.
createOrFold<arith::ShRUIOp>(loc, srcBits, finalShiftVal);
94 Value finalShiftVal = b.
create<arith::ConstantIndexOp>(loc, n);
95 srcBits = b.
createOrFold<arith::ShLIOp>(loc, srcBits, finalShiftVal);
98 Value permutedVectorIdx =
99 b.
create<arith::XOrIOp>(loc, indices[tgtDim], srcBits);
100 return permutedVectorIdx;
105 MemRefType memrefTy, int64_t srcDim,
118 MemoryEffectOpInterface iface = dyn_cast<MemoryEffectOpInterface>(op);
121 std::optional<MemoryEffects::EffectInstance> effect =
124 readOps.push_back(op);
129 writeOps.push_back(op);
134 if (llvm::any_of(readOps, [](
Operation *op) {
135 return !isa<memref::LoadOp, vector::LoadOp, nvgpu::LdMatrixOp>(op) ||
139 if (llvm::any_of(writeOps, [](
Operation *op) {
140 return !isa<memref::StoreOp, vector::StoreOp, nvgpu::DeviceAsyncCopyOp>(
152 auto memRefType = dyn_cast<MemRefType>(memrefValue.
getType());
153 if (!memRefType || !NVGPUDialect::hasSharedMemoryAddressSpace(memRefType))
158 bool hasSubView =
false;
159 parentOp->
walk([&](memref::SubViewOp subView) { hasSubView =
true; });
165 const int64_t rowSize = memRefType.getDimSize(memRefType.getRank() - 1);
166 const int64_t rowsPerLine =
169 const int64_t threadGroupSize =
171 if (rowsPerLine >= threadGroupSize)
182 if (shmReadOps.empty() || shmWriteOps.empty())
187 int64_t tgtDim = memRefType.getRank() - 1;
188 int64_t srcDim = memRefType.getRank() - 2;
191 while (!shmWriteOps.empty()) {
192 Operation *shmWriteOp = shmWriteOps.back();
193 shmWriteOps.pop_back();
199 memRefType, srcDim, tgtDim);
204 while (!shmReadOps.empty()) {
205 Operation *shmReadOp = shmReadOps.back();
206 shmReadOps.pop_back();
212 memRefType, srcDim, tgtDim);
220 class OptimizeSharedMemoryPass
221 :
public nvgpu::impl::OptimizeSharedMemoryBase<OptimizeSharedMemoryPass> {
223 OptimizeSharedMemoryPass() =
default;
225 void runOnOperation()
override {
228 op->
walk([&](memref::AllocOp allocOp) {
229 if (!NVGPUDialect::hasSharedMemoryAddressSpace(allocOp.getType()))
231 shmAllocOps.push_back(allocOp);
233 for (
auto allocOp : shmAllocOps) {
235 allocOp.getMemref())))
243 return std::make_unique<OptimizeSharedMemoryPass>();
constexpr int64_t kSharedMemoryLineSizeBytes
The size of a shared memory line according to NV documentation.
static void transformIndices(OpBuilder &builder, Location loc, SmallVector< Value, 4 > &indices, MemRefType memrefTy, int64_t srcDim, int64_t tgtDim)
constexpr int64_t kDefaultVectorSizeBits
We optimize for 128bit accesses, but this can be made an argument in the future.
static Value permuteVectorOffset(OpBuilder &b, Location loc, ArrayRef< Value > indices, MemRefType memrefTy, int64_t srcDim, int64_t tgtDim)
Uses srcIndexValue to permute tgtIndexValue via `result = xor(floordiv(srcIdxVal,permuteEveryN),...
static LogicalResult getShmReadAndWriteOps(Operation *parentOp, Value shmMemRef, SmallVector< Operation *, 16 > &readOps, SmallVector< Operation *, 16 > &writeOps)
Return all operations within parentOp that read from or write to shmMemRef.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
mlir::LogicalResult optimizeSharedMemoryReadsAndWrites(Operation *parentOp, Value memrefValue)
Passes.
std::unique_ptr< Pass > createOptimizeSharedMemoryPass()
Create a pass to optimize shared memory reads and writes.
void setIndices(Operation *op, ArrayRef< Value > indices)
Set the indices that the given load/store operation is operating on.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
The following effect indicates that the operation reads from some resource.
The following effect indicates that the operation writes to some resource.