MLIR  20.0.0git
CreateAsyncGroups.cpp
Go to the documentation of this file.
1 //===- CreateAsyncGroups.cpp - Create async device copies -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
17 #include "mlir/IR/BuiltinTypes.h"
18 
19 using namespace mlir;
20 
21 /// Return "true" if the given vector transfer op is contiguous and suitable
22 /// for replacement with an async copy.
23 template <typename OpTy>
24 static bool isContiguousXferOp(OpTy op) {
25  return op.getPermutationMap().isMinorIdentity() && op.isDimInBounds(0) &&
26  op.hasPureBufferSemantics() &&
28  cast<MemRefType>(nvgpu::getMemrefOperand(op).getType()));
29 }
30 
31 /// Return "true" if the given op is a contiguous and suitable
32 /// vector.transfer_write or vector.store op.
33 static bool isContiguousStore(Operation *write) {
34  if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(write))
35  return isContiguousXferOp(transferWrite) && !transferWrite.getMask();
36  // vector.store are always contiguous.
37  return isa<vector::StoreOp>(write);
38 }
39 
40 /// Return "true" if the given op is a contiguous and suitable
41 /// vector.transfer_read or vector.load op.
42 static bool isContiguousRead(Operation *read) {
43  if (auto transferRead = dyn_cast<vector::TransferReadOp>(read))
44  return isContiguousXferOp(transferRead);
45  // vector.load are always contiguous.
46  return isa<vector::LoadOp>(read);
47 }
48 
49 namespace {
50 /// A vector.create_mask op and extract position.
51 struct TransferMask {
52  vector::CreateMaskOp createMaskOp;
54 };
55 } // namespace
56 
57 /// If the given vector load op has a mask that is defined by
58 /// vector.create_mask, return that op.
59 static FailureOr<TransferMask> getMaskOp(Operation *loadOp) {
60  auto transferRead = dyn_cast<vector::TransferReadOp>(loadOp);
61  if (!transferRead || !transferRead.getMask())
62  return TransferMask{{}, {}};
63  assert(transferRead.getMask().getType().getRank() == 1 &&
64  "expected 1-D mask");
65 
66  // Case 1: Mask is the result of a vector.create_mask.
67  if (auto maskOp =
68  transferRead.getMask().getDefiningOp<vector::CreateMaskOp>())
69  return TransferMask{maskOp, {}};
70 
71  // Case 2: Mask is the result of a vector.extract(vector.create_mask).
72  if (auto extractOp =
73  transferRead.getMask().getDefiningOp<vector::ExtractOp>())
74  if (auto maskOp =
75  extractOp.getVector().getDefiningOp<vector::CreateMaskOp>())
76  return TransferMask{maskOp,
77  SmallVector<int64_t>(extractOp.getStaticPosition())};
78 
79  // All other cases: not supported.
80  return failure();
81 }
82 
83 /// Build an SSA value that represents the number of read elements.
85  Operation *readOp) {
86  FailureOr<TransferMask> transferMask = getMaskOp(readOp);
87  assert(succeeded(transferMask) && "invalid transfer mask");
88 
89  // No mask => no num_read_elements.
90  if (!transferMask->createMaskOp)
91  return Value();
92 
93  // No extract: return size of "ones" segment in the mask.
94  if (transferMask->extractPosition.empty()) {
95  assert(transferMask->createMaskOp.getNumOperands() == 1 &&
96  "expected single operand");
97  return transferMask->createMaskOp.getOperand(0);
98  }
99 
100  // vector.extract(vector.create_mask).
101  // If extract_pos < num_ones, take number of elements from the least
102  // significant dimension. (Do this for all dimensions and bit-AND the
103  // conditions.)
104  assert(transferMask->createMaskOp.getVectorType().getRank() -
105  transferMask->extractPosition.size() ==
106  1 &&
107  "expected N-D -> (N-1)-D extract");
108  Value cond;
109  // Note: There is one more `sz` than `pos`. The loop end with the last `pos`.
110  for (auto [pos, sz] : llvm::zip(transferMask->extractPosition,
111  transferMask->createMaskOp->getOperands())) {
112  Value cmp =
113  b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
114  b.create<arith::ConstantIndexOp>(loc, pos), sz);
115  if (!cond) {
116  cond = cmp;
117  continue;
118  }
119  cond = b.create<arith::AndIOp>(loc, cmp, cond);
120  }
121  return b.create<arith::SelectOp>(
122  loc, cond, transferMask->createMaskOp->getOperands().back(),
123  b.create<arith::ConstantIndexOp>(loc, 0));
124 }
125 
126 /// Return "true" if the conversion to async copy is supported by "async copy".
127 static bool resultsInSupportedAsyncCopy(MemRefType memrefType,
128  VectorType vecType) {
129  assert(vecType.getRank() == 1 && "expected 1-D vector");
130  constexpr int64_t kSupportedCpAsyncAlignmentsInBytes[3] = {4, 8, 16};
131 
132  // Condition 1: the copy size must be supported.
133  bool supportedCopySize = false;
134  int64_t numElements = vecType.getNumElements();
135  Type elementType = vecType.getElementType();
136  for (int64_t alignmentInBytes : kSupportedCpAsyncAlignmentsInBytes) {
137  if (alignmentInBytes * 8 ==
138  numElements * elementType.getIntOrFloatBitWidth()) {
139  supportedCopySize = true;
140  break;
141  }
142  }
143  if (!supportedCopySize)
144  return false;
145 
146  // TODO: Condition 2: the alignments must be supported. For cp.async the
147  // NVIDIA doc (section 6.4.1) says: "The address must be naturally aligned to
148  // a multiple of the access size. If an address is not properly aligned, the
149  // resulting behavior is undefined.".
150  return true;
151 }
152 
154  bool bypassL1) {
155  llvm::SmallSetVector<Operation *, 16> copyToSharedMem;
156 
157  // Look for all the copy that can be converted to async copy ops.
158  op->walk([&](Operation *writeOp) {
159  // Look for contiguous 1D vector store into shared memory.
160  if (!isContiguousStore(writeOp))
161  return;
162  Value vectorVal = nvgpu::getValueStored(writeOp);
163  if (cast<VectorType>(vectorVal.getType()).getRank() != 1)
164  return;
165  Value storeBase = nvgpu::getMemrefOperand(writeOp);
166  if (!nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace(
167  cast<MemRefType>(storeBase.getType())))
168  return;
169 
170  // The stored vector must originate from a contiguous 1D vector load.
171  Operation *readOp = vectorVal.getDefiningOp();
172  if (readOp == nullptr || !isContiguousRead(readOp))
173  return;
174  Value loadBase = nvgpu::getMemrefOperand(readOp);
175  // Should be reading from global memory (not shared memory).
176  if (nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace(
177  cast<MemRefType>(loadBase.getType())))
178  return;
179 
180  // Look for compatible mask and padding.
181  if (auto transferRead = dyn_cast<vector::TransferReadOp>(readOp)) {
182  if (Value mask = transferRead.getMask()) {
183  if (getConstantIntValue(transferRead.getPadding()) ==
184  static_cast<int64_t>(0))
185  return;
186  if (failed(getMaskOp(readOp)))
187  return;
188  }
189  }
190 
191  // Check whether both accesses are supported before we emit: this is
192  // necessary to ensure the correctness of DeviceAsyncCopyOp.
193  VectorType vecType = cast<VectorType>(vectorVal.getType());
194 
195  if (!resultsInSupportedAsyncCopy(cast<MemRefType>(loadBase.getType()),
196  vecType) ||
197  !resultsInSupportedAsyncCopy(cast<MemRefType>(storeBase.getType()),
198  vecType))
199  return;
200 
201  copyToSharedMem.insert(writeOp);
202  return;
203  });
204 
205  while (!copyToSharedMem.empty()) {
206  // Start a group with the first write.
208  Operation *writeOp = *copyToSharedMem.begin();
209  copyToSharedMem.remove(writeOp);
210  group.push_back(writeOp);
211  Operation *nextNode = writeOp;
212 
213  // Look in the next nodes for more copies to add to the same group.
214  while ((nextNode = nextNode->getNextNode())) {
215  // Ignore ops without side effects.
216  auto memInterface = dyn_cast<MemoryEffectOpInterface>(nextNode);
217  if (memInterface && memInterface.hasNoEffect() &&
219  continue;
220  // Ignore read from a different address space.
221  if (isa<vector::TransferReadOp, vector::LoadOp>(nextNode)) {
222  Operation *readOp = nextNode;
223  Value memrefOperand = nvgpu::getMemrefOperand(readOp);
224  if (!nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace(
225  cast<MemRefType>(memrefOperand.getType()))) {
226  continue;
227  }
228  }
229  if (copyToSharedMem.count(nextNode)) {
230  // Found another copy, add it to the group.
231  copyToSharedMem.remove(nextNode);
232  group.push_back(nextNode);
233  continue;
234  }
235  // If the op is something else stop the accumulating op in the group.
236  break;
237  }
238 
239  // Emit the group.
240  SmallVector<Value> tokens;
241  for (Operation *writeOp : group) {
242  rewriter.setInsertionPoint(writeOp);
243  Value vectorVal = nvgpu::getValueStored(writeOp);
244  auto vectorType = cast<VectorType>(vectorVal.getType());
245  int64_t numElements = vectorType.getNumElements();
246  Operation *readOp = vectorVal.getDefiningOp();
247  Value storeBase = nvgpu::getMemrefOperand(writeOp);
248  Value loadBase = nvgpu::getMemrefOperand(readOp);
249  Value numReadElements =
250  buildNumReadElements(rewriter, writeOp->getLoc(), readOp);
251  auto dstMemref = cast<MemRefType>(storeBase.getType());
252  int64_t sizeInBytes =
253  (dstMemref.getElementTypeBitWidth() * numElements) / 8;
254  // bypass_l1 only possible with 16 byte transfer.
255  Value token = rewriter.create<nvgpu::DeviceAsyncCopyOp>(
257  /*dst=*/storeBase, /*dstIndices=*/nvgpu::getIndices(writeOp),
258  /*src=*/loadBase,
259  /*srcIndices=*/nvgpu::getIndices(readOp),
260  /*dstElements=*/rewriter.getIndexAttr(numElements),
261  /*srcElements=*/numReadElements,
262  /*bypassL1=*/bypassL1 && sizeInBytes == 16 ? rewriter.getUnitAttr()
263  : UnitAttr());
264  tokens.push_back(token);
265  }
266 
267  // Create the group and wait for it right after.
268  Value groupToken = rewriter.create<nvgpu::DeviceAsyncCreateGroupOp>(
270  tokens);
271  rewriter.create<nvgpu::DeviceAsyncWaitOp>(op->getLoc(), groupToken,
272  nullptr);
273  // Clean up old stores.
274  for (Operation *writeOp : group)
275  rewriter.eraseOp(writeOp);
276  }
277 }
static bool isContiguousStore(Operation *write)
Return "true" if the given op is a contiguous and suitable vector.transfer_write or vector....
static bool isContiguousXferOp(OpTy op)
Return "true" if the given vector transfer op is contiguous and suitable for replacement with an asyn...
static bool resultsInSupportedAsyncCopy(MemRefType memrefType, VectorType vecType)
Return "true" if the conversion to async copy is supported by "async copy".
static bool isContiguousRead(Operation *read)
Return "true" if the given op is a contiguous and suitable vector.transfer_read or vector....
static FailureOr< TransferMask > getMaskOp(Operation *loadOp)
If the given vector load op has a mask that is defined by vector.create_mask, return that op.
static Value buildNumReadElements(OpBuilder &b, Location loc, Operation *readOp)
Build an SSA value that represents the number of read elements.
static SmallVector< unsigned > extractPosition(ArrayRef< int64_t > indices)
Convert the value of a DenseI64ArrayAttr to a vector of unsigned indices.
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:148
UnitAttr getUnitAttr()
Definition: Builders.cpp:138
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
This class helps build Operations.
Definition: Builders.h:215
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:406
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
This trait indicates that the memory effects of an operation includes the effects of operations neste...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:745
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:793
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
void remove()
Remove the operation from its parent block, but don't delete it.
Definition: Operation.cpp:547
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:133
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Value getMemrefOperand(Operation *op)
Get the memref that is loaded from/stored into by the given load/store operation.
Definition: Utils.cpp:68
Value getValueStored(Operation *op)
Get the value that is stored by the given store operation.
Definition: Utils.cpp:58
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
void createAsyncGroups(RewriterBase &rewriter, Operation *op, bool bypassL1)
Convert global->shared vector transfers to async device copies.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
bool isLastMemrefDimUnitStride(MemRefType type)
Return "true" if the last dimension of the given type has a static unit stride.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...