22 template <
typename OpTy>
24 return op.getPermutationMap().isMinorIdentity() && op.isDimInBounds(0) &&
25 op.hasPureBufferSemantics() &&
27 .isLastDimUnitStride();
33 if (
auto transferWrite = dyn_cast<vector::TransferWriteOp>(write))
36 return isa<vector::StoreOp>(write);
42 if (
auto transferRead = dyn_cast<vector::TransferReadOp>(read))
45 return isa<vector::LoadOp>(read);
51 vector::CreateMaskOp createMaskOp;
59 auto transferRead = dyn_cast<vector::TransferReadOp>(loadOp);
60 if (!transferRead || !transferRead.getMask())
61 return TransferMask{{}, {}};
62 assert(transferRead.getMask().getType().getRank() == 1 &&
67 transferRead.getMask().getDefiningOp<vector::CreateMaskOp>())
68 return TransferMask{maskOp, {}};
72 transferRead.getMask().getDefiningOp<vector::ExtractOp>())
74 extractOp.getVector().getDefiningOp<vector::CreateMaskOp>())
75 return TransferMask{maskOp,
85 FailureOr<TransferMask> transferMask =
getMaskOp(readOp);
86 assert(succeeded(transferMask) &&
"invalid transfer mask");
89 if (!transferMask->createMaskOp)
93 if (transferMask->extractPosition.empty()) {
94 assert(transferMask->createMaskOp.getNumOperands() == 1 &&
95 "expected single operand");
96 return transferMask->createMaskOp.getOperand(0);
103 assert(transferMask->createMaskOp.getVectorType().getRank() -
104 transferMask->extractPosition.size() ==
106 "expected N-D -> (N-1)-D extract");
109 for (
auto [pos, sz] : llvm::zip(transferMask->extractPosition,
110 transferMask->createMaskOp->getOperands())) {
112 b.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
113 b.
create<arith::ConstantIndexOp>(loc, pos), sz);
118 cond = b.
create<arith::AndIOp>(loc, cmp, cond);
120 return b.
create<arith::SelectOp>(
121 loc, cond, transferMask->createMaskOp->getOperands().back(),
122 b.
create<arith::ConstantIndexOp>(loc, 0));
127 VectorType vecType) {
128 assert(vecType.getRank() == 1 &&
"expected 1-D vector");
129 constexpr int64_t kSupportedCpAsyncAlignmentsInBytes[3] = {4, 8, 16};
132 bool supportedCopySize =
false;
133 int64_t numElements = vecType.getNumElements();
134 Type elementType = vecType.getElementType();
135 for (int64_t alignmentInBytes : kSupportedCpAsyncAlignmentsInBytes) {
136 if (alignmentInBytes * 8 ==
138 supportedCopySize =
true;
142 if (!supportedCopySize)
154 llvm::SmallSetVector<Operation *, 16> copyToSharedMem;
162 if (cast<VectorType>(vectorVal.
getType()).getRank() != 1)
165 if (!nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace(
166 cast<MemRefType>(storeBase.
getType())))
175 if (nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace(
176 cast<MemRefType>(loadBase.
getType())))
180 if (
auto transferRead = dyn_cast<vector::TransferReadOp>(readOp)) {
181 if (
Value mask = transferRead.getMask()) {
183 static_cast<int64_t
>(0))
192 VectorType vecType = cast<VectorType>(vectorVal.
getType());
200 copyToSharedMem.insert(writeOp);
204 while (!copyToSharedMem.empty()) {
207 Operation *writeOp = *copyToSharedMem.begin();
208 copyToSharedMem.
remove(writeOp);
209 group.push_back(writeOp);
213 while ((nextNode = nextNode->getNextNode())) {
215 auto memInterface = dyn_cast<MemoryEffectOpInterface>(nextNode);
216 if (memInterface && memInterface.hasNoEffect() &&
220 if (isa<vector::TransferReadOp, vector::LoadOp>(nextNode)) {
223 if (!nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace(
224 cast<MemRefType>(memrefOperand.
getType()))) {
228 if (copyToSharedMem.count(nextNode)) {
230 copyToSharedMem.remove(nextNode);
231 group.push_back(nextNode);
243 auto vectorType = cast<VectorType>(vectorVal.
getType());
244 int64_t numElements = vectorType.getNumElements();
248 Value numReadElements =
250 auto dstMemref = cast<MemRefType>(storeBase.
getType());
251 int64_t sizeInBytes =
252 (dstMemref.getElementTypeBitWidth() * numElements) / 8;
254 Value token = rewriter.
create<nvgpu::DeviceAsyncCopyOp>(
261 bypassL1 && sizeInBytes == 16 ? rewriter.
getUnitAttr()
263 tokens.push_back(token);
267 Value groupToken = rewriter.
create<nvgpu::DeviceAsyncCreateGroupOp>(
270 rewriter.
create<nvgpu::DeviceAsyncWaitOp>(op->
getLoc(), groupToken,
static bool isContiguousStore(Operation *write)
Return "true" if the given op is a contiguous and suitable vector.transfer_write or vector....
static bool isContiguousXferOp(OpTy op)
Return "true" if the given vector transfer op is contiguous and suitable for replacement with an asyn...
static bool resultsInSupportedAsyncCopy(MemRefType memrefType, VectorType vecType)
Return "true" if the conversion to async copy is supported by "async copy".
static bool isContiguousRead(Operation *read)
Return "true" if the given op is a contiguous and suitable vector.transfer_read or vector....
static FailureOr< TransferMask > getMaskOp(Operation *loadOp)
If the given vector load op has a mask that is defined by vector.create_mask, return that op.
static Value buildNumReadElements(OpBuilder &b, Location loc, Operation *readOp)
Build an SSA value that represents the number of read elements.
static SmallVector< unsigned > extractPosition(ArrayRef< int64_t > indices)
Convert the value of a DenseI64ArrayAttr to a vector of unsigned indices.
IntegerAttr getIndexAttr(int64_t value)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This trait indicates that the memory effects of an operation includes the effects of operations neste...
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
void remove()
Remove the operation from its parent block, but don't delete it.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Value getMemrefOperand(Operation *op)
Get the memref that is loaded from/stored into by the given load/store operation.
Value getValueStored(Operation *op)
Get the value that is stored by the given store operation.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
void createAsyncGroups(RewriterBase &rewriter, Operation *op, bool bypassL1)
Convert global->shared vector transfers to async device copies.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...