MLIR  22.0.0git
OptimizeSharedMemory.cpp
Go to the documentation of this file.
1 //===- OptimizeSharedMemory.cpp - MLIR NVGPU pass implementation ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements transforms to optimize accesses to shared memory.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/Support/MathExtras.h"
24 
25 namespace mlir {
26 namespace nvgpu {
27 #define GEN_PASS_DEF_OPTIMIZESHAREDMEMORY
28 #include "mlir/Dialect/NVGPU/Transforms/Passes.h.inc"
29 } // namespace nvgpu
30 } // namespace mlir
31 
32 using namespace mlir;
33 using namespace mlir::nvgpu;
34 
35 /// The size of a shared memory line according to NV documentation.
36 constexpr int64_t kSharedMemoryLineSizeBytes = 128;
37 /// We optimize for 128bit accesses, but this can be made an argument in the
38 /// future.
39 constexpr int64_t kDefaultVectorSizeBits = 128;
40 
41 /// Uses `srcIndexValue` to permute `tgtIndexValue` via
42 /// `result = xor(floordiv(srcIdxVal,permuteEveryN),
43 /// floordiv(tgtIdxVal,vectorSize)))
44 /// + tgtIdxVal % vectorSize`
45 /// This is done using an optimized sequence of `arith` operations.
47  ArrayRef<Value> indices, MemRefType memrefTy,
48  int64_t srcDim, int64_t tgtDim) {
49  // Adjust the src index to change how often the permutation changes
50  // if necessary.
51  Value src = indices[srcDim];
52 
53  // We only want to permute every N iterations of the target dim where N is
54  // ceil(sharedMemoryLineSizeBytes / dimSizeBytes(tgtDim)).
55  const int64_t permuteEveryN = std::max<int64_t>(
56  1, kSharedMemoryLineSizeBytes / ((memrefTy.getDimSize(tgtDim) *
57  memrefTy.getElementTypeBitWidth()) /
58  8));
59 
60  // clang-format off
61  // Index bit representation (b0 = least significant bit) for dim(1)
62  // of a `memref<?x?xDT>` is as follows:
63  // N := log2(128/elementSizeBits)
64  // M := log2(dimSize(1))
65  // then
66  // bits[0:N] = sub-vector element offset
67  // bits[N:M] = vector index
68  // clang-format on
69  int64_t n =
70  llvm::Log2_64(kDefaultVectorSizeBits / memrefTy.getElementTypeBitWidth());
71  int64_t m = llvm::Log2_64(memrefTy.getDimSize(tgtDim));
72 
73  // Capture bits[0:(M-N)] of src by first creating a (M-N) mask.
74  int64_t mask = (1LL << (m - n)) - 1;
75  if (permuteEveryN > 1)
76  mask = mask << llvm::Log2_64(permuteEveryN);
77  Value srcBits = arith::ConstantIndexOp::create(b, loc, mask);
78  srcBits = arith::AndIOp::create(b, loc, src, srcBits);
79 
80  // Use the src bits to permute the target bits b[N:M] containing the
81  // vector offset.
82  if (permuteEveryN > 1) {
83  int64_t shlBits = n - llvm::Log2_64(permuteEveryN);
84  if (shlBits > 0) {
85  Value finalShiftVal = arith::ConstantIndexOp::create(b, loc, shlBits);
86  srcBits = b.createOrFold<arith::ShLIOp>(loc, srcBits, finalShiftVal);
87  } else if (shlBits < 0) {
88  Value finalShiftVal =
89  arith::ConstantIndexOp::create(b, loc, -1 * shlBits);
90  srcBits = b.createOrFold<arith::ShRUIOp>(loc, srcBits, finalShiftVal);
91  }
92  } else {
93  Value finalShiftVal = arith::ConstantIndexOp::create(b, loc, n);
94  srcBits = b.createOrFold<arith::ShLIOp>(loc, srcBits, finalShiftVal);
95  }
96 
97  Value permutedVectorIdx =
98  arith::XOrIOp::create(b, loc, indices[tgtDim], srcBits);
99  return permutedVectorIdx;
100 }
101 
102 static void transformIndices(OpBuilder &builder, Location loc,
103  SmallVector<Value, 4> &indices,
104  MemRefType memrefTy, int64_t srcDim,
105  int64_t tgtDim) {
106  indices[tgtDim] =
107  permuteVectorOffset(builder, loc, indices, memrefTy, srcDim, tgtDim);
108 }
109 
110 /// Return all operations within `parentOp` that read from or write to
111 /// `shmMemRef`.
112 static LogicalResult
113 getShmReadAndWriteOps(Operation *parentOp, Value shmMemRef,
115  SmallVector<Operation *, 16> &writeOps) {
116  parentOp->walk([&](Operation *op) {
117  MemoryEffectOpInterface iface = dyn_cast<MemoryEffectOpInterface>(op);
118  if (!iface)
119  return;
120  std::optional<MemoryEffects::EffectInstance> effect =
121  iface.getEffectOnValue<MemoryEffects::Read>(shmMemRef);
122  if (effect) {
123  readOps.push_back(op);
124  return;
125  }
126  effect = iface.getEffectOnValue<MemoryEffects::Write>(shmMemRef);
127  if (effect)
128  writeOps.push_back(op);
129  });
130 
131  // Restrict to a supported set of ops. We also require at least 2D access,
132  // although this could be relaxed.
133  if (llvm::any_of(readOps, [](Operation *op) {
134  return !isa<memref::LoadOp, vector::LoadOp, nvgpu::LdMatrixOp>(op) ||
135  getIndices(op).size() < 2;
136  }))
137  return failure();
138  if (llvm::any_of(writeOps, [](Operation *op) {
139  return !isa<memref::StoreOp, vector::StoreOp, nvgpu::DeviceAsyncCopyOp>(
140  op) ||
141  getIndices(op).size() < 2;
142  }))
143  return failure();
144 
145  return success();
146 }
147 
148 llvm::LogicalResult
150  Value memrefValue) {
151  auto memRefType = dyn_cast<MemRefType>(memrefValue.getType());
152  if (!memRefType || !NVGPUDialect::hasSharedMemoryAddressSpace(memRefType))
153  return failure();
154 
155  // Not support 0D MemRefs.
156  if (memRefType.getRank() == 0)
157  return failure();
158 
159  // Abort if the given value has any sub-views; we do not do any alias
160  // analysis.
161  bool hasSubView = false;
162  parentOp->walk([&](memref::SubViewOp subView) { hasSubView = true; });
163  if (hasSubView)
164  return failure();
165 
166  // Check if this is necessary given the assumption of 128b accesses:
167  // If dim[rank-1] is small enough to fit 8 rows in a 128B line.
168  const int64_t rowSize = memRefType.getDimSize(memRefType.getRank() - 1);
169  const int64_t rowsPerLine =
170  (8 * kSharedMemoryLineSizeBytes / memRefType.getElementTypeBitWidth()) /
171  rowSize;
172  const int64_t threadGroupSize =
173  1LL << (7 - llvm::Log2_64(kDefaultVectorSizeBits / 8));
174  if (rowsPerLine >= threadGroupSize)
175  return failure();
176 
177  // Get sets of operations within the function that read/write to shared
178  // memory.
179  SmallVector<Operation *, 16> shmReadOps;
180  SmallVector<Operation *, 16> shmWriteOps;
181  if (failed(getShmReadAndWriteOps(parentOp, memrefValue, shmReadOps,
182  shmWriteOps)))
183  return failure();
184 
185  if (shmReadOps.empty() || shmWriteOps.empty())
186  return failure();
187 
188  OpBuilder builder(parentOp->getContext());
189 
190  int64_t tgtDim = memRefType.getRank() - 1;
191  int64_t srcDim = memRefType.getRank() - 2;
192 
193  // Transform indices for the ops writing to shared memory.
194  while (!shmWriteOps.empty()) {
195  Operation *shmWriteOp = shmWriteOps.back();
196  shmWriteOps.pop_back();
197  builder.setInsertionPoint(shmWriteOp);
198 
199  auto indices = getIndices(shmWriteOp);
200  SmallVector<Value, 4> transformedIndices(indices.begin(), indices.end());
201  transformIndices(builder, shmWriteOp->getLoc(), transformedIndices,
202  memRefType, srcDim, tgtDim);
203  setIndices(shmWriteOp, transformedIndices);
204  }
205 
206  // Transform indices for the ops reading from shared memory.
207  while (!shmReadOps.empty()) {
208  Operation *shmReadOp = shmReadOps.back();
209  shmReadOps.pop_back();
210  builder.setInsertionPoint(shmReadOp);
211 
212  auto indices = getIndices(shmReadOp);
213  SmallVector<Value, 4> transformedIndices(indices.begin(), indices.end());
214  transformIndices(builder, shmReadOp->getLoc(), transformedIndices,
215  memRefType, srcDim, tgtDim);
216  setIndices(shmReadOp, transformedIndices);
217  }
218 
219  return success();
220 }
221 
222 namespace {
223 class OptimizeSharedMemoryPass
224  : public nvgpu::impl::OptimizeSharedMemoryBase<OptimizeSharedMemoryPass> {
225 public:
226  OptimizeSharedMemoryPass() = default;
227 
228  void runOnOperation() override {
229  Operation *op = getOperation();
230  SmallVector<memref::AllocOp> shmAllocOps;
231  op->walk([&](memref::AllocOp allocOp) {
232  if (!NVGPUDialect::hasSharedMemoryAddressSpace(allocOp.getType()))
233  return;
234  shmAllocOps.push_back(allocOp);
235  });
236  for (auto allocOp : shmAllocOps) {
237  if (failed(optimizeSharedMemoryReadsAndWrites(getOperation(),
238  allocOp.getMemref())))
239  return;
240  }
241  }
242 };
243 } // namespace
244 
246  return std::make_unique<OptimizeSharedMemoryPass>();
247 }
constexpr int64_t kSharedMemoryLineSizeBytes
The size of a shared memory line according to NV documentation.
static void transformIndices(OpBuilder &builder, Location loc, SmallVector< Value, 4 > &indices, MemRefType memrefTy, int64_t srcDim, int64_t tgtDim)
constexpr int64_t kDefaultVectorSizeBits
We optimize for 128bit accesses, but this can be made an argument in the future.
static Value permuteVectorOffset(OpBuilder &b, Location loc, ArrayRef< Value > indices, MemRefType memrefTy, int64_t srcDim, int64_t tgtDim)
Uses srcIndexValue to permute tgtIndexValue via `result = xor(floordiv(srcIdxVal,permuteEveryN),...
static LogicalResult getShmReadAndWriteOps(Operation *parentOp, Value shmMemRef, SmallVector< Operation *, 16 > &readOps, SmallVector< Operation *, 16 > &writeOps)
Return all operations within parentOp that read from or write to shmMemRef.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
This class helps build Operations.
Definition: Builders.h:205
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:517
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
llvm::LogicalResult optimizeSharedMemoryReadsAndWrites(Operation *parentOp, Value memrefValue)
Passes.
std::unique_ptr< Pass > createOptimizeSharedMemoryPass()
Create a pass to optimize shared memory reads and writes.
void setIndices(Operation *op, ArrayRef< Value > indices)
Set the indices that the given load/store operation is operating on.
Definition: Utils.cpp:38
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
Include the generated interface declarations.
The following effect indicates that the operation reads from some resource.
The following effect indicates that the operation writes to some resource.