MLIR

Multi-Level IR Compiler Framework

'nvgpu' Dialect

The NVGPU dialect provides a bridge between higher-level target-agnostic dialects (GPU and Vector) and the lower-level target-specific dialect (LLVM IR based NVVM dialect) for NVIDIA GPUs. This allow representing PTX specific operations while using MLIR high level dialects such as Memref and Vector for memory and target-specific register operands, respectively.

Operations 

source

nvgpu.device_async_copy (nvgpu::DeviceAsyncCopyOp) 

Device-side asynchronous copy

Syntax:

operation ::= `nvgpu.device_async_copy` $src `[` $srcIndices `]` `,` $dst `[` $dstIndices `]` `,` $dstElements (`,` $srcElements^)?
              attr-dict `:` type($src) `to` type($dst)

The nvgpu.device_async_copy op initiates an asynchronous copy operation of elements from source (global memory) to the destination (shared memory) without blocking the thread. The async copy is added to a group.

This op is meant to be used with nvgpu.device_async_create_group and nvgpu.device_async_wait to synchronize copies as explained in those ops descriptions.

bypassL1 attribute is hint to the hardware to bypass the L1 cache during async copy, this hint may be ignored by the hardware.

dstElements attribute is the total number of elements written to destination (shared memory).

srcElements argument is the total number of elements read from source (global memory).

srcElements is an optional argument and when present the op only reads srcElements number of elements from the source (global memory) and zero fills the rest of the elements in the destination (shared memory).

In order to do a copy and wait for the result we need the following combination:

// copy 1.
%cp1 = nvgpu.device_async_copy %A[%c0], %B[%c0], 4 :memref<16xf32> to memref<16xf32, 3>
// copy 2.
%cp2 = nvgpu.device_async_copy %C[%c0], %D[%c0], 4 : memref<16xf32> to memref<16xf32, 3>
// group 1 contains copy 1 and copy 2.
%token1 = nvgpu.device_async_create_group %cp1, %cp2
// copy 3.
%cp3 = nvgpu.device_async_copy %E[%c0], %F[%c0], 4 : memref<16xf32> to memref<16xf32, 3>
// group 2 contains copy 3.
%token2 = nvgpu.device_async_create_group %cp3
// after the wait copy 1 and copy 2 are complete.
nvgpu.device_async_wait %token1
// after the wait copy 3 is complete.
nvgpu.device_async_wait %token2

Example:

%0 = nvgpu.device_async_copy %src[%c0, %c0], %dst[%c0, %c0, %c0], 4 :
  memref<4x5xf32> to memref<2x7x5xf32, 3>

Traits: AttrSizedOperandSegments

Attributes: 

AttributeMLIR TypeDescription
dstElements::mlir::IntegerAttrindex attribute
bypassL1::mlir::UnitAttrunit attribute

Operands: 

OperandDescription
dstmemref of any type values
dstIndicesvariadic of index
srcmemref of any type values
srcIndicesvariadic of index
srcElementsindex

Results: 

ResultDescription
asyncTokendevice async token type

nvgpu.device_async_create_group (nvgpu::DeviceAsyncCreateGroupOp) 

Device side asynchronous create group operation

Syntax:

operation ::= `nvgpu.device_async_create_group` $inputTokens attr-dict

The nvgpu.device_async_create_group op creates a group of memory accesses containing all the pending device_async_copy operations associated with argument tokens. Each token can only be part of one group.

It returns a token that can be use to wait until the group fully completes.

This is meant to be used with nvgpu.device_async_wait to synchronize copies as explained in those ops descriptions.

Groups are executed in the order they are created.

Example:

%0 = nvgpu.device_async_create_group

Operands: 

OperandDescription
inputTokensvariadic of device async token type

Results: 

ResultDescription
asyncTokendevice async token type

nvgpu.device_async_wait (nvgpu::DeviceAsyncWaitOp) 

Wait for async gpu ops to complete.

Syntax:

operation ::= `nvgpu.device_async_wait` $asyncDependencies attr-dict

The nvgpu.device_async_wait op will block the execution thread until the group associated with the source token is fully completed.

The optional $numGroups attribute gives an upper bound of the number of groups uncompleted when the wait can unblock the thread. For example, if 16 async groups are pushe and $numGroups is set to 12, then the thread will unblock when 12 groups or fewer are in flight (4 groups have completed).

Example:

nvgpu.device_async_wait %0

Attributes: 

AttributeMLIR TypeDescription
numGroups::mlir::IntegerAttr32-bit signless integer attribute

Operands: 

OperandDescription
asyncDependenciesdevice async token type

nvgpu.ldmatrix (nvgpu::LdMatrixOp) 

Syntax:

operation ::= `nvgpu.ldmatrix` $srcMemref`[` $indices `]` attr-dict `:` type($srcMemref) `->` type($res)

The nvgpu.ldmatrix op represents loading a matrix fragment from memory to registers. The source and result type must be compatible with lowering to the nvvm.ldmatrix instruction. This op represents the distributed version of a vector.transfer_read as an intermediate step between lowering from vector.transfer_read to nvvm.ldmatrix.

This operation is meant to follow the semantic of described here: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-ldmatrix

Example:

%0 = nvgpu.ldmatrix %sm[%c0, %c0] {numTiles = 4 : i32, transpose = false} :
  memref<?x?xf16, 3> -> vector<4x2xf16>

Interfaces: MemoryEffectOpInterface (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{MemoryEffects::Read on ::mlir::SideEffects::DefaultResource}

Attributes: 

AttributeMLIR TypeDescription
transpose::mlir::BoolAttrbool attribute
numTiles::mlir::IntegerAttr32-bit signless integer attribute

Operands: 

OperandDescription
srcMemrefmemref of any type values
indicesvariadic of index

Results: 

ResultDescription
resvector of any type values

nvgpu.mbarrier.arrive (nvgpu::MBarrierArriveOp) 

Performs arrive operation on the nvgpu.mbarrier.arrive.

Syntax:

operation ::= `nvgpu.mbarrier.arrive` $barriers `[` $mbarId `]` attr-dict `:` type($barriers) `->` type($token)

The Op performs arrive-on operation on the mbarrier object and returns a nvgpu.mbarrier.token.

For more information, see https://docs.nvidia.com/cuda/parallel-thread-execution/#arrive-on-operation-on-mbarrier-object

Example:

  %token = nvgpu.mbarrier.arrive %barrier : !nvgpu.mbarrier.barrier<memorySpace = #gpu.address_space<workgroup>> -> !nvgpu.mbarrier.token

Operands: 

OperandDescription
barriersmbarrier barrier type
mbarIdindex

Results: 

ResultDescription
token

nvgpu.mbarrier.arrive.expect_tx (nvgpu::MBarrierArriveExpectTxOp) 

Performs expect_tx operation on the nvgpu.mbarrier.arrive

Syntax:

operation ::= `nvgpu.mbarrier.arrive.expect_tx` $barriers `[` $mbarId `]` `,` $txcount  (`,` `predicate` `=` $predicate^)? attr-dict `:` type($barriers)

A thread executing the Op performs an expect-tx operation on the mbarrier object at the location specified by the address operand $barrier. The expect-tx operation, with an $txcount argument, increases the tx-count of an mbarrier object by the value specified by $txcount. This makes the current phase of the mbarrier object to expect and track the completion of additional asynchronous transactions.

The $txCount specifies the number of element to the expect-tx operation.

Example:

  nvgpu.mbarrier.arrive.expect_tx %barrier, %ic0 : !nvgpu.mbarrier.barrier<memorySpace = #gpu.address_space<workgroup>>

Operands: 

OperandDescription
barriersmbarrier barrier type
txcountindex
mbarIdindex
predicate1-bit signless integer

nvgpu.mbarrier.arrive.nocomplete (nvgpu::MBarrierArriveNoCompleteOp) 

Performs arrive operation on the nvgpu.mbarrier.arrive.nocomplete as non-blocking.

Syntax:

operation ::= `nvgpu.mbarrier.arrive.nocomplete` $barriers `[` $mbarId `]` `,` $count attr-dict `:` type($barriers) `->` type($token)

The Op performs arrive-on operation on the mbarrier object and returns a nvgpu.mbarrier.token.

The Op does not cause the nvgpu.mbarrier to complete its current phase.

Example:

  %token = nvgpu.mbarrier.arrive.noComplete %barrier, %count : !nvgpu.mbarrier.barrier<memorySpace = #gpu.address_space<workgroup>> -> !nvgpu.mbarrier.token

Operands: 

OperandDescription
barriersmbarrier barrier type
mbarIdindex
countindex

Results: 

ResultDescription
token

nvgpu.mbarrier.create (nvgpu::MBarrierCreateOp) 

Creates a nvgpu.mbarrier object.

Syntax:

operation ::= `nvgpu.mbarrier.create` attr-dict `->` type($barriers)

The Op generates one or more mbarrier object, which is a barrier created in shared memory and supports various synchronization behaviors for threads.

The mbarrier object has the following type and alignment requirements: Type: .b64, Alignment: 8, Memory space: .shared

Example:

  %barrier = nvgpu.mbarrier.create -> !nvgpu.mbarrier.barrier<memorySpace = #gpu.address_space<workgroup>>

Results: 

ResultDescription
barriersmbarrier barrier type

nvgpu.mbarrier.init (nvgpu::MBarrierInitOp) 

Initialize the nvgpu.mbarrier.

Syntax:

operation ::= `nvgpu.mbarrier.init` $barriers `[` $mbarId `]` `,` $count (`,` `predicate` `=` $predicate^)? attr-dict `:` type($barriers)

The Op initializes the mbarrier object with the given number of threads.

Example:

  %num_threads = gpu.block_dim x
  %barrier = nvgpu.mbarrier.create -> !nvgpu.mbarrier.barrier<memorySpace = #gpu.address_space<workgroup>>
  nvgpu.mbarrier.init %barrier, %num_threads : !nvgpu.mbarrier.barrier<memorySpace = #gpu.address_space<workgroup>>

Operands: 

OperandDescription
barriersmbarrier barrier type
countindex
mbarIdindex
predicate1-bit signless integer

nvgpu.mbarrier.test.wait (nvgpu::MBarrierTestWaitOp) 

Checks if the nvgpu.mbarrier has completed its current phase.

Syntax:

operation ::= `nvgpu.mbarrier.test.wait` $barriers `[` $mbarId `]` `,` $token attr-dict `:` type($barriers) `,` type($token)

Checks whether the mbarrier object has completed the phase. It is is a non-blocking instruction which tests for the completion of the phase.

Example:

  %isComplete = nvgpu.mbarrier.test.wait %barrier, %token : !nvgpu.mbarrier.barrier<memorySpace = #gpu.address_space<workgroup>>, !nvgpu.mbarrier.token

Operands: 

OperandDescription
barriersmbarrier barrier type
token
mbarIdindex

Results: 

ResultDescription
waitComplete1-bit signless integer

nvgpu.mbarrier.try_wait.parity (nvgpu::MBarrierTryWaitParityOp) 

Waits for the nvgpu.mbarrier to complete its current phase.

Syntax:

operation ::= `nvgpu.mbarrier.try_wait.parity` $barriers `[` $mbarId `]` `,` $phaseParity `,` $ticks attr-dict `:` type($barriers)

Checks whether the mbarrier object has completed the phase. It is is a potentially blocking instruction which tests for the completion of the phase. Suspended thread resumes execution when the specified phase completes OR before the phase completes following a system-dependent time limit.

The $phaseParity specifies either even phase (0) or odd phase (1) to wait.

Example:

  nvgpu.mbarrier.try_wait.parity %barrier, %phaseParity, %ticks : !nvgpu.mbarrier.barrier<memorySpace = #gpu.address_space<workgroup>>

Operands: 

OperandDescription
barriersmbarrier barrier type
phaseParity1-bit signless integer
ticksindex
mbarIdindex

nvgpu.mma.sp.sync (nvgpu::MmaSparseSyncOp) 

Syntax:

operation ::= `nvgpu.mma.sp.sync` `(` $matrixA`,` $matrixB`,` $matrixC `)` `metadata` `(` $sparseMetadata `)` attr-dict
              `:` `(` type($matrixA) `,` type($matrixB) `,` type($matrixC) `)` `->` type($res)

The nvgu.mma.sp.sync operation performs a warp-distributed MMA operation where operand A is “structured sparse”. In this case, the matrixA operand represents the (warp-distributed) non-zero values of operand A, and the sparse_metadata operand provides the indices.

The full description of the sparsity storage format and distribution scheme is described in the PTX docs. This operation is meant to follow the semantic described in the PTX documentation here: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-sparse-mma

The way the indices are distributed among the threads in a warp is controlled by the optional sparsity_selector operand, which is 0 by default. For more information, please consult the PTX documentation linked above.

Example (targetingthe f16 16x8x32 mma.sp PTX instruction):

nvgpu.mma.sp.sync (%a, %b, %c) metadata (%meta) {mmaShape = [16, 8, 32]} :
  (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
mmaShape::mlir::ArrayAttr64-bit integer array attribute
sparsitySelector::mlir::IntegerAttr32-bit signless integer attribute
tf32Enabled::mlir::UnitAttrunit attribute

Operands: 

OperandDescription
matrixAvector of any type values
matrixBvector of any type values
matrixCvector of any type values
sparseMetadatafixed-length vector of 16-bit signless integer values of length 2

Results: 

ResultDescription
resvector of any type values

nvgpu.mma.sync (nvgpu::MmaSyncOp) 

Syntax:

operation ::= `nvgpu.mma.sync` `(` $matrixA`,` $matrixB`,` $matrixC `)` attr-dict
              `:` `(` type($matrixA) `,` type($matrixB) `,` type($matrixC) `)` `->` type($res)

The nvgpu.mma.sync op represents the warp-level matrix-multiply-and- accumulate (mma) operation that is compatible with nvvm.mma.sync. The operands and results vector sizes are thread-level onwership to the warp-level mma operation shape. mmaShape attribute holds the warp-level matrix-multiply shape.

The nvgpu.mma.sync op serves as an intermediate point between lowering from vector.contract to nvvm.mma.sync.

This operation is meant to follow the semantic of described here: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma

Example:

%res = nvgpu.mma.sync (%matrixA, %matrixB, %matrixC) {mmaShape = [16, 8, 16]} :
    (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf32>) -> vector<2x2xf32>

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
mmaShape::mlir::ArrayAttr64-bit integer array attribute
tf32Enabled::mlir::UnitAttrunit attribute

Operands: 

OperandDescription
matrixAvector of any type values
matrixBvector of any type values
matrixCvector of any type values

Results: 

ResultDescription
resvector of any type values

nvgpu.tma.async.load (nvgpu::TmaAsyncLoadOp) 

TMA asynchronous load

Syntax:

operation ::= `nvgpu.tma.async.load` $tensorMapDescriptor `[` $coordinates `]` `,` $barriers `[` $mbarId `]`
              `to` $dst
              (`multicast_mask` `=` $multicastMask^ )?
              (`,` `predicate` `=` $predicate^)?
              attr-dict `:` type($tensorMapDescriptor) `,` type($barriers)
              `->` type($dst)

The Op loads a tile memory region from global memory to shared memory by Tensor Memory Access (TMA).

$tensorMapDescriptor is tensor map descriptor which has information about tile shape. The descriptor is created by nvgpu.tma.create.descriptor

The Op uses $barrier mbarrier based completion mechanism.

Traits: AttrSizedOperandSegments

Operands: 

OperandDescription
dstmemref of any type values
barriersmbarrier barrier type
tensorMapDescriptorTensorMap descriptor
coordinatesvariadic of index
mbarIdindex
multicastMask16-bit signless integer
predicate1-bit signless integer

nvgpu.tma.async.store (nvgpu::TmaAsyncStoreOp) 

TMA asynchronous store

Syntax:

operation ::= `nvgpu.tma.async.store` $src `to` $tensorMapDescriptor `[` $coordinates `]`
              (`,` `predicate` `=` $predicate^)?
              attr-dict `:` type($src)
              `->` type($tensorMapDescriptor)

The Op store a tile memory region from global memory to shared memory by Tensor Memory Access (TMA).

$tensorMapDescriptor is tensor map descriptor which has information about tile shape. The descriptor is created by nvgpu.tma.create.descriptor

Traits: AttrSizedOperandSegments

Operands: 

OperandDescription
srcmemref of any type values
tensorMapDescriptorTensorMap descriptor
coordinatesvariadic of index
predicate1-bit signless integer

nvgpu.tma.create.descriptor (nvgpu::TmaCreateDescriptorOp) 

TMA create descriptor

Syntax:

operation ::= `nvgpu.tma.create.descriptor` $tensor `box` `[` $boxDimensions `]` attr-dict `:` type($tensor) `->` type($tensorMap)

The Op creates a tensor map descriptor object representing tiled memory region. To do that it calls CUDA Driver’s cuTensorMapEncodeTiled. The descriptor is used by Tensor Memory Access (TMA).

The tensor is the source tensor to be tiled.

The boxDimensions is the size of the tiled memory region in each dimension.

For more information see below: https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html

Operands: 

OperandDescription
tensorunranked.memref of any type values
boxDimensionsvariadic of index

Results: 

ResultDescription
tensorMapTensorMap descriptor

nvgpu.tma.prefetch.descriptor (nvgpu::TmaPrefetchOp) 

Prefetch given nvgpu.tensormap.descriptor

Syntax:

operation ::= `nvgpu.tma.prefetch.descriptor` $tensorMapDescriptor (`,` `predicate` `=` $predicate^)? attr-dict `:` type($tensorMapDescriptor)

The Op brings the cache line containing the given $tmaDescriptor for subsequent use by the tma.async.load instruction.

Operands: 

OperandDescription
tensorMapDescriptorTensorMap descriptor
predicate1-bit signless integer

nvgpu.warpgroup.generate.descriptor (nvgpu::WarpgroupGenerateDescriptorOp) 

Generate a warpgroup matrix descriptor

Syntax:

operation ::= `nvgpu.warpgroup.generate.descriptor` $tensor `,` $tensorMap attr-dict `:` type($tensor) `,` type($tensorMap) `->` type($descriptor)

This Op builds a nvgpu.warpgroup.descriptor that is used by nvgpu.warpgroup.mma to perform warpgroup-level matrix multiply and accumulate.

The descriptor specifies the properties of the matrix in shared memory that is a multiplicand in the matrix multiply and accumulate operation.

Operands: 

OperandDescription
tensormemref of any type values
tensorMapTensorMap descriptor

Results: 

ResultDescription
descriptorWarpgroup matrix descriptor type

nvgpu.warpgroup.mma (nvgpu::WarpgroupMmaOp) 

Syntax:

operation ::= `nvgpu.warpgroup.mma` $descriptorA`,` $descriptorB`,` $matrixC attr-dict
              `:` type($descriptorA) `,` type($descriptorB) `,` type($matrixC) `->` type($matrixD)

The nvgpu.warpgroup.mma op performs the warpgroup-level (4 warps) matrix-multiply-and-accumulate (mma) operation that results in nvvm.wgmma.mma_async.

The operands are descriptorA and descriptorB that are wgmma matrix descriptors that shows the properties of the matrix in shared memory. The results are thread-level ownership to the warpgroup-level mma operation shape. The shape is deduced from the descriptor types and output vector.

The Op encapsulates multiple nvvm.wgmma.mma_async operations to complete the given shape. As nvvm.wgmma.async Op, or its corresponding PTX instruction, is asynchronous, this Op groups the nvvm.wgmma.async and surrounds them between wgmma.fence.aligned and wgmma.commit.group.sync.aligned, wgmma.wait.group.sync.aligned Ops.

Example:

  %r1,%r2 = nvgpu.warpgroup.mma %descA, %descB, %acc1, %acc2: 
             !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>, 
             !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>>, 
             !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>,
             !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>
             -> 
             !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>,
             !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>

Attributes: 

AttributeMLIR TypeDescription
waitGroup::mlir::IntegerAttr32-bit signless integer attribute
transposeA::mlir::UnitAttrunit attribute
transposeB::mlir::UnitAttrunit attribute

Operands: 

OperandDescription
descriptorAWarpgroup matrix descriptor type
descriptorBWarpgroup matrix descriptor type
matrixC

Results: 

ResultDescription
matrixD

nvgpu.warpgroup.mma.init.accumulator (nvgpu::WarpgroupMmaInitAccumulatorOp) 

Initializes the accumulator matrix

Syntax:

operation ::= `nvgpu.warpgroup.mma.init.accumulator` attr-dict `->` type($matrixC)

This Op generates and initializes the accumulator matrix for nvgpu.warpgroup.mma op to perform matrix-multiply-and-accumulate.

Results: 

ResultDescription
matrixC

nvgpu.warpgroup.mma.store (nvgpu::WarpgroupMmaStoreOp) 

Syntax:

operation ::= `nvgpu.warpgroup.mma.store` $matrixD `,` $dstMemref attr-dict `:` type($matrixD) `to` type($dstMemref)

The nvgpu.warpgroup.mma.store op performs the store of fragmented result in $matrixD to given memref.

[See the details of register fragment layout for accumulator matrix D] ( https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-d)

Note that, the op must be run with warp group.

Operands: 

OperandDescription
matrixD
dstMemrefmemref of any type values

Attributes 

TensorMapInterleaveKindAttr 

Tensor map interleave layout type

Syntax:

#nvgpu.interleave<
  ::mlir::nvgpu::TensorMapInterleaveKind   # value
>

Enum cases:

  • none (INTERLEAVE_NONE)
  • interleave_16b (INTERLEAVE_16B)
  • interleave_32b (INTERLEAVE_32B)

Parameters: 

ParameterC++ typeDescription
value::mlir::nvgpu::TensorMapInterleaveKindan enum of type TensorMapInterleaveKind

TensorMapL2PromoKindAttr 

Tensor map L2 promotion type

Syntax:

#nvgpu.l2promo<
  ::mlir::nvgpu::TensorMapL2PromoKind   # value
>

Enum cases:

  • none (L2PROMO_NONE)
  • l2promo_64b (L2PROMO_64B)
  • l2promo_128b (L2PROMO_128B)
  • l2promo_256b (L2PROMO_256B)

Parameters: 

ParameterC++ typeDescription
value::mlir::nvgpu::TensorMapL2PromoKindan enum of type TensorMapL2PromoKind

TensorMapOOBKindAttr 

Tensor map out-of-bounds fill type

Syntax:

#nvgpu.oob<
  ::mlir::nvgpu::TensorMapOOBKind   # value
>

Enum cases:

  • zero (OOB_ZERO)
  • nan (OOB_NAN)

Parameters: 

ParameterC++ typeDescription
value::mlir::nvgpu::TensorMapOOBKindan enum of type TensorMapOOBKind

TensorMapSwizzleKindAttr 

Tensor map swizzling mode of shared memory banks

Syntax:

#nvgpu.swizzle<
  ::mlir::nvgpu::TensorMapSwizzleKind   # value
>

Enum cases:

  • none (SWIZZLE_NONE)
  • swizzle_32b (SWIZZLE_32B)
  • swizzle_64b (SWIZZLE_64B)
  • swizzle_128b (SWIZZLE_128B)

Parameters: 

ParameterC++ typeDescription
value::mlir::nvgpu::TensorMapSwizzleKindan enum of type TensorMapSwizzleKind

Types 

DeviceAsyncTokenType 

device async token type

Syntax: !nvgpu.device.async.token

nvgpu.device.async.token is a type returned by an asynchronous operation that runs on the GPU (device). It is used to establish an SSA-based link between the async operation (e.g. DeviceAsyncCopy) and operations that group or synchronize the async operations (e.g. DeviceAsyncCreateGroupOp, DeviceAsyncWaitOp).

MBarrierGroupType 

mbarrier barrier type

Syntax:

!nvgpu.mbarrier.group<
  Attribute,   # memorySpace
  unsigned   # num_barriers
>

This is the type for one or more mbarrier object in shared memory that is used to synchronize a variable number of threads.

If num_barriers is not set, the number of mbarrier objects is 1.

A mbarrier object is 64 bit with 8 byte alignment. The mbarrier object can be initiated and invalidated.

See for more details in PTX ISA

Parameters: 

ParameterC++ typeDescription
memorySpaceAttribute
num_barriersunsigned

MBarrierTokenType 

Syntax: !nvgpu.mbarrier.token

TensorMapDescriptorType 

TensorMap descriptor

Syntax:

!nvgpu.tensormap.descriptor<
  MemRefType,   # tensor
  ::mlir::nvgpu::TensorMapSwizzleKind,   # swizzle
  ::mlir::nvgpu::TensorMapL2PromoKind,   # l2promo
  ::mlir::nvgpu::TensorMapOOBKind,   # oob
  ::mlir::nvgpu::TensorMapInterleaveKind   # interleave
>

nvgpu.tma.descriptor is a type that represents a TMA descriptor. It is 128-byte object either in constant space or kernel paramater.

Parameters: 

ParameterC++ typeDescription
tensorMemRefType
swizzle::mlir::nvgpu::TensorMapSwizzleKindan enum of type TensorMapSwizzleKind
l2promo::mlir::nvgpu::TensorMapL2PromoKindan enum of type TensorMapL2PromoKind
oob::mlir::nvgpu::TensorMapOOBKindan enum of type TensorMapOOBKind
interleave::mlir::nvgpu::TensorMapInterleaveKindan enum of type TensorMapInterleaveKind

WarpgroupAccumulatorType 

Syntax:

!nvgpu.warpgroup.accumulator<
  VectorType   # fragmented
>

This type represents the result matrix obtained from nvgpu.warpgroup.mma. The $fragmented type signifies the distributed or fragmented result vector that is collectively owned by all the threads in the warp-group that executed nvgpu.warpgroup.mma. [See the details of register fragment layout for accumulator matrix D] ( https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-d)

Parameters: 

ParameterC++ typeDescription
fragmentedVectorType

WarpgroupMatrixDescriptorType 

Warpgroup matrix descriptor type

Syntax:

!nvgpu.warpgroup.descriptor<
  MemRefType   # tensor
>

The descriptor specifies the properties of the matrix in shared memory that is a multiplicand in the matrix multiply and accumulate operation.

The descriptor is a 64-bit value contained in a register with the following:

+---------+-----+-----------+-----+-----------+-----+-----+-----------+-----+
|   0-13  |14-15|   16-29   |30-31|   32-45   |46-48|49-51|   52-61   |62-63|
+---------+-----+-----------+-----+-----------+-----+-----+-----------+-----+
|  14bits |2bits|   14bits  |2bits|   14bits  |2bits|3bits|   10bits  |2bits|
+---------+-----+-----------+-----+-----------+-----+-----+-----------+-----+
| BaseAddr|  0  | LeadingDim|  0  |   Stride  |  0  |Offst|     0     |Swzle|
+---------+-----+-----------+-----+-----------+-----+-----+-----------+-----+

See for more details in PTX ISA

Parameters: 

ParameterC++ typeDescription
tensorMemRefType