23 template <
typename OpTy>
25 return op.getPermutationMap().isMinorIdentity() && op.isDimInBounds(0) &&
26 op.hasPureBufferSemantics() &&
34 if (
auto transferWrite = dyn_cast<vector::TransferWriteOp>(write))
37 return isa<vector::StoreOp>(write);
43 if (
auto transferRead = dyn_cast<vector::TransferReadOp>(read))
46 return isa<vector::LoadOp>(read);
52 vector::CreateMaskOp createMaskOp;
60 auto transferRead = dyn_cast<vector::TransferReadOp>(loadOp);
61 if (!transferRead || !transferRead.getMask())
62 return TransferMask{{}, {}};
63 assert(transferRead.getMask().getType().getRank() == 1 &&
68 transferRead.getMask().getDefiningOp<vector::CreateMaskOp>())
69 return TransferMask{maskOp, {}};
73 transferRead.getMask().getDefiningOp<vector::ExtractOp>())
75 extractOp.getVector().getDefiningOp<vector::CreateMaskOp>())
76 return TransferMask{maskOp,
86 FailureOr<TransferMask> transferMask =
getMaskOp(readOp);
87 assert(succeeded(transferMask) &&
"invalid transfer mask");
90 if (!transferMask->createMaskOp)
94 if (transferMask->extractPosition.empty()) {
95 assert(transferMask->createMaskOp.getNumOperands() == 1 &&
96 "expected single operand");
97 return transferMask->createMaskOp.getOperand(0);
104 assert(transferMask->createMaskOp.getVectorType().getRank() -
105 transferMask->extractPosition.size() ==
107 "expected N-D -> (N-1)-D extract");
110 for (
auto [pos, sz] : llvm::zip(transferMask->extractPosition,
111 transferMask->createMaskOp->getOperands())) {
113 b.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
114 b.
create<arith::ConstantIndexOp>(loc, pos), sz);
119 cond = b.
create<arith::AndIOp>(loc, cmp, cond);
121 return b.
create<arith::SelectOp>(
122 loc, cond, transferMask->createMaskOp->getOperands().back(),
123 b.
create<arith::ConstantIndexOp>(loc, 0));
128 VectorType vecType) {
129 assert(vecType.getRank() == 1 &&
"expected 1-D vector");
130 constexpr int64_t kSupportedCpAsyncAlignmentsInBytes[3] = {4, 8, 16};
133 bool supportedCopySize =
false;
134 int64_t numElements = vecType.getNumElements();
135 Type elementType = vecType.getElementType();
136 for (int64_t alignmentInBytes : kSupportedCpAsyncAlignmentsInBytes) {
137 if (alignmentInBytes * 8 ==
139 supportedCopySize =
true;
143 if (!supportedCopySize)
155 llvm::SmallSetVector<Operation *, 16> copyToSharedMem;
163 if (cast<VectorType>(vectorVal.
getType()).getRank() != 1)
166 if (!nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace(
167 cast<MemRefType>(storeBase.
getType())))
176 if (nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace(
177 cast<MemRefType>(loadBase.
getType())))
181 if (
auto transferRead = dyn_cast<vector::TransferReadOp>(readOp)) {
182 if (
Value mask = transferRead.getMask()) {
184 static_cast<int64_t
>(0))
193 VectorType vecType = cast<VectorType>(vectorVal.
getType());
201 copyToSharedMem.insert(writeOp);
205 while (!copyToSharedMem.empty()) {
208 Operation *writeOp = *copyToSharedMem.begin();
209 copyToSharedMem.
remove(writeOp);
210 group.push_back(writeOp);
214 while ((nextNode = nextNode->getNextNode())) {
216 auto memInterface = dyn_cast<MemoryEffectOpInterface>(nextNode);
217 if (memInterface && memInterface.hasNoEffect() &&
221 if (isa<vector::TransferReadOp, vector::LoadOp>(nextNode)) {
224 if (!nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace(
225 cast<MemRefType>(memrefOperand.
getType()))) {
229 if (copyToSharedMem.count(nextNode)) {
231 copyToSharedMem.remove(nextNode);
232 group.push_back(nextNode);
244 auto vectorType = cast<VectorType>(vectorVal.
getType());
245 int64_t numElements = vectorType.getNumElements();
249 Value numReadElements =
251 auto dstMemref = cast<MemRefType>(storeBase.
getType());
252 int64_t sizeInBytes =
253 (dstMemref.getElementTypeBitWidth() * numElements) / 8;
255 Value token = rewriter.
create<nvgpu::DeviceAsyncCopyOp>(
262 bypassL1 && sizeInBytes == 16 ? rewriter.
getUnitAttr()
264 tokens.push_back(token);
268 Value groupToken = rewriter.
create<nvgpu::DeviceAsyncCreateGroupOp>(
271 rewriter.
create<nvgpu::DeviceAsyncWaitOp>(op->
getLoc(), groupToken,
static bool isContiguousStore(Operation *write)
Return "true" if the given op is a contiguous and suitable vector.transfer_write or vector....
static bool isContiguousXferOp(OpTy op)
Return "true" if the given vector transfer op is contiguous and suitable for replacement with an asyn...
static bool resultsInSupportedAsyncCopy(MemRefType memrefType, VectorType vecType)
Return "true" if the conversion to async copy is supported by "async copy".
static bool isContiguousRead(Operation *read)
Return "true" if the given op is a contiguous and suitable vector.transfer_read or vector....
static FailureOr< TransferMask > getMaskOp(Operation *loadOp)
If the given vector load op has a mask that is defined by vector.create_mask, return that op.
static Value buildNumReadElements(OpBuilder &b, Location loc, Operation *readOp)
Build an SSA value that represents the number of read elements.
static SmallVector< unsigned > extractPosition(ArrayRef< int64_t > indices)
Convert the value of a DenseI64ArrayAttr to a vector of unsigned indices.
IntegerAttr getIndexAttr(int64_t value)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This trait indicates that the memory effects of an operation includes the effects of operations neste...
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
void remove()
Remove the operation from its parent block, but don't delete it.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Value getMemrefOperand(Operation *op)
Get the memref that is loaded from/stored into by the given load/store operation.
Value getValueStored(Operation *op)
Get the value that is stored by the given store operation.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
void createAsyncGroups(RewriterBase &rewriter, Operation *op, bool bypassL1)
Convert global->shared vector transfers to async device copies.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
bool isLastMemrefDimUnitStride(MemRefType type)
Return "true" if the last dimension of the given type has a static unit stride.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...