MLIR 22.0.0git
CreateAsyncGroups.cpp
Go to the documentation of this file.
1//===- CreateAsyncGroups.cpp - Create async device copies -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
17
18using namespace mlir;
19
20/// Return "true" if the given vector transfer op is contiguous and suitable
21/// for replacement with an async copy.
22template <typename OpTy>
23static bool isContiguousXferOp(OpTy op) {
24 return op.getPermutationMap().isMinorIdentity() && op.isDimInBounds(0) &&
25 op.hasPureBufferSemantics() &&
26 cast<MemRefType>(nvgpu::getMemrefOperand(op).getType())
27 .isLastDimUnitStride();
28}
29
30/// Return "true" if the given op is a contiguous and suitable
31/// vector.transfer_write or vector.store op.
32static bool isContiguousStore(Operation *write) {
33 if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(write))
34 return isContiguousXferOp(transferWrite) && !transferWrite.getMask();
35 // vector.store are always contiguous.
36 return isa<vector::StoreOp>(write);
37}
38
39/// Return "true" if the given op is a contiguous and suitable
40/// vector.transfer_read or vector.load op.
41static bool isContiguousRead(Operation *read) {
42 if (auto transferRead = dyn_cast<vector::TransferReadOp>(read))
43 return isContiguousXferOp(transferRead);
44 // vector.load are always contiguous.
45 return isa<vector::LoadOp>(read);
46}
47
48namespace {
49/// A vector.create_mask op and extract position.
50struct TransferMask {
51 vector::CreateMaskOp createMaskOp;
52 SmallVector<int64_t> extractPosition;
53};
54} // namespace
55
56/// If the given vector load op has a mask that is defined by
57/// vector.create_mask, return that op.
58static FailureOr<TransferMask> getMaskOp(Operation *loadOp) {
59 auto transferRead = dyn_cast<vector::TransferReadOp>(loadOp);
60 if (!transferRead || !transferRead.getMask())
61 return TransferMask{{}, {}};
62 assert(transferRead.getMask().getType().getRank() == 1 &&
63 "expected 1-D mask");
64
65 // Case 1: Mask is the result of a vector.create_mask.
66 if (auto maskOp =
67 transferRead.getMask().getDefiningOp<vector::CreateMaskOp>())
68 return TransferMask{maskOp, {}};
69
70 // Case 2: Mask is the result of a vector.extract(vector.create_mask).
71 if (auto extractOp =
72 transferRead.getMask().getDefiningOp<vector::ExtractOp>())
73 if (auto maskOp =
74 extractOp.getSource().getDefiningOp<vector::CreateMaskOp>())
75 return TransferMask{maskOp,
76 SmallVector<int64_t>(extractOp.getStaticPosition())};
77
78 // All other cases: not supported.
79 return failure();
80}
81
82/// Build an SSA value that represents the number of read elements.
84 Operation *readOp) {
85 FailureOr<TransferMask> transferMask = getMaskOp(readOp);
86 assert(succeeded(transferMask) && "invalid transfer mask");
87
88 // No mask => no num_read_elements.
89 if (!transferMask->createMaskOp)
90 return Value();
91
92 // No extract: return size of "ones" segment in the mask.
93 if (transferMask->extractPosition.empty()) {
94 assert(transferMask->createMaskOp.getNumOperands() == 1 &&
95 "expected single operand");
96 return transferMask->createMaskOp.getOperand(0);
97 }
98
99 // vector.extract(vector.create_mask).
100 // If extract_pos < num_ones, take number of elements from the least
101 // significant dimension. (Do this for all dimensions and bit-AND the
102 // conditions.)
103 assert(transferMask->createMaskOp.getVectorType().getRank() -
104 transferMask->extractPosition.size() ==
105 1 &&
106 "expected N-D -> (N-1)-D extract");
107 Value cond;
108 // Note: There is one more `sz` than `pos`. The loop end with the last `pos`.
109 for (auto [pos, sz] : llvm::zip(transferMask->extractPosition,
110 transferMask->createMaskOp->getOperands())) {
111 Value cmp =
112 arith::CmpIOp::create(b, loc, arith::CmpIPredicate::slt,
113 arith::ConstantIndexOp::create(b, loc, pos), sz);
114 if (!cond) {
115 cond = cmp;
116 continue;
117 }
118 cond = arith::AndIOp::create(b, loc, cmp, cond);
119 }
120 return arith::SelectOp::create(
121 b, loc, cond, transferMask->createMaskOp->getOperands().back(),
123}
124
125/// Return "true" if the conversion to async copy is supported by "async copy".
126static bool resultsInSupportedAsyncCopy(MemRefType memrefType,
127 VectorType vecType) {
128 assert(vecType.getRank() == 1 && "expected 1-D vector");
129 constexpr int64_t kSupportedCpAsyncAlignmentsInBytes[3] = {4, 8, 16};
130
131 // Condition 1: the copy size must be supported.
132 bool supportedCopySize = false;
133 int64_t numElements = vecType.getNumElements();
134 Type elementType = vecType.getElementType();
135 for (int64_t alignmentInBytes : kSupportedCpAsyncAlignmentsInBytes) {
136 if (alignmentInBytes * 8 ==
137 numElements * elementType.getIntOrFloatBitWidth()) {
138 supportedCopySize = true;
139 break;
140 }
141 }
142 if (!supportedCopySize)
143 return false;
144
145 // TODO: Condition 2: the alignments must be supported. For cp.async the
146 // NVIDIA doc (section 6.4.1) says: "The address must be naturally aligned to
147 // a multiple of the access size. If an address is not properly aligned, the
148 // resulting behavior is undefined.".
149 return true;
150}
151
153 bool bypassL1) {
154 llvm::SmallSetVector<Operation *, 16> copyToSharedMem;
155
156 // Look for all the copy that can be converted to async copy ops.
157 op->walk([&](Operation *writeOp) {
158 // Look for contiguous 1D vector store into shared memory.
159 if (!isContiguousStore(writeOp))
160 return;
161 Value vectorVal = nvgpu::getValueStored(writeOp);
162 if (cast<VectorType>(vectorVal.getType()).getRank() != 1)
163 return;
164 Value storeBase = nvgpu::getMemrefOperand(writeOp);
165 if (!nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace(
166 cast<MemRefType>(storeBase.getType())))
167 return;
168
169 // The stored vector must originate from a contiguous 1D vector load.
170 Operation *readOp = vectorVal.getDefiningOp();
171 if (readOp == nullptr || !isContiguousRead(readOp))
172 return;
173 Value loadBase = nvgpu::getMemrefOperand(readOp);
174 // Should be reading from global memory (not shared memory).
175 if (nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace(
176 cast<MemRefType>(loadBase.getType())))
177 return;
178
179 // Look for compatible mask and padding.
180 if (auto transferRead = dyn_cast<vector::TransferReadOp>(readOp)) {
181 if (Value mask = transferRead.getMask()) {
182 if (getConstantIntValue(transferRead.getPadding()) ==
183 static_cast<int64_t>(0))
184 return;
185 if (failed(getMaskOp(readOp)))
186 return;
187 }
188 }
189
190 // Check whether both accesses are supported before we emit: this is
191 // necessary to ensure the correctness of DeviceAsyncCopyOp.
192 VectorType vecType = cast<VectorType>(vectorVal.getType());
193
194 if (!resultsInSupportedAsyncCopy(cast<MemRefType>(loadBase.getType()),
195 vecType) ||
196 !resultsInSupportedAsyncCopy(cast<MemRefType>(storeBase.getType()),
197 vecType))
198 return;
199
200 copyToSharedMem.insert(writeOp);
201 return;
202 });
203
204 while (!copyToSharedMem.empty()) {
205 // Start a group with the first write.
207 Operation *writeOp = *copyToSharedMem.begin();
208 copyToSharedMem.remove(writeOp);
209 group.push_back(writeOp);
210 Operation *nextNode = writeOp;
211
212 // Look in the next nodes for more copies to add to the same group.
213 while ((nextNode = nextNode->getNextNode())) {
214 // Ignore ops without side effects.
215 auto memInterface = dyn_cast<MemoryEffectOpInterface>(nextNode);
216 if (memInterface && memInterface.hasNoEffect() &&
218 continue;
219 // Ignore read from a different address space.
220 if (isa<vector::TransferReadOp, vector::LoadOp>(nextNode)) {
221 Operation *readOp = nextNode;
222 Value memrefOperand = nvgpu::getMemrefOperand(readOp);
223 if (!nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace(
224 cast<MemRefType>(memrefOperand.getType()))) {
225 continue;
226 }
227 }
228 if (copyToSharedMem.count(nextNode)) {
229 // Found another copy, add it to the group.
230 copyToSharedMem.remove(nextNode);
231 group.push_back(nextNode);
232 continue;
233 }
234 // If the op is something else stop the accumulating op in the group.
235 break;
236 }
237
238 // Emit the group.
239 SmallVector<Value> tokens;
240 for (Operation *writeOp : group) {
241 rewriter.setInsertionPoint(writeOp);
242 Value vectorVal = nvgpu::getValueStored(writeOp);
243 auto vectorType = cast<VectorType>(vectorVal.getType());
244 int64_t numElements = vectorType.getNumElements();
245 Operation *readOp = vectorVal.getDefiningOp();
246 Value storeBase = nvgpu::getMemrefOperand(writeOp);
247 Value loadBase = nvgpu::getMemrefOperand(readOp);
248 Value numReadElements =
249 buildNumReadElements(rewriter, writeOp->getLoc(), readOp);
250 auto dstMemref = cast<MemRefType>(storeBase.getType());
251 int64_t sizeInBytes =
252 (dstMemref.getElementTypeBitWidth() * numElements) / 8;
253 // bypass_l1 only possible with 16 byte transfer.
254 Value token = nvgpu::DeviceAsyncCopyOp::create(
255 rewriter, writeOp->getLoc(),
256 nvgpu::DeviceAsyncTokenType::get(op->getContext()),
257 /*dst=*/storeBase, /*dstIndices=*/nvgpu::getIndices(writeOp),
258 /*src=*/loadBase,
259 /*srcIndices=*/nvgpu::getIndices(readOp),
260 /*dstElements=*/rewriter.getIndexAttr(numElements),
261 /*srcElements=*/numReadElements,
262 /*bypassL1=*/bypassL1 && sizeInBytes == 16 ? rewriter.getUnitAttr()
263 : UnitAttr());
264 tokens.push_back(token);
265 }
266
267 // Create the group and wait for it right after.
268 Value groupToken = nvgpu::DeviceAsyncCreateGroupOp::create(
269 rewriter, op->getLoc(),
270 nvgpu::DeviceAsyncTokenType::get(op->getContext()), tokens);
271 nvgpu::DeviceAsyncWaitOp::create(rewriter, op->getLoc(), groupToken,
272 nullptr);
273 // Clean up old stores.
274 for (Operation *writeOp : group)
275 rewriter.eraseOp(writeOp);
276 }
277}
static bool isContiguousStore(Operation *write)
Return "true" if the given op is a contiguous and suitable vector.transfer_write or vector....
static bool isContiguousXferOp(OpTy op)
Return "true" if the given vector transfer op is contiguous and suitable for replacement with an asyn...
static FailureOr< TransferMask > getMaskOp(Operation *loadOp)
If the given vector load op has a mask that is defined by vector.create_mask, return that op.
static bool resultsInSupportedAsyncCopy(MemRefType memrefType, VectorType vecType)
Return "true" if the conversion to async copy is supported by "async copy".
static bool isContiguousRead(Operation *read)
Return "true" if the given op is a contiguous and suitable vector.transfer_read or vector....
static Value buildNumReadElements(OpBuilder &b, Location loc, Operation *readOp)
Build an SSA value that represents the number of read elements.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
UnitAttr getUnitAttr()
Definition Builders.cpp:98
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:207
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
This trait indicates that the memory effects of an operation includes the effects of operations neste...
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
void remove()
Remove the operation from its parent block, but don't delete it.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:797
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
Value getMemrefOperand(Operation *op)
Get the memref that is loaded from/stored into by the given load/store operation.
Definition Utils.cpp:68
Value getValueStored(Operation *op)
Get the value that is stored by the given store operation.
Definition Utils.cpp:58
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition Utils.cpp:18
void createAsyncGroups(RewriterBase &rewriter, Operation *op, bool bypassL1)
Convert global->shared vector transfers to async device copies.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304