MLIR 22.0.0git
OptimizeSharedMemory.cpp
Go to the documentation of this file.
1//===- OptimizeSharedMemory.cpp - MLIR NVGPU pass implementation ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements transforms to optimize accesses to shared memory.
10//
11//===----------------------------------------------------------------------===//
12
14
22#include "llvm/ADT/STLExtras.h"
23#include "llvm/Support/MathExtras.h"
25namespace mlir {
26namespace nvgpu {
27#define GEN_PASS_DEF_OPTIMIZESHAREDMEMORY
28#include "mlir/Dialect/NVGPU/Transforms/Passes.h.inc"
29} // namespace nvgpu
30} // namespace mlir
31
32using namespace mlir;
33using namespace mlir::nvgpu;
34
35/// The size of a shared memory line according to NV documentation.
37/// We optimize for 128bit accesses, but this can be made an argument in the
38/// future.
41/// Uses `srcIndexValue` to permute `tgtIndexValue` via
42/// `result = xor(floordiv(srcIdxVal,permuteEveryN),
43/// floordiv(tgtIdxVal,vectorSize)))
44/// + tgtIdxVal % vectorSize`
45/// This is done using an optimized sequence of `arith` operations.
47 ArrayRef<Value> indices, MemRefType memrefTy,
48 int64_t srcDim, int64_t tgtDim) {
49 // Adjust the src index to change how often the permutation changes
50 // if necessary.
51 Value src = indices[srcDim];
52
53 // We only want to permute every N iterations of the target dim where N is
54 // ceil(sharedMemoryLineSizeBytes / dimSizeBytes(tgtDim)).
55 const int64_t permuteEveryN = std::max<int64_t>(
56 1, kSharedMemoryLineSizeBytes / ((memrefTy.getDimSize(tgtDim) *
57 memrefTy.getElementTypeBitWidth()) /
58 8));
59
60 // clang-format off
61 // Index bit representation (b0 = least significant bit) for dim(1)
62 // of a `memref<?x?xDT>` is as follows:
63 // N := log2(128/elementSizeBits)
64 // M := log2(dimSize(1))
65 // then
66 // bits[0:N] = sub-vector element offset
67 // bits[N:M] = vector index
68 // clang-format on
69 int64_t n =
70 llvm::Log2_64(kDefaultVectorSizeBits / memrefTy.getElementTypeBitWidth());
71 int64_t m = llvm::Log2_64(memrefTy.getDimSize(tgtDim));
72
73 // Capture bits[0:(M-N)] of src by first creating a (M-N) mask.
74 int64_t mask = (1LL << (m - n)) - 1;
75 if (permuteEveryN > 1)
76 mask = mask << llvm::Log2_64(permuteEveryN);
77 Value srcBits = arith::ConstantIndexOp::create(b, loc, mask);
78 srcBits = arith::AndIOp::create(b, loc, src, srcBits);
79
80 // Use the src bits to permute the target bits b[N:M] containing the
81 // vector offset.
82 if (permuteEveryN > 1) {
83 int64_t shlBits = n - llvm::Log2_64(permuteEveryN);
84 if (shlBits > 0) {
85 Value finalShiftVal = arith::ConstantIndexOp::create(b, loc, shlBits);
86 srcBits = b.createOrFold<arith::ShLIOp>(loc, srcBits, finalShiftVal);
87 } else if (shlBits < 0) {
88 Value finalShiftVal =
89 arith::ConstantIndexOp::create(b, loc, -1 * shlBits);
90 srcBits = b.createOrFold<arith::ShRUIOp>(loc, srcBits, finalShiftVal);
91 }
92 } else {
93 Value finalShiftVal = arith::ConstantIndexOp::create(b, loc, n);
94 srcBits = b.createOrFold<arith::ShLIOp>(loc, srcBits, finalShiftVal);
95 }
96
97 Value permutedVectorIdx =
98 arith::XOrIOp::create(b, loc, indices[tgtDim], srcBits);
99 return permutedVectorIdx;
100}
101
102static void transformIndices(OpBuilder &builder, Location loc,
104 MemRefType memrefTy, int64_t srcDim,
105 int64_t tgtDim) {
106 indices[tgtDim] =
107 permuteVectorOffset(builder, loc, indices, memrefTy, srcDim, tgtDim);
108}
109
110/// Return all operations within `parentOp` that read from or write to
111/// `shmMemRef`.
112static LogicalResult
116 parentOp->walk([&](Operation *op) {
117 MemoryEffectOpInterface iface = dyn_cast<MemoryEffectOpInterface>(op);
118 if (!iface)
119 return;
120 std::optional<MemoryEffects::EffectInstance> effect =
121 iface.getEffectOnValue<MemoryEffects::Read>(shmMemRef);
122 if (effect) {
123 readOps.push_back(op);
124 return;
125 }
126 effect = iface.getEffectOnValue<MemoryEffects::Write>(shmMemRef);
127 if (effect)
128 writeOps.push_back(op);
129 });
130
131 // Restrict to a supported set of ops. We also require at least 2D access,
132 // although this could be relaxed.
133 if (llvm::any_of(readOps, [](Operation *op) {
134 return !isa<memref::LoadOp, vector::LoadOp, nvgpu::LdMatrixOp>(op) ||
135 getIndices(op).size() < 2;
136 }))
137 return failure();
138 if (llvm::any_of(writeOps, [](Operation *op) {
139 return !isa<memref::StoreOp, vector::StoreOp, nvgpu::DeviceAsyncCopyOp>(
140 op) ||
141 getIndices(op).size() < 2;
142 }))
143 return failure();
144
145 return success();
146}
147
148llvm::LogicalResult
150 Value memrefValue) {
151 auto memRefType = dyn_cast<MemRefType>(memrefValue.getType());
152 if (!memRefType || !NVGPUDialect::hasSharedMemoryAddressSpace(memRefType))
153 return failure();
154
155 // Not support 0D MemRefs.
156 if (memRefType.getRank() == 0)
157 return failure();
158
159 // Abort if the given value has any sub-views; we do not do any alias
160 // analysis.
161 bool hasSubView = false;
162 parentOp->walk([&](memref::SubViewOp subView) { hasSubView = true; });
163 if (hasSubView)
164 return failure();
165
166 // Check if this is necessary given the assumption of 128b accesses:
167 // If dim[rank-1] is small enough to fit 8 rows in a 128B line.
168 const int64_t rowSize = memRefType.getDimSize(memRefType.getRank() - 1);
169 const int64_t rowsPerLine =
170 (8 * kSharedMemoryLineSizeBytes / memRefType.getElementTypeBitWidth()) /
171 rowSize;
172 const int64_t threadGroupSize =
173 1LL << (7 - llvm::Log2_64(kDefaultVectorSizeBits / 8));
174 if (rowsPerLine >= threadGroupSize)
175 return failure();
176
177 // Get sets of operations within the function that read/write to shared
178 // memory.
181 if (failed(getShmReadAndWriteOps(parentOp, memrefValue, shmReadOps,
182 shmWriteOps)))
183 return failure();
184
185 if (shmReadOps.empty() || shmWriteOps.empty())
186 return failure();
187
188 OpBuilder builder(parentOp->getContext());
189
190 int64_t tgtDim = memRefType.getRank() - 1;
191 int64_t srcDim = memRefType.getRank() - 2;
192
193 // Transform indices for the ops writing to shared memory.
194 while (!shmWriteOps.empty()) {
195 Operation *shmWriteOp = shmWriteOps.back();
196 shmWriteOps.pop_back();
197 builder.setInsertionPoint(shmWriteOp);
198
199 auto indices = getIndices(shmWriteOp);
200 SmallVector<Value, 4> transformedIndices(indices.begin(), indices.end());
201 transformIndices(builder, shmWriteOp->getLoc(), transformedIndices,
202 memRefType, srcDim, tgtDim);
203 setIndices(shmWriteOp, transformedIndices);
204 }
205
206 // Transform indices for the ops reading from shared memory.
207 while (!shmReadOps.empty()) {
208 Operation *shmReadOp = shmReadOps.back();
209 shmReadOps.pop_back();
210 builder.setInsertionPoint(shmReadOp);
211
212 auto indices = getIndices(shmReadOp);
213 SmallVector<Value, 4> transformedIndices(indices.begin(), indices.end());
214 transformIndices(builder, shmReadOp->getLoc(), transformedIndices,
215 memRefType, srcDim, tgtDim);
216 setIndices(shmReadOp, transformedIndices);
217 }
218
219 return success();
220}
221
222namespace {
223class OptimizeSharedMemoryPass
224 : public nvgpu::impl::OptimizeSharedMemoryBase<OptimizeSharedMemoryPass> {
225public:
226 OptimizeSharedMemoryPass() = default;
227
228 void runOnOperation() override {
229 Operation *op = getOperation();
231 op->walk([&](memref::AllocOp allocOp) {
232 if (!NVGPUDialect::hasSharedMemoryAddressSpace(allocOp.getType()))
233 return;
234 shmAllocOps.push_back(allocOp);
235 });
236 for (auto allocOp : shmAllocOps) {
237 if (failed(optimizeSharedMemoryReadsAndWrites(getOperation(),
238 allocOp.getMemref())))
239 return;
240 }
241 }
242};
243} // namespace
244
246 return std::make_unique<OptimizeSharedMemoryPass>();
247}
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
constexpr int64_t kSharedMemoryLineSizeBytes
The size of a shared memory line according to NV documentation.
static void transformIndices(OpBuilder &builder, Location loc, SmallVector< Value, 4 > &indices, MemRefType memrefTy, int64_t srcDim, int64_t tgtDim)
constexpr int64_t kDefaultVectorSizeBits
We optimize for 128bit accesses, but this can be made an argument in the future.
static Value permuteVectorOffset(OpBuilder &b, Location loc, ArrayRef< Value > indices, MemRefType memrefTy, int64_t srcDim, int64_t tgtDim)
Uses srcIndexValue to permute tgtIndexValue via result = xor(floordiv(srcIdxVal,permuteEveryN),...
static LogicalResult getShmReadAndWriteOps(Operation *parentOp, Value shmMemRef, SmallVector< Operation *, 16 > &readOps, SmallVector< Operation *, 16 > &writeOps)
Return all operations within parentOp that read from or write to shmMemRef.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:207
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:797
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
llvm::LogicalResult optimizeSharedMemoryReadsAndWrites(Operation *parentOp, Value memrefValue)
Passes.
std::unique_ptr< Pass > createOptimizeSharedMemoryPass()
Create a pass to optimize shared memory reads and writes.
void setIndices(Operation *op, ArrayRef< Value > indices)
Set the indices that the given load/store operation is operating on.
Definition Utils.cpp:38
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition Utils.cpp:18
Include the generated interface declarations.
The following effect indicates that the operation reads from some resource.
The following effect indicates that the operation writes to some resource.