MLIR  17.0.0git
OptimizeSharedMemory.cpp
Go to the documentation of this file.
1 //===- OptimizeSharedMemory.cpp - MLIR NVGPU pass implementation ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements transforms to optimize accesses to shared memory.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/Support/MathExtras.h"
25 
26 namespace mlir {
27 namespace nvgpu {
28 #define GEN_PASS_DEF_OPTIMIZESHAREDMEMORY
29 #include "mlir/Dialect/NVGPU/Passes.h.inc"
30 } // namespace nvgpu
31 } // namespace mlir
32 
33 using namespace mlir;
34 using namespace mlir::nvgpu;
35 
36 /// The size of a shared memory line according to NV documentation.
37 constexpr int64_t kSharedMemoryLineSizeBytes = 128;
38 /// We optimize for 128bit accesses, but this can be made an argument in the
39 /// future.
40 constexpr int64_t kDefaultVectorSizeBits = 128;
41 
42 /// Uses `srcIndexValue` to permute `tgtIndexValue` via
43 /// `result = xor(floordiv(srcIdxVal,permuteEveryN),
44 /// floordiv(tgtIdxVal,vectorSize)))
45 /// + tgtIdxVal % vectorSize`
46 /// This is done using an optimized sequence of `arith` operations.
48  ArrayRef<Value> indices, MemRefType memrefTy,
49  int64_t srcDim, int64_t tgtDim) {
50  // Adjust the src index to change how often the permutation changes
51  // if necessary.
52  Value src = indices[srcDim];
53 
54  // We only want to permute every N iterations of the target dim where N is
55  // ceil(sharedMemoryLineSizeBytes / dimSizeBytes(tgtDim)).
56  const int64_t permuteEveryN = std::max<int64_t>(
57  1, kSharedMemoryLineSizeBytes / ((memrefTy.getDimSize(tgtDim) *
58  memrefTy.getElementTypeBitWidth()) /
59  8));
60 
61  // clang-format off
62  // Index bit representation (b0 = least significant bit) for dim(1)
63  // of a `memref<?x?xDT>` is as follows:
64  // N := log2(128/elementSizeBits)
65  // M := log2(dimSize(1))
66  // then
67  // bits[0:N] = sub-vector element offset
68  // bits[N:M] = vector index
69  // clang-format on
70  int64_t n =
71  llvm::Log2_64(kDefaultVectorSizeBits / memrefTy.getElementTypeBitWidth());
72  int64_t m = llvm::Log2_64(memrefTy.getDimSize(tgtDim));
73 
74  // Capture bits[0:(M-N)] of src by first creating a (M-N) mask.
75  int64_t mask = (1LL << (m - n)) - 1;
76  if (permuteEveryN > 1)
77  mask = mask << llvm::Log2_64(permuteEveryN);
78  Value srcBits = b.create<arith::ConstantIndexOp>(loc, mask);
79  srcBits = b.create<arith::AndIOp>(loc, src, srcBits);
80 
81  // Use the src bits to permute the target bits b[N:M] containing the
82  // vector offset.
83  if (permuteEveryN > 1) {
84  int64_t shlBits = n - llvm::Log2_64(permuteEveryN);
85  if (shlBits > 0) {
86  Value finalShiftVal = b.create<arith::ConstantIndexOp>(loc, shlBits);
87  srcBits = b.createOrFold<arith::ShLIOp>(loc, srcBits, finalShiftVal);
88  } else if (shlBits < 0) {
89  Value finalShiftVal = b.create<arith::ConstantIndexOp>(loc, -1 * shlBits);
90  srcBits = b.createOrFold<arith::ShRUIOp>(loc, srcBits, finalShiftVal);
91  }
92  } else {
93  Value finalShiftVal = b.create<arith::ConstantIndexOp>(loc, n);
94  srcBits = b.createOrFold<arith::ShLIOp>(loc, srcBits, finalShiftVal);
95  }
96 
97  Value permutedVectorIdx =
98  b.create<arith::XOrIOp>(loc, indices[tgtDim], srcBits);
99  return permutedVectorIdx;
100 }
101 
102 static void transformIndices(OpBuilder &builder, Location loc,
103  SmallVector<Value, 4> &indices,
104  MemRefType memrefTy, int64_t srcDim,
105  int64_t tgtDim) {
106  indices[tgtDim] =
107  permuteVectorOffset(builder, loc, indices, memrefTy, srcDim, tgtDim);
108 }
109 
111  if (auto ldmatrixOp = dyn_cast<LdMatrixOp>(op))
112  return ldmatrixOp.getIndices();
113  if (auto copyOp = dyn_cast<DeviceAsyncCopyOp>(op))
114  return copyOp.getDstIndices();
115  if (auto loadOp = dyn_cast<memref::LoadOp>(op))
116  return loadOp.getIndices();
117  if (auto storeOp = dyn_cast<memref::StoreOp>(op))
118  return storeOp.getIndices();
119  if (auto vectorReadOp = dyn_cast<vector::LoadOp>(op))
120  return vectorReadOp.getIndices();
121  if (auto vectorStoreOp = dyn_cast<vector::StoreOp>(op))
122  return vectorStoreOp.getIndices();
123  llvm_unreachable("unsupported op type");
124 }
125 
127  if (auto ldmatrixOp = dyn_cast<LdMatrixOp>(op))
128  return ldmatrixOp.getIndicesMutable().assign(indices);
129  if (auto copyOp = dyn_cast<DeviceAsyncCopyOp>(op))
130  return copyOp.getDstIndicesMutable().assign(indices);
131  if (auto loadOp = dyn_cast<memref::LoadOp>(op))
132  return loadOp.getIndicesMutable().assign(indices);
133  if (auto storeOp = dyn_cast<memref::StoreOp>(op))
134  return storeOp.getIndicesMutable().assign(indices);
135  if (auto vectorReadOp = dyn_cast<vector::LoadOp>(op))
136  return vectorReadOp.getIndicesMutable().assign(indices);
137  if (auto vectorStoreOp = dyn_cast<vector::StoreOp>(op))
138  return vectorStoreOp.getIndicesMutable().assign(indices);
139  llvm_unreachable("unsupported op type");
140 }
141 
142 /// Return all operations within `parentOp` that read from or write to
143 /// `shmMemRef`.
144 static LogicalResult
145 getShmReadAndWriteOps(Operation *parentOp, Value shmMemRef,
147  SmallVector<Operation *, 16> &writeOps) {
148  parentOp->walk([&](Operation *op) {
149  MemoryEffectOpInterface iface = dyn_cast<MemoryEffectOpInterface>(op);
150  if (!iface)
151  return;
152  std::optional<MemoryEffects::EffectInstance> effect =
153  iface.getEffectOnValue<MemoryEffects::Read>(shmMemRef);
154  if (effect) {
155  readOps.push_back(op);
156  return;
157  }
158  effect = iface.getEffectOnValue<MemoryEffects::Write>(shmMemRef);
159  if (effect)
160  writeOps.push_back(op);
161  });
162 
163  // Restrict to a supported set of ops. We also require at least 2D access,
164  // although this could be relaxed.
165  if (llvm::any_of(readOps, [](Operation *op) {
166  return !isa<memref::LoadOp, vector::LoadOp, nvgpu::LdMatrixOp>(op) ||
167  getIndices(op).size() < 2;
168  }))
169  return failure();
170  if (llvm::any_of(writeOps, [](Operation *op) {
171  return !isa<memref::StoreOp, vector::StoreOp, nvgpu::DeviceAsyncCopyOp>(
172  op) ||
173  getIndices(op).size() < 2;
174  }))
175  return failure();
176 
177  return success();
178 }
179 
182  Value memrefValue) {
183  auto memRefType = memrefValue.getType().dyn_cast<MemRefType>();
184  if (!memRefType || !NVGPUDialect::hasSharedMemoryAddressSpace(memRefType))
185  return failure();
186 
187  // Abort if the given value has any sub-views; we do not do any alias
188  // analysis.
189  bool hasSubView = false;
190  parentOp->walk([&](memref::SubViewOp subView) { hasSubView = true; });
191  if (hasSubView)
192  return failure();
193 
194  // Check if this is necessary given the assumption of 128b accesses:
195  // If dim[rank-1] is small enough to fit 8 rows in a 128B line.
196  const int64_t rowSize = memRefType.getDimSize(memRefType.getRank() - 1);
197  const int64_t rowsPerLine =
198  (8 * kSharedMemoryLineSizeBytes / memRefType.getElementTypeBitWidth()) /
199  rowSize;
200  const int64_t threadGroupSize =
201  1LL << (7 - llvm::Log2_64(kDefaultVectorSizeBits / 8));
202  if (rowsPerLine >= threadGroupSize)
203  return failure();
204 
205  // Get sets of operations within the function that read/write to shared
206  // memory.
207  SmallVector<Operation *, 16> shmReadOps;
208  SmallVector<Operation *, 16> shmWriteOps;
209  if (failed(getShmReadAndWriteOps(parentOp, memrefValue, shmReadOps,
210  shmWriteOps)))
211  return failure();
212 
213  if (shmReadOps.empty() || shmWriteOps.empty())
214  return failure();
215 
216  OpBuilder builder(parentOp->getContext());
217 
218  int64_t tgtDim = memRefType.getRank() - 1;
219  int64_t srcDim = memRefType.getRank() - 2;
220 
221  // Transform indices for the ops writing to shared memory.
222  while (!shmWriteOps.empty()) {
223  Operation *shmWriteOp = shmWriteOps.back();
224  shmWriteOps.pop_back();
225  builder.setInsertionPoint(shmWriteOp);
226 
227  auto indices = getIndices(shmWriteOp);
228  SmallVector<Value, 4> transformedIndices(indices.begin(), indices.end());
229  transformIndices(builder, shmWriteOp->getLoc(), transformedIndices,
230  memRefType, srcDim, tgtDim);
231  setIndices(shmWriteOp, transformedIndices);
232  }
233 
234  // Transform indices for the ops reading from shared memory.
235  while (!shmReadOps.empty()) {
236  Operation *shmReadOp = shmReadOps.back();
237  shmReadOps.pop_back();
238  builder.setInsertionPoint(shmReadOp);
239 
240  auto indices = getIndices(shmReadOp);
241  SmallVector<Value, 4> transformedIndices(indices.begin(), indices.end());
242  transformIndices(builder, shmReadOp->getLoc(), transformedIndices,
243  memRefType, srcDim, tgtDim);
244  setIndices(shmReadOp, transformedIndices);
245  }
246 
247  return success();
248 }
249 
250 namespace {
251 class OptimizeSharedMemoryPass
252  : public nvgpu::impl::OptimizeSharedMemoryBase<OptimizeSharedMemoryPass> {
253 public:
254  OptimizeSharedMemoryPass() = default;
255 
256  void runOnOperation() override {
257  Operation *op = getOperation();
258  SmallVector<memref::AllocOp> shmAllocOps;
259  op->walk([&](memref::AllocOp allocOp) {
260  if (!NVGPUDialect::hasSharedMemoryAddressSpace(allocOp.getType()))
261  return;
262  shmAllocOps.push_back(allocOp);
263  });
264  for (auto allocOp : shmAllocOps) {
265  if (failed(optimizeSharedMemoryReadsAndWrites(getOperation(),
266  allocOp.getMemref())))
267  return;
268  }
269  }
270 };
271 } // namespace
272 
274  return std::make_unique<OptimizeSharedMemoryPass>();
275 }
constexpr int64_t kSharedMemoryLineSizeBytes
The size of a shared memory line according to NV documentation.
static void transformIndices(OpBuilder &builder, Location loc, SmallVector< Value, 4 > &indices, MemRefType memrefTy, int64_t srcDim, int64_t tgtDim)
constexpr int64_t kDefaultVectorSizeBits
We optimize for 128bit accesses, but this can be made an argument in the future.
static Value permuteVectorOffset(OpBuilder &b, Location loc, ArrayRef< Value > indices, MemRefType memrefTy, int64_t srcDim, int64_t tgtDim)
Uses srcIndexValue to permute tgtIndexValue via `result = xor(floordiv(srcIdxVal,permuteEveryN),...
static LogicalResult getShmReadAndWriteOps(Operation *parentOp, Value shmMemRef, SmallVector< Operation *, 16 > &readOps, SmallVector< Operation *, 16 > &writeOps)
Return all operations within parentOp that read from or write to shmMemRef.
Operation::operand_range getIndices(Operation *op)
void setIndices(Operation *op, ArrayRef< Value > indices)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
This class helps build Operations.
Definition: Builders.h:199
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:351
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:473
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:422
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:75
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:191
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:198
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:620
U dyn_cast() const
Definition: Types.h:311
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
Type getType() const
Return the type of this value.
Definition: Value.h:122
mlir::LogicalResult optimizeSharedMemoryReadsAndWrites(Operation *parentOp, Value memrefValue)
Passes.
std::unique_ptr< Pass > createOptimizeSharedMemoryPass()
Create a pass to optimize shared memory reads and writes.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
The following effect indicates that the operation reads from some resource.
The following effect indicates that the operation writes to some resource.