'nvgpu' Dialect
The NVGPU
dialect provides a bridge between higher-level target-agnostic
dialects (GPU and Vector) and the lower-level target-specific dialect
(LLVM IR based NVVM dialect) for NVIDIA GPUs. This allow representing PTX
specific operations while using MLIR high level dialects such as Memref
and Vector for memory and target-specific register operands, respectively.
Operations ¶
nvgpu.device_async_copy
(nvgpu::DeviceAsyncCopyOp) ¶
Device-side asynchronous copy
Syntax:
operation ::= `nvgpu.device_async_copy` $src `[` $srcIndices `]` `,` $dst `[` $dstIndices `]` `,` $dstElements (`,` $srcElements^)?
attr-dict `:` type($src) `to` type($dst)
The nvgpu.device_async_copy
op initiates an asynchronous copy operation of
elements from source (global memory) to the destination (shared memory)
without blocking the thread. The async copy is added to a group.
This op is meant to be used with nvgpu.device_async_create_group
and
nvgpu.device_async_wait
to synchronize copies as explained in those ops
descriptions.
bypassL1
attribute is hint to the hardware to bypass the L1 cache during
async copy, this hint may be ignored by the hardware.
dstElements
attribute is the total number of elements written to
destination (shared memory).
srcElements
argument is the total number of elements read from
source (global memory).
srcElements
is an optional argument and when present the op only reads
srcElements
number of elements from the source (global memory) and zero fills
the rest of the elements in the destination (shared memory).
In order to do a copy and wait for the result we need the following combination:
// copy 1.
%cp1 = nvgpu.device_async_copy %A[%c0], %B[%c0], 4 :memref<16xf32> to memref<16xf32, 3>
// copy 2.
%cp2 = nvgpu.device_async_copy %C[%c0], %D[%c0], 4 : memref<16xf32> to memref<16xf32, 3>
// group 1 contains copy 1 and copy 2.
%token1 = nvgpu.device_async_create_group %cp1, %cp2
// copy 3.
%cp3 = nvgpu.device_async_copy %E[%c0], %F[%c0], 4 : memref<16xf32> to memref<16xf32, 3>
// group 2 contains copy 3.
%token2 = nvgpu.device_async_create_group %cp3
// after the wait copy 1 and copy 2 are complete.
nvgpu.device_async_wait %token1
// after the wait copy 3 is complete.
nvgpu.device_async_wait %token2
Example:
%0 = nvgpu.device_async_copy %src[%c0, %c0], %dst[%c0, %c0, %c0], 4 :
memref<4x5xf32> to memref<2x7x5xf32, 3>
Traits: AttrSizedOperandSegments
Interfaces: InferTypeOpInterface
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
dstElements | ::mlir::IntegerAttr | index attribute |
bypassL1 | ::mlir::UnitAttr | unit attribute |
Operands: ¶
Operand | Description |
---|---|
dst | memref of any type values |
dstIndices | variadic of index |
src | memref of any type values |
srcIndices | variadic of index |
srcElements | index |
Results: ¶
Result | Description |
---|---|
asyncToken | device async token type |
nvgpu.device_async_create_group
(nvgpu::DeviceAsyncCreateGroupOp) ¶
Device side asynchronous create group operation
Syntax:
operation ::= `nvgpu.device_async_create_group` $inputTokens attr-dict
The nvgpu.device_async_create_group
op creates a group of memory accesses
containing all the pending device_async_copy
operations associated with
argument tokens. Each token can only be part of one group.
It returns a token that can be use to wait until the group fully completes.
This is meant to be used with nvgpu.device_async_wait
to synchronize copies
as explained in those ops descriptions.
Groups are executed in the order they are created.
Example:
%0 = nvgpu.device_async_create_group
Interfaces: InferTypeOpInterface
Operands: ¶
Operand | Description |
---|---|
inputTokens | variadic of device async token type |
Results: ¶
Result | Description |
---|---|
asyncToken | device async token type |
nvgpu.device_async_wait
(nvgpu::DeviceAsyncWaitOp) ¶
Wait for async gpu ops to complete.
Syntax:
operation ::= `nvgpu.device_async_wait` $asyncDependencies attr-dict
The nvgpu.device_async_wait
op will block the execution thread until the group
associated with the source token is fully completed.
The optional $numGroups
attribute gives an upper bound of the number of
groups uncompleted when the wait can unblock the thread. For example, if
16 async groups are pushe and $numGroups
is set to 12, then the thread
will unblock when 12 groups or fewer are in flight (4 groups have
completed).
Example:
nvgpu.device_async_wait %0
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
numGroups | ::mlir::IntegerAttr | 32-bit signless integer attribute |
Operands: ¶
Operand | Description |
---|---|
asyncDependencies | device async token type |
nvgpu.ldmatrix
(nvgpu::LdMatrixOp) ¶
Syntax:
operation ::= `nvgpu.ldmatrix` $srcMemref`[` $indices `]` attr-dict `:` type($srcMemref) `->` type($res)
The nvgpu.ldmatrix
op represents loading a matrix fragment from
memory to registers. The source and result type must be compatible
with lowering to the nvvm.ldmatrix
instruction. This op represents
the distributed version of a vector.transfer_read
as an intermediate
step between lowering from vector.transfer_read
to nvvm.ldmatrix
.
This operation is meant to follow the semantic of described here: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-ldmatrix
Example:
%0 = nvgpu.ldmatrix %sm[%c0, %c0] {numTiles = 4 : i32, transpose = false} :
memref<?x?xf16, 3> -> vector<4x2xf16>
Interfaces: MemoryEffectOpInterface (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{MemoryEffects::Read on ::mlir::SideEffects::DefaultResource}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
transpose | ::mlir::BoolAttr | bool attribute |
numTiles | ::mlir::IntegerAttr | 32-bit signless integer attribute |
Operands: ¶
Operand | Description |
---|---|
srcMemref | memref of any type values |
indices | variadic of index |
Results: ¶
Result | Description |
---|---|
res | vector of any type values |
nvgpu.mbarrier.arrive
(nvgpu::MBarrierArriveOp) ¶
Performs arrive operation on the nvgpu.mbarrier.arrive
.
Syntax:
operation ::= `nvgpu.mbarrier.arrive` $barriers `[` $mbarId `]` attr-dict `:` type($barriers) `->` type($token)
The Op performs arrive-on operation on the mbarrier
object and returns a
nvgpu.mbarrier.token
.
For more information, see https://docs.nvidia.com/cuda/parallel-thread-execution/#arrive-on-operation-on-mbarrier-object
Example:
%token = nvgpu.mbarrier.arrive %barrier : !nvgpu.mbarrier.barrier<memorySpace = #gpu.address_space<workgroup>> -> !nvgpu.mbarrier.token
Interfaces: InferTypeOpInterface
Operands: ¶
Operand | Description |
---|---|
barriers | mbarrier barrier type |
mbarId | index |
Results: ¶
Result | Description |
---|---|
token |
nvgpu.mbarrier.arrive.expect_tx
(nvgpu::MBarrierArriveExpectTxOp) ¶
Performs expect_tx operation on the nvgpu.mbarrier.arrive
Syntax:
operation ::= `nvgpu.mbarrier.arrive.expect_tx` $barriers `[` $mbarId `]` `,` $txcount (`,` `predicate` `=` $predicate^)? attr-dict `:` type($barriers)
A thread executing the Op performs an expect-tx operation on the mbarrier object at the location specified by the address operand $barrier. The expect-tx operation, with an $txcount argument, increases the tx-count of an mbarrier object by the value specified by $txcount. This makes the current phase of the mbarrier object to expect and track the completion of additional asynchronous transactions.
The $txCount
specifies the number of element to the expect-tx operation.
Example:
nvgpu.mbarrier.arrive.expect_tx %barrier, %ic0 : !nvgpu.mbarrier.barrier<memorySpace = #gpu.address_space<workgroup>>
Operands: ¶
Operand | Description |
---|---|
barriers | mbarrier barrier type |
txcount | index |
mbarId | index |
predicate | 1-bit signless integer |
nvgpu.mbarrier.arrive.nocomplete
(nvgpu::MBarrierArriveNoCompleteOp) ¶
Performs arrive operation on the nvgpu.mbarrier.arrive.nocomplete
as non-blocking.
Syntax:
operation ::= `nvgpu.mbarrier.arrive.nocomplete` $barriers `[` $mbarId `]` `,` $count attr-dict `:` type($barriers) `->` type($token)
The Op performs arrive-on operation on the mbarrier
object and returns a
nvgpu.mbarrier.token
.
The Op does not cause the nvgpu.mbarrier
to complete its current phase.
Example:
%token = nvgpu.mbarrier.arrive.noComplete %barrier, %count : !nvgpu.mbarrier.barrier<memorySpace = #gpu.address_space<workgroup>> -> !nvgpu.mbarrier.token
Interfaces: InferTypeOpInterface
Operands: ¶
Operand | Description |
---|---|
barriers | mbarrier barrier type |
mbarId | index |
count | index |
Results: ¶
Result | Description |
---|---|
token |
nvgpu.mbarrier.create
(nvgpu::MBarrierCreateOp) ¶
Creates a nvgpu.mbarrier
object.
Syntax:
operation ::= `nvgpu.mbarrier.create` attr-dict `->` type($barriers)
The Op generates one or more mbarrier
object, which is a barrier created in
shared memory and supports various synchronization behaviors for threads.
The mbarrier
object has the following type and alignment requirements:
Type: .b64, Alignment: 8, Memory space: .shared
Example:
%barrier = nvgpu.mbarrier.create -> !nvgpu.mbarrier.barrier<memorySpace = #gpu.address_space<workgroup>>
Results: ¶
Result | Description |
---|---|
barriers | mbarrier barrier type |
nvgpu.mbarrier.init
(nvgpu::MBarrierInitOp) ¶
Initialize the nvgpu.mbarrier
.
Syntax:
operation ::= `nvgpu.mbarrier.init` $barriers `[` $mbarId `]` `,` $count (`,` `predicate` `=` $predicate^)? attr-dict `:` type($barriers)
The Op initializes the mbarrier
object with the given number of threads.
Example:
%num_threads = gpu.block_dim x
%barrier = nvgpu.mbarrier.create -> !nvgpu.mbarrier.barrier<memorySpace = #gpu.address_space<workgroup>>
nvgpu.mbarrier.init %barrier, %num_threads : !nvgpu.mbarrier.barrier<memorySpace = #gpu.address_space<workgroup>>
Operands: ¶
Operand | Description |
---|---|
barriers | mbarrier barrier type |
count | index |
mbarId | index |
predicate | 1-bit signless integer |
nvgpu.mbarrier.test.wait
(nvgpu::MBarrierTestWaitOp) ¶
Checks if the nvgpu.mbarrier
has completed its current phase.
Syntax:
operation ::= `nvgpu.mbarrier.test.wait` $barriers `[` $mbarId `]` `,` $token attr-dict `:` type($barriers) `,` type($token)
Checks whether the mbarrier object has completed the phase. It is is a non-blocking instruction which tests for the completion of the phase.
Example:
%isComplete = nvgpu.mbarrier.test.wait %barrier, %token : !nvgpu.mbarrier.barrier<memorySpace = #gpu.address_space<workgroup>>, !nvgpu.mbarrier.token
Interfaces: InferTypeOpInterface
Operands: ¶
Operand | Description |
---|---|
barriers | mbarrier barrier type |
token | |
mbarId | index |
Results: ¶
Result | Description |
---|---|
waitComplete | 1-bit signless integer |
nvgpu.mbarrier.try_wait.parity
(nvgpu::MBarrierTryWaitParityOp) ¶
Waits for the nvgpu.mbarrier
to complete its current phase.
Syntax:
operation ::= `nvgpu.mbarrier.try_wait.parity` $barriers `[` $mbarId `]` `,` $phaseParity `,` $ticks attr-dict `:` type($barriers)
Checks whether the mbarrier object has completed the phase. It is is a potentially blocking instruction which tests for the completion of the phase. Suspended thread resumes execution when the specified phase completes OR before the phase completes following a system-dependent time limit.
The $phaseParity
specifies either even phase (0) or odd phase (1) to
wait.
Example:
nvgpu.mbarrier.try_wait.parity %barrier, %phaseParity, %ticks : !nvgpu.mbarrier.barrier<memorySpace = #gpu.address_space<workgroup>>
Operands: ¶
Operand | Description |
---|---|
barriers | mbarrier barrier type |
phaseParity | 1-bit signless integer |
ticks | index |
mbarId | index |
nvgpu.mma.sp.sync
(nvgpu::MmaSparseSyncOp) ¶
Syntax:
operation ::= `nvgpu.mma.sp.sync` `(` $matrixA`,` $matrixB`,` $matrixC `)` `metadata` `(` $sparseMetadata `)` attr-dict
`:` `(` type($matrixA) `,` type($matrixB) `,` type($matrixC) `)` `->` type($res)
The nvgu.mma.sp.sync
operation performs a warp-distributed MMA operation
where operand A is “structured sparse”. In this case, the matrixA
operand
represents the (warp-distributed) non-zero values of operand A, and the
sparse_metadata
operand provides the indices.
The full description of the sparsity storage format and distribution scheme is described in the PTX docs. This operation is meant to follow the semantic described in the PTX documentation here: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-sparse-mma
The way the indices are distributed among the threads in a warp is controlled
by the optional sparsity_selector
operand, which is 0
by default. For
more information, please consult the PTX documentation linked above.
Example (targetingthe f16 16x8x32 mma.sp
PTX instruction):
nvgpu.mma.sp.sync (%a, %b, %c) metadata (%meta) {mmaShape = [16, 8, 32]} :
(vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
mmaShape | ::mlir::ArrayAttr | 64-bit integer array attribute |
sparsitySelector | ::mlir::IntegerAttr | 32-bit signless integer attribute |
tf32Enabled | ::mlir::UnitAttr | unit attribute |
Operands: ¶
Operand | Description |
---|---|
matrixA | vector of any type values |
matrixB | vector of any type values |
matrixC | vector of any type values |
sparseMetadata | fixed-length vector of 16-bit signless integer values of length 2 |
Results: ¶
Result | Description |
---|---|
res | vector of any type values |
nvgpu.mma.sync
(nvgpu::MmaSyncOp) ¶
Syntax:
operation ::= `nvgpu.mma.sync` `(` $matrixA`,` $matrixB`,` $matrixC `)` attr-dict
`:` `(` type($matrixA) `,` type($matrixB) `,` type($matrixC) `)` `->` type($res)
The nvgpu.mma.sync
op represents the warp-level matrix-multiply-and-
accumulate (mma) operation that is compatible with nvvm.mma.sync
.
The operands and results vector sizes are thread-level onwership to
the warp-level mma operation shape. mmaShape
attribute holds the
warp-level matrix-multiply shape.
The nvgpu.mma.sync
op serves as an intermediate point between lowering from
vector.contract
to nvvm.mma.sync
.
This operation is meant to follow the semantic of described here: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma
Example:
%res = nvgpu.mma.sync (%matrixA, %matrixB, %matrixC) {mmaShape = [16, 8, 16]} :
(vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf32>) -> vector<2x2xf32>
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
mmaShape | ::mlir::ArrayAttr | 64-bit integer array attribute |
tf32Enabled | ::mlir::UnitAttr | unit attribute |
Operands: ¶
Operand | Description |
---|---|
matrixA | vector of any type values |
matrixB | vector of any type values |
matrixC | vector of any type values |
Results: ¶
Result | Description |
---|---|
res | vector of any type values |
nvgpu.rcp
(nvgpu::RcpOp) ¶
The reciprocal calculation for vector types
Syntax:
operation ::= `nvgpu.rcp` $in `{` `rounding` `=` $rounding (`,` `ftz` $ftz^)? `}`
attr-dict `:` type($out)
Reciprocal calculation for vector
types using nvvm.rcp
OPs.
Currently, only the approx
rounding mode and ftz
are supported, and only for the f32
type.
The input and output must be of the same vector type and shape.
Traits: AlwaysSpeculatableImplTrait
, SameOperandsAndResultType
Interfaces: ConditionallySpeculatable
, InferTypeOpInterface
, NoMemoryEffect (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
rounding | ::mlir::nvgpu::RcpRoundingModeAttr | Rounding mode of rcpEnum cases:
|
ftz | ::mlir::UnitAttr | unit attribute |
Operands: ¶
Operand | Description |
---|---|
in | vector of 32-bit float values |
Results: ¶
Result | Description |
---|---|
out | vector of 32-bit float values |
nvgpu.tma.async.load
(nvgpu::TmaAsyncLoadOp) ¶
TMA asynchronous load
Syntax:
operation ::= `nvgpu.tma.async.load` $tensorMapDescriptor `[` $coordinates `]` `,` $barriers `[` $mbarId `]`
`to` $dst
(`multicast_mask` `=` $multicastMask^ )?
(`,` `predicate` `=` $predicate^)?
attr-dict `:` type($tensorMapDescriptor) `,` type($barriers)
`->` type($dst)
The Op loads a tile memory region from global memory to shared memory by Tensor Memory Access (TMA).
$tensorMapDescriptor
is tensor map descriptor which has information about
tile shape. The descriptor is created by nvgpu.tma.create.descriptor
The Op uses $barrier
mbarrier based completion mechanism.
Traits: AttrSizedOperandSegments
Operands: ¶
Operand | Description |
---|---|
dst | memref of any type values |
barriers | mbarrier barrier type |
tensorMapDescriptor | TensorMap descriptor |
coordinates | variadic of index |
mbarId | index |
multicastMask | 16-bit signless integer |
predicate | 1-bit signless integer |
nvgpu.tma.async.store
(nvgpu::TmaAsyncStoreOp) ¶
TMA asynchronous store
Syntax:
operation ::= `nvgpu.tma.async.store` $src `to` $tensorMapDescriptor `[` $coordinates `]`
(`,` `predicate` `=` $predicate^)?
attr-dict `:` type($src)
`->` type($tensorMapDescriptor)
The Op store a tile memory region from global memory to shared memory by Tensor Memory Access (TMA).
$tensorMapDescriptor
is tensor map descriptor which has information about
tile shape. The descriptor is created by nvgpu.tma.create.descriptor
Traits: AttrSizedOperandSegments
Operands: ¶
Operand | Description |
---|---|
src | memref of any type values |
tensorMapDescriptor | TensorMap descriptor |
coordinates | variadic of index |
predicate | 1-bit signless integer |
nvgpu.tma.create.descriptor
(nvgpu::TmaCreateDescriptorOp) ¶
TMA create descriptor
Syntax:
operation ::= `nvgpu.tma.create.descriptor` $tensor `box` `[` $boxDimensions `]` attr-dict `:` type($tensor) `->` type($tensorMap)
The Op creates a tensor map descriptor object representing tiled memory
region. To do that it calls CUDA Driver’s cuTensorMapEncodeTiled
. The
descriptor is used by Tensor Memory Access (TMA).
The tensor
is the source tensor to be tiled.
The boxDimensions
is the size of the tiled memory region in each dimension.
For more information see below: https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html
Operands: ¶
Operand | Description |
---|---|
tensor | unranked.memref of any type values |
boxDimensions | variadic of index |
Results: ¶
Result | Description |
---|---|
tensorMap | TensorMap descriptor |
nvgpu.tma.prefetch.descriptor
(nvgpu::TmaPrefetchOp) ¶
Prefetch given nvgpu.tensormap.descriptor
Syntax:
operation ::= `nvgpu.tma.prefetch.descriptor` $tensorMapDescriptor (`,` `predicate` `=` $predicate^)? attr-dict `:` type($tensorMapDescriptor)
The Op brings the cache line containing the given $tmaDescriptor
for
subsequent use by the tma.async.load
instruction.
Operands: ¶
Operand | Description |
---|---|
tensorMapDescriptor | TensorMap descriptor |
predicate | 1-bit signless integer |
nvgpu.warpgroup.generate.descriptor
(nvgpu::WarpgroupGenerateDescriptorOp) ¶
Generate a warpgroup matrix descriptor
Syntax:
operation ::= `nvgpu.warpgroup.generate.descriptor` $tensor `,` $tensorMap attr-dict `:` type($tensor) `,` type($tensorMap) `->` type($descriptor)
This Op builds a nvgpu.warpgroup.descriptor
that is used by
nvgpu.warpgroup.mma
to perform warpgroup-level matrix multiply and
accumulate.
The descriptor specifies the properties of the matrix in shared memory that is a multiplicand in the matrix multiply and accumulate operation.
Operands: ¶
Operand | Description |
---|---|
tensor | memref of any type values |
tensorMap | TensorMap descriptor |
Results: ¶
Result | Description |
---|---|
descriptor | Warpgroup matrix descriptor type |
nvgpu.warpgroup.mma
(nvgpu::WarpgroupMmaOp) ¶
Syntax:
operation ::= `nvgpu.warpgroup.mma` $descriptorA`,` $descriptorB`,` $matrixC attr-dict
`:` type($descriptorA) `,` type($descriptorB) `,` type($matrixC) `->` type($matrixD)
The nvgpu.warpgroup.mma
op performs the warpgroup-level (4 warps)
matrix-multiply-and-accumulate (mma) operation that results in
nvvm.wgmma.mma_async
.
The operands are descriptorA
and descriptorB
that are wgmma matrix
descriptors that shows the properties of the matrix in shared memory. The
results are thread-level ownership to the warpgroup-level mma operation
shape. The shape is deduced from the descriptor types and output vector.
The Op encapsulates multiple nvvm.wgmma.mma_async
operations to complete
the given shape. As nvvm.wgmma.async
Op, or its corresponding PTX
instruction, is asynchronous, this Op groups the nvvm.wgmma.async
and
surrounds them between wgmma.fence.aligned
and
wgmma.commit.group.sync.aligned
, wgmma.wait.group.sync.aligned
Ops.
Example:
%r1,%r2 = nvgpu.warpgroup.mma %descA, %descB, %acc1, %acc2:
!nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>,
!nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>>,
!nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>,
!nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>
->
!nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>,
!nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
waitGroup | ::mlir::IntegerAttr | 32-bit signless integer attribute |
transposeA | ::mlir::UnitAttr | unit attribute |
transposeB | ::mlir::UnitAttr | unit attribute |
Operands: ¶
Operand | Description |
---|---|
descriptorA | Warpgroup matrix descriptor type |
descriptorB | Warpgroup matrix descriptor type |
matrixC |
Results: ¶
Result | Description |
---|---|
matrixD |
nvgpu.warpgroup.mma.init.accumulator
(nvgpu::WarpgroupMmaInitAccumulatorOp) ¶
Initializes the accumulator matrix
Syntax:
operation ::= `nvgpu.warpgroup.mma.init.accumulator` attr-dict `->` type($matrixC)
This Op generates and initializes the accumulator matrix for
nvgpu.warpgroup.mma
op to perform matrix-multiply-and-accumulate.
Results: ¶
Result | Description |
---|---|
matrixC |
nvgpu.warpgroup.mma.store
(nvgpu::WarpgroupMmaStoreOp) ¶
Syntax:
operation ::= `nvgpu.warpgroup.mma.store` $matrixD `,` $dstMemref attr-dict `:` type($matrixD) `to` type($dstMemref)
The nvgpu.warpgroup.mma.store
op performs the store of fragmented result
in $matrixD to given memref.
[See the details of register fragment layout for accumulator matrix D] ( https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-d)
Note that, the op must be run with warp group.
Operands: ¶
Operand | Description |
---|---|
matrixD | |
dstMemref | memref of any type values |
Attributes ¶
RcpRoundingModeAttr ¶
Rounding mode of rcp
Syntax:
#nvgpu.rcp_rounding_mode<
::mlir::nvgpu::RcpRoundingMode # value
>
Enum cases:
- approx (
APPROX
) - rn (
RN
) - rz (
RZ
) - rm (
RM
) - rp (
RP
)
Parameters: ¶
Parameter | C++ type | Description |
---|---|---|
value | ::mlir::nvgpu::RcpRoundingMode | an enum of type RcpRoundingMode |
TensorMapInterleaveKindAttr ¶
Tensor map interleave layout type
Syntax:
#nvgpu.interleave<
::mlir::nvgpu::TensorMapInterleaveKind # value
>
Enum cases:
- none (
INTERLEAVE_NONE
) - interleave_16b (
INTERLEAVE_16B
) - interleave_32b (
INTERLEAVE_32B
)
Parameters: ¶
Parameter | C++ type | Description |
---|---|---|
value | ::mlir::nvgpu::TensorMapInterleaveKind | an enum of type TensorMapInterleaveKind |
TensorMapL2PromoKindAttr ¶
Tensor map L2 promotion type
Syntax:
#nvgpu.l2promo<
::mlir::nvgpu::TensorMapL2PromoKind # value
>
Enum cases:
- none (
L2PROMO_NONE
) - l2promo_64b (
L2PROMO_64B
) - l2promo_128b (
L2PROMO_128B
) - l2promo_256b (
L2PROMO_256B
)
Parameters: ¶
Parameter | C++ type | Description |
---|---|---|
value | ::mlir::nvgpu::TensorMapL2PromoKind | an enum of type TensorMapL2PromoKind |
TensorMapOOBKindAttr ¶
Tensor map out-of-bounds fill type
Syntax:
#nvgpu.oob<
::mlir::nvgpu::TensorMapOOBKind # value
>
Enum cases:
- zero (
OOB_ZERO
) - nan (
OOB_NAN
)
Parameters: ¶
Parameter | C++ type | Description |
---|---|---|
value | ::mlir::nvgpu::TensorMapOOBKind | an enum of type TensorMapOOBKind |
TensorMapSwizzleKindAttr ¶
Tensor map swizzling mode of shared memory banks
Syntax:
#nvgpu.swizzle<
::mlir::nvgpu::TensorMapSwizzleKind # value
>
Enum cases:
- none (
SWIZZLE_NONE
) - swizzle_32b (
SWIZZLE_32B
) - swizzle_64b (
SWIZZLE_64B
) - swizzle_128b (
SWIZZLE_128B
)
Parameters: ¶
Parameter | C++ type | Description |
---|---|---|
value | ::mlir::nvgpu::TensorMapSwizzleKind | an enum of type TensorMapSwizzleKind |
Types ¶
DeviceAsyncTokenType ¶
device async token type
Syntax: !nvgpu.device.async.token
nvgpu.device.async.token
is a type returned by an asynchronous operation
that runs on the GPU (device). It is used to establish an SSA-based link
between the async operation (e.g. DeviceAsyncCopy) and operations that
group or synchronize the async operations (e.g. DeviceAsyncCreateGroupOp,
DeviceAsyncWaitOp).
MBarrierGroupType ¶
mbarrier barrier type
Syntax:
!nvgpu.mbarrier.group<
Attribute, # memorySpace
unsigned # num_barriers
>
This is the type for one or more mbarrier object in shared memory that is used to synchronize a variable number of threads.
If num_barriers
is not set, the number of mbarrier objects is 1.
A mbarrier object is 64 bit with 8 byte alignment. The mbarrier object can be initiated and invalidated.
See for more details in PTX ISA
Parameters: ¶
Parameter | C++ type | Description |
---|---|---|
memorySpace | Attribute | |
num_barriers | unsigned |
MBarrierTokenType ¶
Syntax: !nvgpu.mbarrier.token
TensorMapDescriptorType ¶
TensorMap descriptor
Syntax:
!nvgpu.tensormap.descriptor<
MemRefType, # tensor
::mlir::nvgpu::TensorMapSwizzleKind, # swizzle
::mlir::nvgpu::TensorMapL2PromoKind, # l2promo
::mlir::nvgpu::TensorMapOOBKind, # oob
::mlir::nvgpu::TensorMapInterleaveKind # interleave
>
nvgpu.tma.descriptor
is a type that represents a TMA descriptor. It is
128-byte object either in constant space or kernel paramater.
Parameters: ¶
Parameter | C++ type | Description |
---|---|---|
tensor | MemRefType | |
swizzle | ::mlir::nvgpu::TensorMapSwizzleKind | an enum of type TensorMapSwizzleKind |
l2promo | ::mlir::nvgpu::TensorMapL2PromoKind | an enum of type TensorMapL2PromoKind |
oob | ::mlir::nvgpu::TensorMapOOBKind | an enum of type TensorMapOOBKind |
interleave | ::mlir::nvgpu::TensorMapInterleaveKind | an enum of type TensorMapInterleaveKind |
WarpgroupAccumulatorType ¶
Syntax:
!nvgpu.warpgroup.accumulator<
VectorType # fragmented
>
This type represents the result matrix obtained from nvgpu.warpgroup.mma
.
The $fragmented
type signifies the distributed or fragmented result
vector that is collectively owned by all the threads in the warp-group
that executed nvgpu.warpgroup.mma
.
[See the details of register fragment layout for accumulator matrix D]
(
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-d)
Parameters: ¶
Parameter | C++ type | Description |
---|---|---|
fragmented | VectorType |
WarpgroupMatrixDescriptorType ¶
Warpgroup matrix descriptor type
Syntax:
!nvgpu.warpgroup.descriptor<
MemRefType # tensor
>
The descriptor specifies the properties of the matrix in shared memory that is a multiplicand in the matrix multiply and accumulate operation.
The descriptor is a 64-bit value contained in a register with the following:
+---------+-----+-----------+-----+-----------+-----+-----+-----------+-----+
| 0-13 |14-15| 16-29 |30-31| 32-45 |46-48|49-51| 52-61 |62-63|
+---------+-----+-----------+-----+-----------+-----+-----+-----------+-----+
| 14bits |2bits| 14bits |2bits| 14bits |2bits|3bits| 10bits |2bits|
+---------+-----+-----------+-----+-----------+-----+-----+-----------+-----+
| BaseAddr| 0 | LeadingDim| 0 | Stride | 0 |Offst| 0 |Swzle|
+---------+-----+-----------+-----+-----------+-----+-----+-----------+-----+
See for more details in PTX ISA
Parameters: ¶
Parameter | C++ type | Description |
---|---|---|
tensor | MemRefType |
Enums ¶
RcpRoundingMode ¶
Rounding mode of rcp
Cases: ¶
Symbol | Value | String |
---|---|---|
APPROX | 0 | approx |
RN | 1 | rn |
RZ | 2 | rz |
RM | 3 | rm |
RP | 4 | rp |
TensorMapInterleaveKind ¶
Tensor map interleave layout type
Cases: ¶
Symbol | Value | String |
---|---|---|
INTERLEAVE_NONE | 0 | none |
INTERLEAVE_16B | 1 | interleave_16b |
INTERLEAVE_32B | 2 | interleave_32b |
TensorMapL2PromoKind ¶
Tensor map L2 promotion type
Cases: ¶
Symbol | Value | String |
---|---|---|
L2PROMO_NONE | 0 | none |
L2PROMO_64B | 1 | l2promo_64b |
L2PROMO_128B | 2 | l2promo_128b |
L2PROMO_256B | 3 | l2promo_256b |
TensorMapOOBKind ¶
Tensor map out-of-bounds fill type
Cases: ¶
Symbol | Value | String |
---|---|---|
OOB_ZERO | 0 | zero |
OOB_NAN | 1 | nan |
TensorMapSwizzleKind ¶
Tensor map swizzling mode of shared memory banks
Cases: ¶
Symbol | Value | String |
---|---|---|
SWIZZLE_NONE | 0 | none |
SWIZZLE_32B | 1 | swizzle_32b |
SWIZZLE_64B | 2 | swizzle_64b |
SWIZZLE_128B | 3 | swizzle_128b |