MLIR  19.0.0git
OptimizeSharedMemory.cpp
Go to the documentation of this file.
1 //===- OptimizeSharedMemory.cpp - MLIR NVGPU pass implementation ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements transforms to optimize accesses to shared memory.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/Support/MathExtras.h"
26 
27 namespace mlir {
28 namespace nvgpu {
29 #define GEN_PASS_DEF_OPTIMIZESHAREDMEMORY
30 #include "mlir/Dialect/NVGPU/Transforms/Passes.h.inc"
31 } // namespace nvgpu
32 } // namespace mlir
33 
34 using namespace mlir;
35 using namespace mlir::nvgpu;
36 
37 /// The size of a shared memory line according to NV documentation.
38 constexpr int64_t kSharedMemoryLineSizeBytes = 128;
39 /// We optimize for 128bit accesses, but this can be made an argument in the
40 /// future.
41 constexpr int64_t kDefaultVectorSizeBits = 128;
42 
43 /// Uses `srcIndexValue` to permute `tgtIndexValue` via
44 /// `result = xor(floordiv(srcIdxVal,permuteEveryN),
45 /// floordiv(tgtIdxVal,vectorSize)))
46 /// + tgtIdxVal % vectorSize`
47 /// This is done using an optimized sequence of `arith` operations.
49  ArrayRef<Value> indices, MemRefType memrefTy,
50  int64_t srcDim, int64_t tgtDim) {
51  // Adjust the src index to change how often the permutation changes
52  // if necessary.
53  Value src = indices[srcDim];
54 
55  // We only want to permute every N iterations of the target dim where N is
56  // ceil(sharedMemoryLineSizeBytes / dimSizeBytes(tgtDim)).
57  const int64_t permuteEveryN = std::max<int64_t>(
58  1, kSharedMemoryLineSizeBytes / ((memrefTy.getDimSize(tgtDim) *
59  memrefTy.getElementTypeBitWidth()) /
60  8));
61 
62  // clang-format off
63  // Index bit representation (b0 = least significant bit) for dim(1)
64  // of a `memref<?x?xDT>` is as follows:
65  // N := log2(128/elementSizeBits)
66  // M := log2(dimSize(1))
67  // then
68  // bits[0:N] = sub-vector element offset
69  // bits[N:M] = vector index
70  // clang-format on
71  int64_t n =
72  llvm::Log2_64(kDefaultVectorSizeBits / memrefTy.getElementTypeBitWidth());
73  int64_t m = llvm::Log2_64(memrefTy.getDimSize(tgtDim));
74 
75  // Capture bits[0:(M-N)] of src by first creating a (M-N) mask.
76  int64_t mask = (1LL << (m - n)) - 1;
77  if (permuteEveryN > 1)
78  mask = mask << llvm::Log2_64(permuteEveryN);
79  Value srcBits = b.create<arith::ConstantIndexOp>(loc, mask);
80  srcBits = b.create<arith::AndIOp>(loc, src, srcBits);
81 
82  // Use the src bits to permute the target bits b[N:M] containing the
83  // vector offset.
84  if (permuteEveryN > 1) {
85  int64_t shlBits = n - llvm::Log2_64(permuteEveryN);
86  if (shlBits > 0) {
87  Value finalShiftVal = b.create<arith::ConstantIndexOp>(loc, shlBits);
88  srcBits = b.createOrFold<arith::ShLIOp>(loc, srcBits, finalShiftVal);
89  } else if (shlBits < 0) {
90  Value finalShiftVal = b.create<arith::ConstantIndexOp>(loc, -1 * shlBits);
91  srcBits = b.createOrFold<arith::ShRUIOp>(loc, srcBits, finalShiftVal);
92  }
93  } else {
94  Value finalShiftVal = b.create<arith::ConstantIndexOp>(loc, n);
95  srcBits = b.createOrFold<arith::ShLIOp>(loc, srcBits, finalShiftVal);
96  }
97 
98  Value permutedVectorIdx =
99  b.create<arith::XOrIOp>(loc, indices[tgtDim], srcBits);
100  return permutedVectorIdx;
101 }
102 
103 static void transformIndices(OpBuilder &builder, Location loc,
104  SmallVector<Value, 4> &indices,
105  MemRefType memrefTy, int64_t srcDim,
106  int64_t tgtDim) {
107  indices[tgtDim] =
108  permuteVectorOffset(builder, loc, indices, memrefTy, srcDim, tgtDim);
109 }
110 
111 /// Return all operations within `parentOp` that read from or write to
112 /// `shmMemRef`.
113 static LogicalResult
114 getShmReadAndWriteOps(Operation *parentOp, Value shmMemRef,
116  SmallVector<Operation *, 16> &writeOps) {
117  parentOp->walk([&](Operation *op) {
118  MemoryEffectOpInterface iface = dyn_cast<MemoryEffectOpInterface>(op);
119  if (!iface)
120  return;
121  std::optional<MemoryEffects::EffectInstance> effect =
122  iface.getEffectOnValue<MemoryEffects::Read>(shmMemRef);
123  if (effect) {
124  readOps.push_back(op);
125  return;
126  }
127  effect = iface.getEffectOnValue<MemoryEffects::Write>(shmMemRef);
128  if (effect)
129  writeOps.push_back(op);
130  });
131 
132  // Restrict to a supported set of ops. We also require at least 2D access,
133  // although this could be relaxed.
134  if (llvm::any_of(readOps, [](Operation *op) {
135  return !isa<memref::LoadOp, vector::LoadOp, nvgpu::LdMatrixOp>(op) ||
136  getIndices(op).size() < 2;
137  }))
138  return failure();
139  if (llvm::any_of(writeOps, [](Operation *op) {
140  return !isa<memref::StoreOp, vector::StoreOp, nvgpu::DeviceAsyncCopyOp>(
141  op) ||
142  getIndices(op).size() < 2;
143  }))
144  return failure();
145 
146  return success();
147 }
148 
151  Value memrefValue) {
152  auto memRefType = dyn_cast<MemRefType>(memrefValue.getType());
153  if (!memRefType || !NVGPUDialect::hasSharedMemoryAddressSpace(memRefType))
154  return failure();
155 
156  // Abort if the given value has any sub-views; we do not do any alias
157  // analysis.
158  bool hasSubView = false;
159  parentOp->walk([&](memref::SubViewOp subView) { hasSubView = true; });
160  if (hasSubView)
161  return failure();
162 
163  // Check if this is necessary given the assumption of 128b accesses:
164  // If dim[rank-1] is small enough to fit 8 rows in a 128B line.
165  const int64_t rowSize = memRefType.getDimSize(memRefType.getRank() - 1);
166  const int64_t rowsPerLine =
167  (8 * kSharedMemoryLineSizeBytes / memRefType.getElementTypeBitWidth()) /
168  rowSize;
169  const int64_t threadGroupSize =
170  1LL << (7 - llvm::Log2_64(kDefaultVectorSizeBits / 8));
171  if (rowsPerLine >= threadGroupSize)
172  return failure();
173 
174  // Get sets of operations within the function that read/write to shared
175  // memory.
176  SmallVector<Operation *, 16> shmReadOps;
177  SmallVector<Operation *, 16> shmWriteOps;
178  if (failed(getShmReadAndWriteOps(parentOp, memrefValue, shmReadOps,
179  shmWriteOps)))
180  return failure();
181 
182  if (shmReadOps.empty() || shmWriteOps.empty())
183  return failure();
184 
185  OpBuilder builder(parentOp->getContext());
186 
187  int64_t tgtDim = memRefType.getRank() - 1;
188  int64_t srcDim = memRefType.getRank() - 2;
189 
190  // Transform indices for the ops writing to shared memory.
191  while (!shmWriteOps.empty()) {
192  Operation *shmWriteOp = shmWriteOps.back();
193  shmWriteOps.pop_back();
194  builder.setInsertionPoint(shmWriteOp);
195 
196  auto indices = getIndices(shmWriteOp);
197  SmallVector<Value, 4> transformedIndices(indices.begin(), indices.end());
198  transformIndices(builder, shmWriteOp->getLoc(), transformedIndices,
199  memRefType, srcDim, tgtDim);
200  setIndices(shmWriteOp, transformedIndices);
201  }
202 
203  // Transform indices for the ops reading from shared memory.
204  while (!shmReadOps.empty()) {
205  Operation *shmReadOp = shmReadOps.back();
206  shmReadOps.pop_back();
207  builder.setInsertionPoint(shmReadOp);
208 
209  auto indices = getIndices(shmReadOp);
210  SmallVector<Value, 4> transformedIndices(indices.begin(), indices.end());
211  transformIndices(builder, shmReadOp->getLoc(), transformedIndices,
212  memRefType, srcDim, tgtDim);
213  setIndices(shmReadOp, transformedIndices);
214  }
215 
216  return success();
217 }
218 
219 namespace {
220 class OptimizeSharedMemoryPass
221  : public nvgpu::impl::OptimizeSharedMemoryBase<OptimizeSharedMemoryPass> {
222 public:
223  OptimizeSharedMemoryPass() = default;
224 
225  void runOnOperation() override {
226  Operation *op = getOperation();
227  SmallVector<memref::AllocOp> shmAllocOps;
228  op->walk([&](memref::AllocOp allocOp) {
229  if (!NVGPUDialect::hasSharedMemoryAddressSpace(allocOp.getType()))
230  return;
231  shmAllocOps.push_back(allocOp);
232  });
233  for (auto allocOp : shmAllocOps) {
234  if (failed(optimizeSharedMemoryReadsAndWrites(getOperation(),
235  allocOp.getMemref())))
236  return;
237  }
238  }
239 };
240 } // namespace
241 
243  return std::make_unique<OptimizeSharedMemoryPass>();
244 }
constexpr int64_t kSharedMemoryLineSizeBytes
The size of a shared memory line according to NV documentation.
static void transformIndices(OpBuilder &builder, Location loc, SmallVector< Value, 4 > &indices, MemRefType memrefTy, int64_t srcDim, int64_t tgtDim)
constexpr int64_t kDefaultVectorSizeBits
We optimize for 128bit accesses, but this can be made an argument in the future.
static Value permuteVectorOffset(OpBuilder &b, Location loc, ArrayRef< Value > indices, MemRefType memrefTy, int64_t srcDim, int64_t tgtDim)
Uses srcIndexValue to permute tgtIndexValue via `result = xor(floordiv(srcIdxVal,permuteEveryN),...
static LogicalResult getShmReadAndWriteOps(Operation *parentOp, Value shmMemRef, SmallVector< Operation *, 16 > &readOps, SmallVector< Operation *, 16 > &writeOps)
Return all operations within parentOp that read from or write to shmMemRef.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
This class helps build Operations.
Definition: Builders.h:209
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:522
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:793
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
mlir::LogicalResult optimizeSharedMemoryReadsAndWrites(Operation *parentOp, Value memrefValue)
Passes.
std::unique_ptr< Pass > createOptimizeSharedMemoryPass()
Create a pass to optimize shared memory reads and writes.
void setIndices(Operation *op, ArrayRef< Value > indices)
Set the indices that the given load/store operation is operating on.
Definition: Utils.cpp:38
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
The following effect indicates that the operation reads from some resource.
The following effect indicates that the operation writes to some resource.