MLIR  22.0.0git
CreateAsyncGroups.cpp
Go to the documentation of this file.
1 //===- CreateAsyncGroups.cpp - Create async device copies -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
16 #include "mlir/IR/BuiltinTypes.h"
17 
18 using namespace mlir;
19 
20 /// Return "true" if the given vector transfer op is contiguous and suitable
21 /// for replacement with an async copy.
22 template <typename OpTy>
23 static bool isContiguousXferOp(OpTy op) {
24  return op.getPermutationMap().isMinorIdentity() && op.isDimInBounds(0) &&
25  op.hasPureBufferSemantics() &&
26  cast<MemRefType>(nvgpu::getMemrefOperand(op).getType())
27  .isLastDimUnitStride();
28 }
29 
30 /// Return "true" if the given op is a contiguous and suitable
31 /// vector.transfer_write or vector.store op.
32 static bool isContiguousStore(Operation *write) {
33  if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(write))
34  return isContiguousXferOp(transferWrite) && !transferWrite.getMask();
35  // vector.store are always contiguous.
36  return isa<vector::StoreOp>(write);
37 }
38 
39 /// Return "true" if the given op is a contiguous and suitable
40 /// vector.transfer_read or vector.load op.
41 static bool isContiguousRead(Operation *read) {
42  if (auto transferRead = dyn_cast<vector::TransferReadOp>(read))
43  return isContiguousXferOp(transferRead);
44  // vector.load are always contiguous.
45  return isa<vector::LoadOp>(read);
46 }
47 
48 namespace {
49 /// A vector.create_mask op and extract position.
50 struct TransferMask {
51  vector::CreateMaskOp createMaskOp;
53 };
54 } // namespace
55 
56 /// If the given vector load op has a mask that is defined by
57 /// vector.create_mask, return that op.
58 static FailureOr<TransferMask> getMaskOp(Operation *loadOp) {
59  auto transferRead = dyn_cast<vector::TransferReadOp>(loadOp);
60  if (!transferRead || !transferRead.getMask())
61  return TransferMask{{}, {}};
62  assert(transferRead.getMask().getType().getRank() == 1 &&
63  "expected 1-D mask");
64 
65  // Case 1: Mask is the result of a vector.create_mask.
66  if (auto maskOp =
67  transferRead.getMask().getDefiningOp<vector::CreateMaskOp>())
68  return TransferMask{maskOp, {}};
69 
70  // Case 2: Mask is the result of a vector.extract(vector.create_mask).
71  if (auto extractOp =
72  transferRead.getMask().getDefiningOp<vector::ExtractOp>())
73  if (auto maskOp =
74  extractOp.getVector().getDefiningOp<vector::CreateMaskOp>())
75  return TransferMask{maskOp,
76  SmallVector<int64_t>(extractOp.getStaticPosition())};
77 
78  // All other cases: not supported.
79  return failure();
80 }
81 
82 /// Build an SSA value that represents the number of read elements.
84  Operation *readOp) {
85  FailureOr<TransferMask> transferMask = getMaskOp(readOp);
86  assert(succeeded(transferMask) && "invalid transfer mask");
87 
88  // No mask => no num_read_elements.
89  if (!transferMask->createMaskOp)
90  return Value();
91 
92  // No extract: return size of "ones" segment in the mask.
93  if (transferMask->extractPosition.empty()) {
94  assert(transferMask->createMaskOp.getNumOperands() == 1 &&
95  "expected single operand");
96  return transferMask->createMaskOp.getOperand(0);
97  }
98 
99  // vector.extract(vector.create_mask).
100  // If extract_pos < num_ones, take number of elements from the least
101  // significant dimension. (Do this for all dimensions and bit-AND the
102  // conditions.)
103  assert(transferMask->createMaskOp.getVectorType().getRank() -
104  transferMask->extractPosition.size() ==
105  1 &&
106  "expected N-D -> (N-1)-D extract");
107  Value cond;
108  // Note: There is one more `sz` than `pos`. The loop end with the last `pos`.
109  for (auto [pos, sz] : llvm::zip(transferMask->extractPosition,
110  transferMask->createMaskOp->getOperands())) {
111  Value cmp =
112  b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
113  b.create<arith::ConstantIndexOp>(loc, pos), sz);
114  if (!cond) {
115  cond = cmp;
116  continue;
117  }
118  cond = b.create<arith::AndIOp>(loc, cmp, cond);
119  }
120  return b.create<arith::SelectOp>(
121  loc, cond, transferMask->createMaskOp->getOperands().back(),
122  b.create<arith::ConstantIndexOp>(loc, 0));
123 }
124 
125 /// Return "true" if the conversion to async copy is supported by "async copy".
126 static bool resultsInSupportedAsyncCopy(MemRefType memrefType,
127  VectorType vecType) {
128  assert(vecType.getRank() == 1 && "expected 1-D vector");
129  constexpr int64_t kSupportedCpAsyncAlignmentsInBytes[3] = {4, 8, 16};
130 
131  // Condition 1: the copy size must be supported.
132  bool supportedCopySize = false;
133  int64_t numElements = vecType.getNumElements();
134  Type elementType = vecType.getElementType();
135  for (int64_t alignmentInBytes : kSupportedCpAsyncAlignmentsInBytes) {
136  if (alignmentInBytes * 8 ==
137  numElements * elementType.getIntOrFloatBitWidth()) {
138  supportedCopySize = true;
139  break;
140  }
141  }
142  if (!supportedCopySize)
143  return false;
144 
145  // TODO: Condition 2: the alignments must be supported. For cp.async the
146  // NVIDIA doc (section 6.4.1) says: "The address must be naturally aligned to
147  // a multiple of the access size. If an address is not properly aligned, the
148  // resulting behavior is undefined.".
149  return true;
150 }
151 
153  bool bypassL1) {
154  llvm::SmallSetVector<Operation *, 16> copyToSharedMem;
155 
156  // Look for all the copy that can be converted to async copy ops.
157  op->walk([&](Operation *writeOp) {
158  // Look for contiguous 1D vector store into shared memory.
159  if (!isContiguousStore(writeOp))
160  return;
161  Value vectorVal = nvgpu::getValueStored(writeOp);
162  if (cast<VectorType>(vectorVal.getType()).getRank() != 1)
163  return;
164  Value storeBase = nvgpu::getMemrefOperand(writeOp);
165  if (!nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace(
166  cast<MemRefType>(storeBase.getType())))
167  return;
168 
169  // The stored vector must originate from a contiguous 1D vector load.
170  Operation *readOp = vectorVal.getDefiningOp();
171  if (readOp == nullptr || !isContiguousRead(readOp))
172  return;
173  Value loadBase = nvgpu::getMemrefOperand(readOp);
174  // Should be reading from global memory (not shared memory).
175  if (nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace(
176  cast<MemRefType>(loadBase.getType())))
177  return;
178 
179  // Look for compatible mask and padding.
180  if (auto transferRead = dyn_cast<vector::TransferReadOp>(readOp)) {
181  if (Value mask = transferRead.getMask()) {
182  if (getConstantIntValue(transferRead.getPadding()) ==
183  static_cast<int64_t>(0))
184  return;
185  if (failed(getMaskOp(readOp)))
186  return;
187  }
188  }
189 
190  // Check whether both accesses are supported before we emit: this is
191  // necessary to ensure the correctness of DeviceAsyncCopyOp.
192  VectorType vecType = cast<VectorType>(vectorVal.getType());
193 
194  if (!resultsInSupportedAsyncCopy(cast<MemRefType>(loadBase.getType()),
195  vecType) ||
196  !resultsInSupportedAsyncCopy(cast<MemRefType>(storeBase.getType()),
197  vecType))
198  return;
199 
200  copyToSharedMem.insert(writeOp);
201  return;
202  });
203 
204  while (!copyToSharedMem.empty()) {
205  // Start a group with the first write.
207  Operation *writeOp = *copyToSharedMem.begin();
208  copyToSharedMem.remove(writeOp);
209  group.push_back(writeOp);
210  Operation *nextNode = writeOp;
211 
212  // Look in the next nodes for more copies to add to the same group.
213  while ((nextNode = nextNode->getNextNode())) {
214  // Ignore ops without side effects.
215  auto memInterface = dyn_cast<MemoryEffectOpInterface>(nextNode);
216  if (memInterface && memInterface.hasNoEffect() &&
218  continue;
219  // Ignore read from a different address space.
220  if (isa<vector::TransferReadOp, vector::LoadOp>(nextNode)) {
221  Operation *readOp = nextNode;
222  Value memrefOperand = nvgpu::getMemrefOperand(readOp);
223  if (!nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace(
224  cast<MemRefType>(memrefOperand.getType()))) {
225  continue;
226  }
227  }
228  if (copyToSharedMem.count(nextNode)) {
229  // Found another copy, add it to the group.
230  copyToSharedMem.remove(nextNode);
231  group.push_back(nextNode);
232  continue;
233  }
234  // If the op is something else stop the accumulating op in the group.
235  break;
236  }
237 
238  // Emit the group.
239  SmallVector<Value> tokens;
240  for (Operation *writeOp : group) {
241  rewriter.setInsertionPoint(writeOp);
242  Value vectorVal = nvgpu::getValueStored(writeOp);
243  auto vectorType = cast<VectorType>(vectorVal.getType());
244  int64_t numElements = vectorType.getNumElements();
245  Operation *readOp = vectorVal.getDefiningOp();
246  Value storeBase = nvgpu::getMemrefOperand(writeOp);
247  Value loadBase = nvgpu::getMemrefOperand(readOp);
248  Value numReadElements =
249  buildNumReadElements(rewriter, writeOp->getLoc(), readOp);
250  auto dstMemref = cast<MemRefType>(storeBase.getType());
251  int64_t sizeInBytes =
252  (dstMemref.getElementTypeBitWidth() * numElements) / 8;
253  // bypass_l1 only possible with 16 byte transfer.
254  Value token = rewriter.create<nvgpu::DeviceAsyncCopyOp>(
256  /*dst=*/storeBase, /*dstIndices=*/nvgpu::getIndices(writeOp),
257  /*src=*/loadBase,
258  /*srcIndices=*/nvgpu::getIndices(readOp),
259  /*dstElements=*/rewriter.getIndexAttr(numElements),
260  /*srcElements=*/numReadElements,
261  /*bypassL1=*/bypassL1 && sizeInBytes == 16 ? rewriter.getUnitAttr()
262  : UnitAttr());
263  tokens.push_back(token);
264  }
265 
266  // Create the group and wait for it right after.
267  Value groupToken = rewriter.create<nvgpu::DeviceAsyncCreateGroupOp>(
269  tokens);
270  rewriter.create<nvgpu::DeviceAsyncWaitOp>(op->getLoc(), groupToken,
271  nullptr);
272  // Clean up old stores.
273  for (Operation *writeOp : group)
274  rewriter.eraseOp(writeOp);
275  }
276 }
static bool isContiguousStore(Operation *write)
Return "true" if the given op is a contiguous and suitable vector.transfer_write or vector....
static bool isContiguousXferOp(OpTy op)
Return "true" if the given vector transfer op is contiguous and suitable for replacement with an asyn...
static bool resultsInSupportedAsyncCopy(MemRefType memrefType, VectorType vecType)
Return "true" if the conversion to async copy is supported by "async copy".
static bool isContiguousRead(Operation *read)
Return "true" if the given op is a contiguous and suitable vector.transfer_read or vector....
static FailureOr< TransferMask > getMaskOp(Operation *loadOp)
If the given vector load op has a mask that is defined by vector.create_mask, return that op.
static Value buildNumReadElements(OpBuilder &b, Location loc, Operation *readOp)
Build an SSA value that represents the number of read elements.
static SmallVector< unsigned > extractPosition(ArrayRef< int64_t > indices)
Convert the value of a DenseI64ArrayAttr to a vector of unsigned indices.
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:103
UnitAttr getUnitAttr()
Definition: Builders.cpp:93
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
This class helps build Operations.
Definition: Builders.h:205
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:452
This trait indicates that the memory effects of an operation includes the effects of operations neste...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
void remove()
Remove the operation from its parent block, but don't delete it.
Definition: Operation.cpp:546
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Value getMemrefOperand(Operation *op)
Get the memref that is loaded from/stored into by the given load/store operation.
Definition: Utils.cpp:68
Value getValueStored(Operation *op)
Get the value that is stored by the given store operation.
Definition: Utils.cpp:58
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
void createAsyncGroups(RewriterBase &rewriter, Operation *op, bool bypassL1)
Convert global->shared vector transfers to async device copies.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...