MLIR

Multi-Level IR Compiler Framework

'gpu' Dialect

Note: this dialect is more likely to change than others in the near future; use with caution.

This dialect provides middle-level abstractions for launching GPU kernels following a programming model similar to that of CUDA or OpenCL. It provides abstractions for kernel invocations (and may eventually provide those for device management) that are not present at the lower level (e.g., as LLVM IR intrinsics for GPUs). Its goal is to abstract away device- and driver-specific manipulations to launch a GPU kernel and provide a simple path towards GPU execution from MLIR. It may be targeted, for example, by DSLs using MLIR. The dialect uses gpu as its canonical prefix.

Memory attribution 

Memory buffers are defined at the function level, either in “gpu.launch” or in “gpu.func” ops. This encoding makes it clear where the memory belongs and makes the lifetime of the memory visible. The memory is only accessible while the kernel is launched/the function is currently invoked. The latter is more strict than actual GPU implementations but using static memory at the function level is just for convenience. It is also always possible to pass pointers to the workgroup memory into other functions, provided they expect the correct memory space.

The buffers are considered live throughout the execution of the GPU function body. The absence of memory attribution syntax means that the function does not require special buffers. Rationale: although the underlying models declare memory buffers at the module level, we chose to do it at the function level to provide some structuring for the lifetime of those buffers; this avoids the incentive to use the buffers for communicating between different kernels or launches of the same kernel, which should be done through function arguments instead; we chose not to use alloca-style approach that would require more complex lifetime analysis following the principles of MLIR that promote structure and representing analysis results in the IR.

GPU Compilation 

Deprecation notice 

The --gpu-to-(cubin|hsaco) passes will be deprecated in a future release.

Compilation overview 

The compilation process in the GPU dialect has two main stages: GPU module serialization and offloading operations translation. Together these stages can produce GPU binaries and the necessary code to execute them.

An example of how the compilation workflow look is:

mlir-opt example.mlir                   \
  --pass-pipeline="builtin.module(      \
    gpu-kernel-outlining,               \ # Outline gpu.launch body to a kernel.
    nvvm-attach-target{chip=sm_90 O=3}, \ # Attach an NVVM target to a gpu.module op.
    gpu.module(convert-gpu-to-nvvm),    \ # Convert GPU to NVVM.
    gpu-to-llvm,                        \ # Convert GPU to LLVM.
    gpu-module-to-binary                \ # Serialize GPU modules to binaries.
  )" -o example-nvvm.mlir
mlir-translate example-nvvm.mlir        \
  --mlir-to-llvmir                      \ # Obtain the translated LLVM IR.
  -o example.ll

Default NVVM Compilation Pipeline: gpu-lower-to-nvvm-pipeline 

The gpu-lower-to-nvvm-pipeline compilation pipeline serves as the default way for NVVM target compilation within MLIR. This pipeline operates by lowering primary dialects (arith, memref, scf, vector, gpu, and nvgpu) to NVVM target. It begins by lowering GPU code region(s) to the specified NVVM compilation target and subsequently handles the host code.

This pipeline specifically requires explicitly parallel IR and doesn’t do GPU parallelization. To enable parallelism, necessary transformations must be applied before utilizing this pipeline.

It’s designed to provide a generic solution for NVVM targets, generating NVVM and LLVM dialect code compatible with mlir-cpu-runner or execution engine.

Example: 

Here’s a snippet illustrating the use of primary dialects, including arith, within GPU code execution:

func.func @main() {
    %c2 = arith.constant 2 : index
    %c1 = arith.constant 1 : index
    gpu.launch 
        blocks(%0, %1, %2) in (%3 = %c1, %4 = %c1, %5 = %c1) 
        threads(%6, %7, %8) in (%9 = %c2, %10 = %c1, %11 = %c1) { 
        gpu.printf "Hello from %d\n" %6 : index
        gpu.terminator
    }
    return
}

The gpu-lower-to-nvvm pipeline compiles this input code to NVVM format as below. It provides customization options like specifying SM capability, PTX version, and optimization level. Once compiled, the resulting IR is ready for execution using mlir-cpu-runner. Alternatively, it can be translated into LLVM, expanding its utility within the system.

mlir-opt example.mlir -gpu-lower-to-nvvm-pipeline = "cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3"

Module serialization 

Attributes implementing the GPU Target Attribute Interface handle the serialization process and are called Target attributes. These attributes can be attached to GPU Modules indicating the serialization scheme to compile the module into a binary string.

The gpu-module-to-binary pass searches for all nested GPU modules and serializes the module using the target attributes attached to the module, producing a binary with an object for every target.

Example:

// Input:
gpu.module @kernels [#nvvm.target<chip = "sm_90">, #nvvm.target<chip = "sm_60">] {
  ...
}
// mlir-opt --gpu-module-to-binary:
gpu.binary @kernels [
  #gpu.object<#nvvm.target<chip = "sm_90">, "sm_90 cubin">,
  #gpu.object<#nvvm.target<chip = "sm_60">, "sm_60 cubin">
]

Offloading LLVM translation 

Attributes implementing the GPU Offloading LLVM Translation Attribute Interface handle the translation of GPU binaries and kernel launches into LLVM instructions and are called Offloading attributes. These attributes are attached to GPU binary operations.

During the LLVM translation process, GPU binaries get translated using the scheme provided by the Offloading attribute, translating the GPU binary into LLVM instructions. Meanwhile, Kernel launches are translated by searching the appropriate binary and invoking the procedure provided by the Offloading attribute in the binary for translating kernel launches into LLVM instructions.

Example:

// Input:
// Binary with multiple objects but selecting the second one for embedding.
gpu.binary @binary <#gpu.select_object<#rocdl.target<chip = "gfx90a">>> [
    #gpu.object<#nvvm.target, "NVPTX">,
    #gpu.object<#rocdl.target<chip = "gfx90a">, "AMDGPU">
  ]
llvm.func @foo() {
  ...
  // Launching a kernel inside the binary.
  gpu.launch_func @binary::@func blocks in (%0, %0, %0)
                                 threads in (%0, %0, %0) : i64
                                 dynamic_shared_memory_size %2
                                 args(%1 : i32, %1 : i32)
  ...
}
// mlir-translate --mlir-to-llvmir:
@binary_bin_cst = internal constant [6 x i8] c"AMDGPU", align 8
@binary_func_kernel_name = private unnamed_addr constant [7 x i8] c"func\00", align 1
...
define void @foo() {
  ...
  %module = call ptr @mgpuModuleLoad(ptr @binary_bin_cst)
  %kernel = call ptr @mgpuModuleGetFunction(ptr %module, ptr @binary_func_kernel_name)
  call void @mgpuLaunchKernel(ptr %kernel, ...) ; Launch the kernel
  ...
  call void @mgpuModuleUnload(ptr %module)
  ...
}
...

The binary operation 

From a semantic point of view, GPU binaries allow the implementation of many concepts, from simple object files to fat binaries. By default, the binary operation uses the #gpu.select_object offloading attribute; this attribute embeds a single object in the binary as a global string, see the attribute docs for more information.

Operations 

source

gpu.all_reduce (gpu::AllReduceOp) 

Reduce values among workgroup.

Syntax:

operation ::= `gpu.all_reduce` custom<AllReduceOperation>($op) $value
              (`uniform` $uniform^)? $body attr-dict
              `:` functional-type(operands, results)

The all_reduce op reduces the value of every work item across a local workgroup. The result is equal for all work items of a workgroup.

For example, both

%1 = gpu.all_reduce add %0 {} : (f32) -> (f32)
%2 = gpu.all_reduce %0 {
^bb(%lhs : f32, %rhs : f32):
  %sum = arith.addf %lhs, %rhs : f32
  "gpu.yield"(%sum) : (f32) -> ()
} : (f32) -> (f32)

compute the sum of each work item’s %0 value. The first version specifies the accumulation as operation, whereas the second version specifies the accumulation as code region. The reduction operation must be one of:

  • Integer types: add, mul, minui, minsi, maxui, maxsi, and, or, xor
  • Floating point types: add, mul, minnumf, maxnumf, minimumf, maximumf

If uniform flag is set either none or all work items of a workgroup need to execute this op in convergence.

Traits: IsolatedFromAbove, SameOperandsAndResultType

Interfaces: InferTypeOpInterface

Attributes: 

AttributeMLIR TypeDescription
op::mlir::gpu::AllReduceOperationAttr
built-in reduction operations supported by gpu.allreduce.

Enum cases:

  • add (ADD)
  • mul (MUL)
  • minui (MINUI)
  • minsi (MINSI)
  • minnumf (MINNUMF)
  • maxui (MAXUI)
  • maxsi (MAXSI)
  • maxnumf (MAXNUMF)
  • and (AND)
  • or (OR)
  • xor (XOR)
  • minimumf (MINIMUMF)
  • maximumf (MAXIMUMF)
uniform::mlir::UnitAttrunit attribute

Operands: 

OperandDescription
valueInteger or Float

Results: 

ResultDescription
resultInteger or Float

gpu.alloc (gpu::AllocOp) 

GPU memory allocation operation.

Syntax:

operation ::= `gpu.alloc` custom<AsyncDependencies>(type($asyncToken), $asyncDependencies) (` ` `host_shared` $hostShared^)? ` `
              `(` $dynamicSizes `)` (`` `[` $symbolOperands^ `]`)? attr-dict `:` type($memref)

The gpu.alloc operation allocates a region of memory on the GPU. It is similar to the memref.alloc op, but supports asynchronous GPU execution.

The op does not execute before all async dependencies have finished executing.

If the async keyword is present, the op is executed asynchronously (i.e. it does not block until the execution has finished on the device). In that case, it also returns a !gpu.async.token.

If the host_shared keyword is present, the memory will be allocated in a memory accessible both on host and on device.

Example:

%memref, %token = gpu.alloc async [%dep] host_shared (%width) : memref<64x?xf32, 1>

Traits: AttrSizedOperandSegments

Interfaces: GPU_AsyncOpInterface

Attributes: 

AttributeMLIR TypeDescription
hostShared::mlir::UnitAttrunit attribute

Operands: 

OperandDescription
asyncDependenciesvariadic of async token type
dynamicSizesvariadic of index
symbolOperandsvariadic of index

Results: 

ResultDescription
memrefmemref of any type values
asyncTokenasync token type

gpu.barrier (gpu::BarrierOp) 

Synchronizes all work items of a workgroup.

Syntax:

operation ::= `gpu.barrier` attr-dict

The “barrier” op synchronizes all work items of a workgroup. It is used to coordinate communication between the work items of the workgroup.

gpu.barrier

waits until all work items in the workgroup have reached this point and all memory accesses made by these work items prior to the op are visible to all work items in the workgroup. Data hazards between work items accessing the same memory can be avoided by synchronizing work items in-between these accesses.

Either none or all work items of a workgroup need to execute this op in convergence.

gpu.binary (gpu::BinaryOp) 

An Op for storing serialized GPU binary objects.

Syntax:

operation ::= `gpu.binary` $sym_name custom<OffloadingHandler>($offloadingHandler) attr-dict $objects

GPU binaries provide a semantic mechanism for storing GPU objects, e.g. the result of compiling a GPU module to an object file.

This operation has 3 arguments:

  • The name of the binary.
  • An optional attribute implementing the offloading LLVM translation interface.
  • An array of GPU object attributes.

During translation, the offloading attribute will be called for translating GPU binary and launch_func operations. The default offloading handler is: #gpu.select_object, this handler selects the first object from the array and embeds it as a string.

Examples:

  // Selects the first object.
  gpu.binary @myobject [#gpu.object<...>, #gpu.object<...>]
  // Uses the `#foo.my_handler` for handling the binary during translation.
  gpu.binary @myobject <#foo.my_handler> [#gpu.object<...>, #gpu.object<...>]
  // Selects the object with the `#rocdl.target` target attribute.
  gpu.binary @myobject <#gpu.select_object<#rocdl.target>> [#gpu.object<...>, #gpu.object<#rocdl.target, ...>]

Interfaces: Symbol

Attributes: 

AttributeMLIR TypeDescription
sym_name::mlir::StringAttrstring attribute
offloadingHandler::mlir::Attributeany attribute with the `OffloadingTranslationAttrTrait` trait.
objects::mlir::ArrayAttran array of GPU object attributes with at least 1 elements

gpu.block_dim (gpu::BlockDimOp) 

Syntax:

operation ::= `gpu.block_dim` $dimension attr-dict

Returns the number of threads in the thread block (aka the block size) along the x, y, or z dimension.

Example:

%bDimX = gpu.block_dim x

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, InferIntRangeInterface, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
dimension::mlir::gpu::DimensionAttr
a dimension, either 'x', 'y', or 'z'

Enum cases:

  • x (x)
  • y (y)
  • z (z)

Results: 

ResultDescription
«unnamed»index

gpu.block_id (gpu::BlockIdOp) 

Syntax:

operation ::= `gpu.block_id` $dimension attr-dict

Returns the block id, i.e. the index of the current block within the grid along the x, y, or z dimension.

Example:

%bIdY = gpu.block_id y

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, InferIntRangeInterface, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
dimension::mlir::gpu::DimensionAttr
a dimension, either 'x', 'y', or 'z'

Enum cases:

  • x (x)
  • y (y)
  • z (z)

Results: 

ResultDescription
«unnamed»index

gpu.cluster_dim (gpu::ClusterDimOp) 

Syntax:

operation ::= `gpu.cluster_dim` $dimension attr-dict

Returns the number of thread blocks in the cluster along the x, y, or z dimension.

Example:

%cDimX = gpu.cluster_dim x

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, InferIntRangeInterface, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
dimension::mlir::gpu::DimensionAttr
a dimension, either 'x', 'y', or 'z'

Enum cases:

  • x (x)
  • y (y)
  • z (z)

Results: 

ResultDescription
«unnamed»index

gpu.cluster_id (gpu::ClusterIdOp) 

Syntax:

operation ::= `gpu.cluster_id` $dimension attr-dict

Returns the cluster id, i.e. the index of the current cluster within the grid along the x, y, or z dimension.

Example:

%cIdY = gpu.cluster_id y

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, InferIntRangeInterface, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
dimension::mlir::gpu::DimensionAttr
a dimension, either 'x', 'y', or 'z'

Enum cases:

  • x (x)
  • y (y)
  • z (z)

Results: 

ResultDescription
«unnamed»index

gpu.create_2to4_spmat (gpu::Create2To4SpMatOp) 

Create sparse matrix with 2:4 sparsity operation

Syntax:

operation ::= `gpu.create_2to4_spmat` custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
              `{` $pruneFlag `}` $rows `,` $cols `,` $memref attr-dict `:` type($memref)

The gpu.create_2to4_spmat operation initializes a sparse matrix in dense format with 2:4 sparsity. The buffers must already be copied from the host to the device prior to using this operation. The operation returns a handle to the sparse matrix descriptor.

If the async keyword is present, the op is executed asynchronously (i.e. it does not block until the execution has finished on the device). In that case, it returns a !gpu.async.token in addition to the environment.

Example:

%spmat, %token = gpu.create_2to4_spmat async [%dep] {PRUNE_AND_CHECK} %rows, %cols, %mem: memref<?xf64>

Interfaces: GPU_AsyncOpInterface

Attributes: 

AttributeMLIR TypeDescription
pruneFlag::mlir::gpu::Prune2To4SpMatFlagAttr
pruning strategy for 2:4 sparse matrix

Enum cases:

  • NONE (NONE)
  • PRUNE_ONLY (PRUNE_ONLY)
  • PRUNE_AND_CHECK (PRUNE_AND_CHECK)

Operands: 

OperandDescription
asyncDependenciesvariadic of async token type
rowsindex
colsindex
memrefmemref of any type values

Results: 

ResultDescription
spMatsparse matrix handle type
asyncTokenasync token type

gpu.create_bsr (gpu::CreateBsrOp) 

Create sparse matrix in BSR format operation

Syntax:

operation ::= `gpu.create_bsr` custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
              $brows `,` $bcols `,` $bnnz `,` $rBlockSize `,` $cBlockSize `,`
              $bRowPos `,` $bColIdxs `,` $values attr-dict
              `:` type($bRowPos) `,` type($bColIdxs) `,` type($values)

The gpu.create_bsr operation initializes a sparse matrix in BSR format with the given sizes for the matrix and blocks from the given position, index, and values buffers. The buffers must already be copied from the host to the device prior to using this operation. The operation returns a handle to the sparse matrix descriptor.

The BSR format is similar to CSR, where the column indices represent two-dimensional blocks instead of a single matrix entry. Note that this operation (currently) only supports storage with square blocks, i.e., rBlockSize == cBlockSize.

If the async keyword is present, the op is executed asynchronously (i.e. it does not block until the execution has finished on the device). In that case, it returns a !gpu.async.token in addition to the environment.

Example:

%spmat, %token = gpu.create_bsr async [%dep]
   %brows, %bcols, %bnnz, %rBlockSize, %cBlockSize,
   %bRowPos, %bColIdxs, %values : memref<?xindex>, memref<?xindex>, memref<?xf64>

Interfaces: GPU_AsyncOpInterface

Operands: 

OperandDescription
asyncDependenciesvariadic of async token type
browsindex
bcolsindex
bnnzindex
rBlockSizeindex
cBlockSizeindex
bRowPosmemref of any type values
bColIdxsmemref of any type values
valuesmemref of any type values

Results: 

ResultDescription
spmatsparse matrix handle type
asyncTokenasync token type

gpu.create_coo_aos (gpu::CreateCooAoSOp) 

Create sparse matrix in COO format operation (AoS)

Syntax:

operation ::= `gpu.create_coo_aos` custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
              $rows `,` $cols `,` $nnz `,` $idxs `,` $values attr-dict
              `:` type($idxs) `,` type($values)

The gpu.create_coo_aos operation initializes a sparse matrix in COO format with the given sizes from the given index and values buffers. The buffers must already be copied from the host to the device prior to using this operation. The operation returns a handle to the sparse matrix descriptor. Unlike the default gpu.create_coo operation, this operation builds the COO format from a single index buffer in AoS format (note that this feature has been deprecated in cuSparse 11.2).

If the async keyword is present, the op is executed asynchronously (i.e. it does not block until the execution has finished on the device). In that case, it returns a !gpu.async.token in addition to the environment.

Example:

%spmat, %token = gpu.create_coo_aos async [%dep] %rows, %cols, %nnz, %idxs,
    %values : memref<?xindex>, memref<?xf64>

Interfaces: GPU_AsyncOpInterface

Operands: 

OperandDescription
asyncDependenciesvariadic of async token type
rowsindex
colsindex
nnzindex
idxsmemref of any type values
valuesmemref of any type values

Results: 

ResultDescription
spmatsparse matrix handle type
asyncTokenasync token type

gpu.create_coo (gpu::CreateCooOp) 

Create sparse matrix in COO format operation

Syntax:

operation ::= `gpu.create_coo` custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
              $rows `,` $cols `,` $nnz `,` $rowIdxs `,` $colIdxs `,` $values attr-dict
              `:` type($rowIdxs) `,` type($colIdxs) `,` type($values)

The gpu.create_coo operation initializes a sparse matrix in COO format with the given sizes from the given index and values buffers. The buffers must already be copied from the host to the device prior to using this operation. The operation returns a handle to the sparse matrix descriptor. Note that this operation builds the COO in SoA format.

If the async keyword is present, the op is executed asynchronously (i.e. it does not block until the execution has finished on the device). In that case, it returns a !gpu.async.token in addition to the environment.

Example:

%spmat, %token = gpu.create_coo async [%dep] %rows, %cols, %nnz, %rowIdx,
    %colIdx, %values : memref<?xindex>, memref<?xindex>, memref<?xf64>

Interfaces: GPU_AsyncOpInterface

Operands: 

OperandDescription
asyncDependenciesvariadic of async token type
rowsindex
colsindex
nnzindex
rowIdxsmemref of any type values
colIdxsmemref of any type values
valuesmemref of any type values

Results: 

ResultDescription
spmatsparse matrix handle type
asyncTokenasync token type

gpu.create_csc (gpu::CreateCscOp) 

Create sparse matrix in CSC format operation

Syntax:

operation ::= `gpu.create_csc` custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
              $rows `,` $cols `,` $nnz `,` $colPos `,` $rowIdxs `,` $values attr-dict
              `:` type($colPos) `,` type($rowIdxs) `,` type($values)

The gpu.create_csc operation initializes a sparse matrix in CSC format with the given sizes from the given position, index, and values buffers. The buffers must already be copied from the host to the device prior to using this operation. The operation returns a handle to the sparse matrix descriptor.

The CSC format has exactly the same memory layout as its transpose in CSR format (and vice versa).

If the async keyword is present, the op is executed asynchronously (i.e. it does not block until the execution has finished on the device). In that case, it returns a !gpu.async.token in addition to the environment.

Example:

%spmat, %token = gpu.create_csc async [%dep] %rows, %cols, %nnz, %colPos,
    %rowIdx, %values : memref<?xindex>, memref<?xindex>, memref<?xf64>

Interfaces: GPU_AsyncOpInterface

Operands: 

OperandDescription
asyncDependenciesvariadic of async token type
rowsindex
colsindex
nnzindex
colPosmemref of any type values
rowIdxsmemref of any type values
valuesmemref of any type values

Results: 

ResultDescription
spmatsparse matrix handle type
asyncTokenasync token type

gpu.create_csr (gpu::CreateCsrOp) 

Create sparse matrix in CSR format operation

Syntax:

operation ::= `gpu.create_csr` custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
              $rows `,` $cols `,` $nnz `,` $rowPos `,` $colIdxs `,` $values attr-dict
              `:` type($rowPos) `,` type($colIdxs) `,` type($values)

The gpu.create_csr operation initializes a sparse matrix in CSR format with the given sizes from the given position, index, and values buffers. The buffers must already be copied from the host to the device prior to using this operation. The operation returns a handle to the sparse matrix descriptor.

The CSR format has exactly the same memory layout as its transpose in CSC format (and vice versa).

If the async keyword is present, the op is executed asynchronously (i.e. it does not block until the execution has finished on the device). In that case, it returns a !gpu.async.token in addition to the environment.

Example:

%spmat, %token = gpu.create_csr async [%dep] %rows, %cols, %nnz, %rowPos,
    %colIdx, %values : memref<?xindex>, memref<?xindex>, memref<?xf64>

Interfaces: GPU_AsyncOpInterface

Operands: 

OperandDescription
asyncDependenciesvariadic of async token type
rowsindex
colsindex
nnzindex
rowPosmemref of any type values
colIdxsmemref of any type values
valuesmemref of any type values

Results: 

ResultDescription
spmatsparse matrix handle type
asyncTokenasync token type

gpu.create_dn_tensor (gpu::CreateDnTensorOp) 

Create dense tensor operation

Syntax:

operation ::= `gpu.create_dn_tensor` custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
              $memref `,` $dims attr-dict `:` type($dims) `into` type($memref)

The gpu.create_dn_tensor operation initializes a dense tensor from the given values buffer and sizes. The buffer must already be copied from the host to the device prior to using this operation. The operation returns a handle to the dense tensor descriptor.

If the async keyword is present, the op is executed asynchronously (i.e. it does not block until the execution has finished on the device). In that case, it returns a !gpu.async.token in addition to the environment.

Example:

%dmat, %token = gpu.create_dn_tensor async [%dep] %mem, %dims : index, index into memref<?xf64>

Traits: AttrSizedOperandSegments

Interfaces: GPU_AsyncOpInterface

Operands: 

OperandDescription
asyncDependenciesvariadic of async token type
memrefmemref of any type values
dimsvariadic of index

Results: 

ResultDescription
dnTensordense tensor handle type
asyncTokenasync token type

gpu.dealloc (gpu::DeallocOp) 

GPU memory deallocation operation

Syntax:

operation ::= `gpu.dealloc` custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
              $memref attr-dict `:` type($memref)

The gpu.dealloc operation frees the region of memory referenced by a memref which was originally created by the gpu.alloc operation. It is similar to the memref.dealloc op, but supports asynchronous GPU execution.

The op does not execute before all async dependencies have finished executing.

If the async keyword is present, the op is executed asynchronously (i.e. it does not block until the execution has finished on the device). In that case, it returns a !gpu.async.token.

Example:

%token = gpu.dealloc async [%dep] %memref : memref<8x64xf32, 1>

Interfaces: GPU_AsyncOpInterface

Operands: 

OperandDescription
asyncDependenciesvariadic of async token type
memrefmemref of any type values

Results: 

ResultDescription
asyncTokenasync token type

gpu.destroy_dn_tensor (gpu::DestroyDnTensorOp) 

Destroy dense tensor operation

Syntax:

operation ::= `gpu.destroy_dn_tensor` custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
              $dnTensor attr-dict

The gpu.destroy_dn_tensor operation releases all resources of a dense tensor represented by a handle that was previously created by a gpu.create_dn_tensor operation.

If the async keyword is present, the op is executed asynchronously (i.e. it does not block until the execution has finished on the device). In that case, it returns a !gpu.async.token in addition to the environment.

Example:

%token = gpu.destroy_dn_tensor async [%dep] %dnTensor

Interfaces: GPU_AsyncOpInterface

Operands: 

OperandDescription
asyncDependenciesvariadic of async token type
dnTensordense tensor handle type

Results: 

ResultDescription
asyncTokenasync token type

gpu.destroy_sp_mat (gpu::DestroySpMatOp) 

Destroy sparse matrix operation

Syntax:

operation ::= `gpu.destroy_sp_mat` custom<AsyncDependencies>(type($asyncToken), $asyncDependencies) $spmat attr-dict

The gpu.destroy_sp_mat operation releases all resources of a sparse matrix represented by a handle that was previously created by a one of the sparse matrix creation operations.

If the async keyword is present, the op is executed asynchronously (i.e. it does not block until the execution has finished on the device). In that case, it returns a !gpu.async.token in addition to the environment.

Example:

%token = gpu.destroy_sp_mat async [%dep] %spmat

Interfaces: GPU_AsyncOpInterface

Operands: 

OperandDescription
asyncDependenciesvariadic of async token type
spmatsparse matrix handle type

Results: 

ResultDescription
asyncTokenasync token type

gpu.dynamic_shared_memory (gpu::DynamicSharedMemoryOp) 

Get the memref for dynamic shared memory

Syntax:

operation ::= `gpu.dynamic_shared_memory` attr-dict `:` type($resultMemref)

This operation provides a memref pointer to the start of dynamic shared memory, often referred to as workgroup memory. It’s important to note that this dynamic shared memory needs to be allocated at kernel launch. One can conveniently utilize the dynamic_shared_memory_size parameter of gpu.launch for this purpose.

Examples:

%0 = gpu.dynamic.shared.memory : memref<?xi8, #gpu.address_space<workgroup>>
%1 = memref.view %0[%c8192][] : memref<?xi8, #gpu.address_space<workgroup>>
                        to memref<32x64xf32, #gpu.address_space<workgroup>>
%2 = memref.view %0[%c16384][] : memref<?xi8, #gpu.address_space<workgroup>>
                        to memref<32x64xf32, #gpu.address_space<workgroup>>

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Results: 

ResultDescription
resultMemref1D memref of 8-bit signless integer values

gpu.func (gpu::GPUFuncOp) 

Function executable on a GPU

Defines a function that can be executed on a GPU. This supports memory attribution and its body has a particular execution model.

GPU functions are either kernels (as indicated by the kernel attribute) or regular functions. The former can be launched from the host side, while the latter are device side only.

The memory attribution defines SSA values that correspond to memory buffers allocated in the memory hierarchy of the GPU (see below).

The operation has one attached region that corresponds to the body of the function. The region arguments consist of the function arguments without modification, followed by buffers defined in memory annotations. The body of a GPU function, when launched, is executed by multiple work items. There are no guarantees on the order in which work items execute, or on the connection between them. In particular, work items are not necessarily executed in lock-step. Synchronization ops such as “gpu.barrier” should be used to coordinate work items. Declarations of GPU functions, i.e. not having the body region, are not supported.

A function may optionally be annotated with the block and/or grid sizes that will be used when it is launched using the gpu.known_block_size and gpu.known_grid_size attributes, respectively. If set, these attributes must be arrays of three 32-bit integers giving the x, y, and z launch dimensions. Launching a kernel that has these annotations, or that calls a function with these annotations, using a block size or grid size other than what is specified is undefined behavior.

Syntax:

op ::= `gpu.func` symbol-ref-id `(` argument-list `)` (`->`
function-result-list)?
       memory-attribution `kernel`? function-attributes? region

memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)?
                       (`private` `(` ssa-id-and-type-list `)`)?

Example:

gpu.func @foo(%arg0: index)
    workgroup(%workgroup: memref<32xf32, 3>)
    private(%private: memref<1xf32, 5>)
    kernel
    attributes {qux: "quux"} {
  gpu.return
}

The generic form illustrates the concept

"gpu.func"(%arg: index) {sym_name: "foo", kernel, qux: "quux"} ({
^bb0(%arg0: index, %workgroup: memref<32xf32, 3>,
     %private: memref<1xf32, 5>):
  "gpu.return"() : () -> ()
}) : (index) -> ()

Note the non-default memory spaces used in memref types in memory attribution.

Traits: AutomaticAllocationScope, HasParent<GPUModuleOp>, IsolatedFromAbove

Interfaces: CallableOpInterface, FunctionOpInterface, Symbol

Attributes: 

AttributeMLIR TypeDescription
function_type::mlir::TypeAttrtype attribute of function type
arg_attrs::mlir::ArrayAttrArray of dictionary attributes
res_attrs::mlir::ArrayAttrArray of dictionary attributes
workgroup_attrib_attrs::mlir::ArrayAttrArray of dictionary attributes
private_attrib_attrs::mlir::ArrayAttrArray of dictionary attributes

gpu.module (gpu::GPUModuleOp) 

A top level compilation unit containing code to be run on a GPU.

GPU module contains code that is intended to be run on a GPU. A host device can launch this code through a gpu.launc_func that creates a fully qualified symbol through the gpu.module’s symbol and a gpu.func symbol contained in the gpu.module.

The module’s top-level scope is modeled by a single region with a single block. GPU modules are required to have a name that is used for symbol resolution by the gpu.launch_func operation.

Using an op with a region to define a GPU module enables “embedding” GPU modules with SIMT execution models in other dialects in a clean manner and allows filtering of code regions to execute passes on only code intended to or not intended to be run on the separate device.

Modules can contain zero or more target attributes. These attributes encode how to transform modules into binary strings and are used by the gpu-module-to-binary pass to transform modules into GPU binaries.

Modules can contain an optional OffloadingTranslationAttr attribute. This attribute will be used during the gpu-module-to-binary pass to specify the OffloadingTranslationAttr used when creating the gpu.binary operation.

gpu.module @symbol_name {
  gpu.func {}
    ...
  gpu.module_end
}
// Module with offloading handler and target attributes.
gpu.module @symbol_name2 <#gpu.select_object<1>> [
    #nvvm.target,
    #rocdl.target<chip = "gfx90a">] {
  gpu.func {}
    ...
  gpu.module_end
}

Traits: HasDefaultDLTIDataLayout, IsolatedFromAbove, SingleBlockImplicitTerminator<ModuleEndOp>, SingleBlock, SymbolTable

Interfaces: DataLayoutOpInterface, Symbol

Attributes: 

AttributeMLIR TypeDescription
targets::mlir::ArrayAttrarray of GPU target attributes with at least 1 elements
offloadingHandler::mlir::Attributeany attribute with the `OffloadingTranslationAttrTrait` trait.

gpu.global_id (gpu::GlobalIdOp) 

Syntax:

operation ::= `gpu.global_id` $dimension attr-dict

Returns the unique global workitem/thread id, i.e., the unique index of the current workitem/thread within all workgroups / grid along the x, y, or z dimension.

Example:

%gidX = gpu.global_id x

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, InferIntRangeInterface, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
dimension::mlir::gpu::DimensionAttr
a dimension, either 'x', 'y', or 'z'

Enum cases:

  • x (x)
  • y (y)
  • z (z)

Results: 

ResultDescription
«unnamed»index

gpu.grid_dim (gpu::GridDimOp) 

Syntax:

operation ::= `gpu.grid_dim` $dimension attr-dict

Returns the number of thread blocks in the grid along the x, y, or z dimension.

Example:

%gDimZ = gpu.grid_dim z

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, InferIntRangeInterface, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
dimension::mlir::gpu::DimensionAttr
a dimension, either 'x', 'y', or 'z'

Enum cases:

  • x (x)
  • y (y)
  • z (z)

Results: 

ResultDescription
«unnamed»index

gpu.host_register (gpu::HostRegisterOp) 

Registers a memref for access from device.

Syntax:

operation ::= `gpu.host_register` $value attr-dict `:` type($value)

This op maps the provided host buffer into the device address space.

This operation may not be supported in every environment, there is not yet a way to check at runtime whether this feature is supported.

Writes from the host are guaranteed to be visible to device kernels that are launched afterwards. Writes from the device are guaranteed to be visible on the host after synchronizing with the device kernel completion.

Operands: 

OperandDescription
valueunranked.memref of any type values

gpu.host_unregister (gpu::HostUnregisterOp) 

Unregisters a memref for access from device.

Syntax:

operation ::= `gpu.host_unregister` $value attr-dict `:` type($value)

This op unmaps the provided host buffer from the device address space.

This operation may not be supported in every environment, there is not yet a way to check at runtime whether this feature is supported.

Operands: 

OperandDescription
valueunranked.memref of any type values

gpu.lane_id (gpu::LaneIdOp) 

Syntax:

operation ::= `gpu.lane_id` attr-dict

Returns the lane id within the subgroup (warp/wave).

Example:

%laneId = gpu.lane_id

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, InferIntRangeInterface, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Results: 

ResultDescription
resultindex

gpu.launch_func (gpu::LaunchFuncOp) 

Launches a function as a GPU kernel

Syntax:

operation ::= `gpu.launch_func` custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
              (`<` $asyncObject^ `:` type($asyncObject) `>`)?
              $kernel
              ( `clusters` `in` ` ` `(` $clusterSizeX^ `,` $clusterSizeY `,` $clusterSizeZ `)` )?
              `blocks` `in` ` ` `(` $gridSizeX `,` $gridSizeY `,` $gridSizeZ `)`
              `threads` `in` ` ` `(` $blockSizeX `,` $blockSizeY `,` $blockSizeZ `)`
              custom<LaunchDimType>(type($gridSizeX), ref($clusterSizeX), type($clusterSizeX), type($clusterSizeY), type($clusterSizeZ))
              (`dynamic_shared_memory_size` $dynamicSharedMemorySize^)?
              custom<LaunchFuncOperands>($kernelOperands, type($kernelOperands)) attr-dict

Launch a kernel function on the specified grid of thread blocks. gpu.launch operations are lowered to gpu.launch_func operations by outlining the kernel body into a function in a dedicated module, which reflects the separate compilation process. The kernel function is required to have the gpu.kernel attribute. The module containing the kernel function is required to be a gpu.module. And finally, the module containing the kernel module (which thus cannot be the top-level module) is required to have the gpu.container_module attribute. The gpu.launch_func operation has a symbol attribute named kernel to identify the fully specified kernel function to launch (both the gpu.module and func).

The gpu.launch_func supports async dependencies: the kernel does not start executing until the ops producing those async dependencies have completed.

By the default, the host implicitly blocks until kernel execution has completed. If the async keyword is present, the host does not block but instead a !gpu.async.token is returned. Other async GPU ops can take this token as dependency.

The operation requires at least the grid and block sizes along the x,y,z dimensions as arguments. When a lower-dimensional kernel is required, unused sizes must be explicitly set to 1.

The remaining operands are optional. The first optional operand corresponds to the amount of dynamic shared memory a kernel’s workgroup should be allocated; when this operand is not present, a zero size is assumed.

The remaining operands if present are passed as arguments to the kernel function.

The gpu.launch_func also supports kernel launching with clusters if supported by the target architecture. The cluster size can be set by clusterSizeX, clusterSizeY, and clusterSizeZ arguments. When these arguments are present, the Op launches a kernel that clusters the given thread blocks. This feature is exclusive to certain architectures.

Example:

module attributes {gpu.container_module} {

  // This module creates a separate compilation unit for the GPU compiler.
  gpu.module @kernels {
    func.func @kernel_1(%arg0 : f32, %arg1 : memref<?xf32, 1>)
        attributes { nvvm.kernel = true } {

      // Operations that produce block/thread IDs and dimensions are
      // injected when outlining the `gpu.launch` body to a function called
      // by `gpu.launch_func`.
      %tIdX = gpu.thread_id x
      %tIdY = gpu.thread_id y
      %tIdZ = gpu.thread_id z

      %bDimX = gpu.block_dim x
      %bDimY = gpu.block_dim y
      %bDimZ = gpu.block_dim z

      %bIdX = gpu.block_id x
      %bIdY = gpu.block_id y
      %bIdZ = gpu.block_id z

      %gDimX = gpu.grid_dim x
      %gDimY = gpu.grid_dim y
      %gDimZ = gpu.grid_dim z

      // (Optional)  Cluster size only for support architectures
      %cIdX = gpu.cluster_id x
      %cIdY = gpu.cluster_id y
      %cIdZ = gpu.cluster_id z

      %cDimX = gpu.cluster_dim x
      %cDimY = gpu.cluster_dim y
      %cDimZ = gpu.cluster_dim z

      "some_op"(%bx, %tx) : (index, index) -> ()
      %42 = load %arg1[%bx] : memref<?xf32, 1>
    }
  }

  %t0 = gpu.wait async
  gpu.launch_func
      async                           // (Optional) Don't block host, return token.
      [%t0]                           // (Optional) Execute only after %t0 has completed.
      @kernels::@kernel_1             // Kernel function.
      clusters in (%cst, %cst, %cst)  // (Optional) Cluster size only for support architectures.
      blocks in (%cst, %cst, %cst)    // Grid size.
      threads in (%cst, %cst, %cst)   // Block size.
      dynamic_shared_memory_size %s   // (Optional) Amount of dynamic shared
                                      // memory to allocate for a workgroup.
      args(%arg0 : f32,               // (Optional) Kernel arguments.
           %arg1 : memref<?xf32, 1>)
}

Traits: AttrSizedOperandSegments

Interfaces: GPU_AsyncOpInterface

Attributes: 

AttributeMLIR TypeDescription
kernel::mlir::SymbolRefAttrsymbol reference attribute

Operands: 

OperandDescription
asyncDependenciesvariadic of async token type
gridSizeXindex or 32-bit signless integer or 64-bit signless integer
gridSizeYindex or 32-bit signless integer or 64-bit signless integer
gridSizeZindex or 32-bit signless integer or 64-bit signless integer
blockSizeXindex or 32-bit signless integer or 64-bit signless integer
blockSizeYindex or 32-bit signless integer or 64-bit signless integer
blockSizeZindex or 32-bit signless integer or 64-bit signless integer
clusterSizeXindex or 32-bit signless integer or 64-bit signless integer
clusterSizeYindex or 32-bit signless integer or 64-bit signless integer
clusterSizeZindex or 32-bit signless integer or 64-bit signless integer
dynamicSharedMemorySize32-bit signless integer
kernelOperandsvariadic of any type
asyncObjectany type

Results: 

ResultDescription
asyncTokenasync token type

gpu.launch (gpu::LaunchOp) 

GPU kernel launch operation

Launch a kernel on the specified grid of thread blocks. The body of the kernel is defined by the single region that this operation contains. The operation takes an optional list of async dependencies followed by six operands and an optional operand.

The async keyword indicates the kernel should be launched asynchronously; the operation returns a new !gpu.async.token when the keyword is specified. The kernel launched does not start executing until the ops producing its async dependencies (optional operands) have completed.

The first three operands (following any async dependencies) are grid sizes along the x,y,z dimensions and the following three are block sizes along the x,y,z dimensions. When a lower-dimensional kernel is required, unused sizes must be explicitly set to 1. The last operand is optional and corresponds to the amount of dynamic shared memory a kernel’s workgroup should be allocated; when this operand is not present, a zero size is assumed.

The body region has at least twelve arguments, or eighteen if cluster dimensions are present, grouped as follows:

  • three optional arguments that contain cluster identifiers along x,y,z dimensions;
  • three arguments that contain block identifiers along x,y,z dimensions;
  • three arguments that contain thread identifiers along x,y,z dimensions;
  • operands of the gpu.launch operation as is (i.e. the operands for grid and block sizes).
  • a variadic number of Workgroup memory attributions.
  • a variadic number of Private memory attributions.

Syntax:

operation ::= `gpu.launch` (`async` (`[` ssa-id-list `]`)? )?
                         ( `clusters` `(` ssa-id-list `)` `in` ssa-reassignment )?
                         `blocks` `(` ssa-id-list `)` `in` ssa-reassignment
                         `threads` `(` ssa-id-list `)` `in` ssa-reassignment
                         (dynamic_shared_memory_size ssa-use)?
                         memory-attribution
                         region attr-dict?
ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)`
memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)?
                       (`private` `(` ssa-id-and-type-list `)`)?

Example:

gpu.launch blocks(%bx, %by, %bz) in (%sz_bx = %0, %sz_by = %1, %sz_bz = %2)
           threads(%tx, %ty, %tz) in (%sz_tx = %3, %sz_ty = %4, %sz_tz = %5) {
  // Block and thread identifiers, as well as block/grid sizes are
  // immediately usable inside body region.
  "some_op"(%bx, %tx) : (index, index) -> ()
  // Assuming %val1 is defined outside the gpu.launch region.
  %42 = load %val1[%bx] : memref<?xf32, 1>
}

// Generic syntax explains how the pretty syntax maps to the IR structure.
"gpu.launch"(%cst, %cst, %c1,  // Grid sizes.
             %cst, %c1, %c1)   // Block sizes.

    {/*attributes*/}
    // All sizes and identifiers have "index" size.
    : (index, index, index, index, index, index) -> () {
// The operation passes block and thread identifiers, followed by grid and
// block sizes.
^bb0(%bx : index, %by : index, %bz : index,
     %tx : index, %ty : index, %tz : index,
     %num_bx : index, %num_by : index, %num_bz : index,
     %num_tx : index, %num_ty : index, %num_tz : index)
  "some_op"(%bx, %tx) : (index, index) -> ()
  %3 = "memref.load"(%val1, %bx) : (memref<?xf32, 1>, index) -> f32
}

// Launch with memory attributions.
gpu.launch blocks(%bx, %by, %bz) in (%sz_bx = %0, %sz_by = %1, %sz_bz = %2)
           threads(%tx, %ty, %tz) in (%sz_tx = %3, %sz_ty = %4, %sz_tz = %5)
           workgroup(%workgroup: memref<32xf32, 3>)
           private(%private: memref<1xf32, 5>) {
  // Block and thread identifiers, as well as block/grid sizes are
  // immediately usable inside body region.
  "some_op"(%bx, %tx) : (index, index) -> ()
  // Assuming %val1 is defined outside the gpu.launch region.
  %42 = load %workgroup[%bx] : memref<32xf32, 3>
}

// Launch with clusters.
gpu.launch clusters(%cx, %cy, %cz) in (%sz_cx = %0, %sz_cy = %1, %sz_cz = %2)
           blocks(%bx, %by, %bz) in (%sz_bx = %3, %sz_by = %4, %sz_bz = %5)
           threads(%tx, %ty, %tz) in (%sz_tx = %6, %sz_ty = %7, %sz_tz = %8)
{
  // Cluster, block and thread identifiers, as well as cluster/block/grid 
  // sizes are immediately usable inside body region.
  "some_op"(%cx, %bx, %tx) : (index, index, index) -> ()
}

Rationale: using operation/block arguments gives analyses a clear way of understanding that a value has additional semantics (e.g., we will need to know what value corresponds to threadIdx.x for coalescing). We can recover these properties by analyzing the operations producing values, but it is easier just to have that information by construction.

Traits: AttrSizedOperandSegments, AutomaticAllocationScope, RecursiveMemoryEffects

Interfaces: GPU_AsyncOpInterface, InferIntRangeInterface

Operands: 

OperandDescription
asyncDependenciesvariadic of async token type
gridSizeXindex
gridSizeYindex
gridSizeZindex
blockSizeXindex
blockSizeYindex
blockSizeZindex
clusterSizeXindex
clusterSizeYindex
clusterSizeZindex
dynamicSharedMemorySize32-bit signless integer

Results: 

ResultDescription
asyncTokenasync token type

gpu.memcpy (gpu::MemcpyOp) 

GPU memcpy operation

Syntax:

operation ::= `gpu.memcpy` custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
              $dst`,` $src `:` type($dst)`,` type($src) attr-dict

The gpu.memcpy operation copies the content of one memref to another.

The op does not execute before all async dependencies have finished executing.

If the async keyword is present, the op is executed asynchronously (i.e. it does not block until the execution has finished on the device). In that case, it returns a !gpu.async.token.

Example:

%token = gpu.memcpy async [%dep] %dst, %src : memref<?xf32, 1>, memref<?xf32>

Interfaces: GPU_AsyncOpInterface

Operands: 

OperandDescription
asyncDependenciesvariadic of async token type
dstmemref of any type values
srcmemref of any type values

Results: 

ResultDescription
asyncTokenasync token type

gpu.memset (gpu::MemsetOp) 

GPU memset operation

Syntax:

operation ::= `gpu.memset` custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
              $dst`,` $value `:` type($dst)`,` type($value) attr-dict

The gpu.memset operation sets the content of memref to a scalar value.

The op does not execute before all async dependencies have finished executing.

If the async keyword is present, the op is executed asynchronously (i.e. it does not block until the execution has finished on the device). In that case, it returns a !gpu.async.token.

Example:

%token = gpu.memset async [%dep] %dst, %value : memref<?xf32, 1>, f32

Interfaces: GPU_AsyncOpInterface

Operands: 

OperandDescription
asyncDependenciesvariadic of async token type
dstmemref of any type values
valueany type

Results: 

ResultDescription
asyncTokenasync token type

gpu.module_end (gpu::ModuleEndOp) 

A pseudo op that marks the end of a gpu.module.

Syntax:

operation ::= `gpu.module_end` attr-dict

This op terminates the only block inside the only region of a gpu.module.

Traits: HasParent<GPUModuleOp>, Terminator

gpu.num_subgroups (gpu::NumSubgroupsOp) 

Syntax:

operation ::= `gpu.num_subgroups` attr-dict `:` type($result)

Returns the number of subgroups within a workgroup.

Example:

%numSg = gpu.num_subgroups : index

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, InferIntRangeInterface, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Results: 

ResultDescription
resultindex

gpu.printf (gpu::PrintfOp) 

Device-side printf, as in CUDA or OpenCL, for debugging

Syntax:

operation ::= `gpu.printf` $format attr-dict ($args^ `:` type($args))?

gpu.printf takes a literal format string format and an arbitrary number of scalar arguments that should be printed.

The format string is a C-style printf string, subject to any restrictions imposed by one’s target platform.

Interfaces: MemoryEffectOpInterface (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{MemoryEffects::Write on ::mlir::SideEffects::DefaultResource}

Attributes: 

AttributeMLIR TypeDescription
format::mlir::StringAttrstring attribute

Operands: 

OperandDescription
argsvariadic of integer or index or floating-point

gpu.return (gpu::ReturnOp) 

Terminator for GPU functions.

Syntax:

operation ::= `gpu.return` attr-dict ($operands^ `:` type($operands))?

A terminator operation for regions that appear in the body of gpu.func functions. The operands to the gpu.return are the result values returned by an invocation of the gpu.func.

Traits: AlwaysSpeculatableImplTrait, HasParent<GPUFuncOp>, Terminator

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
operandsvariadic of any type

gpu.sddmm_buffer_size (gpu::SDDMMBufferSizeOp) 

Precompute buffersize for SDDMM operation

Syntax:

operation ::= `gpu.sddmm_buffer_size` custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
              $dnmatA (`{` $modeA^ `}`)? `,` $dnmatB (`{` $modeB^ `}`)? `,` $spmatC attr-dict `into` $computeType

The gpu.sddmm_buffer_size operation returns the buffer size required to perform the SDDMM operation on the given sparse and dense matrices. The operation expects handles returned by previous sparse operations to construct an environment and the operands for SDDMM.

If the async keyword is present, the op is executed asynchronously (i.e. it does not block until the execution has finished on the device). In that case, it returns a !gpu.async.token in addition to the environment.

Example:

%buffersz, %token = gpu.sddmm_buffer_size async [%dep] %dnmatA{TRANSPOSE}, %dnmatB{TRANSPOSE}, %spmatC into f32

The matrix arguments can also be associated with one of the following operators: NON_TRANSPOSE, TRANSPOSE, CONJUGATE_TRANSPOSE. The default value is NON_TRANSPOSE.

Interfaces: GPU_AsyncOpInterface

Attributes: 

AttributeMLIR TypeDescription
modeA::mlir::gpu::TransposeModeAttr
transpose mode of sparse matrix supported by sparse tensor ops

Enum cases:

  • NON_TRANSPOSE (NON_TRANSPOSE)
  • TRANSPOSE (TRANSPOSE)
  • CONJUGATE_TRANSPOSE (CONJUGATE_TRANSPOSE)
modeB::mlir::gpu::TransposeModeAttr
transpose mode of sparse matrix supported by sparse tensor ops

Enum cases:

  • NON_TRANSPOSE (NON_TRANSPOSE)
  • TRANSPOSE (TRANSPOSE)
  • CONJUGATE_TRANSPOSE (CONJUGATE_TRANSPOSE)
computeType::mlir::TypeAttrany type attribute

Operands: 

OperandDescription
asyncDependenciesvariadic of async token type
dnmatAdense tensor handle type
dnmatBdense tensor handle type
spmatCsparse matrix handle type

Results: 

ResultDescription
bufferSzindex
asyncTokenasync token type

gpu.sddmm (gpu::SDDMMOp) 

SDDMM operation

Syntax:

operation ::= `gpu.sddmm` custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
              $dnmatA (`{` $modeA^ `}`)? `,` $dnmatB (`{` $modeB^ `}`)? `,` $spmatC `,` $buffer attr-dict `:` type($buffer) `into` $computeType

The gpu.sddmm operation performs the SDDMM operation on the given sparse and dense matrices, and buffer. The operation expects handles returned by previous sparse operations to construct an environment and the operands for SDDMM. The buffer must have been allocated on the device.

If the async keyword is present, the op is executed asynchronously (i.e. it does not block until the execution has finished on the device). In that case, it returns a !gpu.async.token in addition to the environment.

Example:

%token = gpu.sddmm async [%dep] %dnmatA{TRANSPOSE}, %dnmatB{TRANSPOSE}, %spmatC, %buffer into f32

The matrix arguments can also be associated with one of the following operators: NON_TRANSPOSE, TRANSPOSE, CONJUGATE_TRANSPOSE. The default value is NON_TRANSPOSE.

Interfaces: GPU_AsyncOpInterface

Attributes: 

AttributeMLIR TypeDescription
modeA::mlir::gpu::TransposeModeAttr
transpose mode of sparse matrix supported by sparse tensor ops

Enum cases:

  • NON_TRANSPOSE (NON_TRANSPOSE)
  • TRANSPOSE (TRANSPOSE)
  • CONJUGATE_TRANSPOSE (CONJUGATE_TRANSPOSE)
modeB::mlir::gpu::TransposeModeAttr
transpose mode of sparse matrix supported by sparse tensor ops

Enum cases:

  • NON_TRANSPOSE (NON_TRANSPOSE)
  • TRANSPOSE (TRANSPOSE)
  • CONJUGATE_TRANSPOSE (CONJUGATE_TRANSPOSE)
computeType::mlir::TypeAttrany type attribute

Operands: 

OperandDescription
asyncDependenciesvariadic of async token type
dnmatAdense tensor handle type
dnmatBdense tensor handle type
spmatCsparse matrix handle type
buffermemref of any type values

Results: 

ResultDescription
asyncTokenasync token type

gpu.set_csr_pointers (gpu::SetCsrPointersOp) 

SpGEMM get size operation

Syntax:

operation ::= `gpu.set_csr_pointers` custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
              $spmat `,` $positions `,` $coordinates `,` $values attr-dict
              `:` type($positions) `,` type($coordinates) `,` type($values)

The gpu.set_csr_pointers assigns the given positions, coordinates, and values buffer that reside on the device directly to the given sparse matrix descriptor in csr format.

If the async keyword is present, the op is executed asynchronously (i.e. it does not block until the execution has finished on the device). In that case, it returns a !gpu.async.token in addition to the environment.

Example:

%token = gpu.set_csr_pointers async [%dep] %positions, %coordinates, %values
      : memref<?xf32>, memref<?xindex>, memref<?xindex>

Interfaces: GPU_AsyncOpInterface

Operands: 

OperandDescription
asyncDependenciesvariadic of async token type
spmatsparse matrix handle type
positionsmemref of any type values
coordinatesmemref of any type values
valuesmemref of any type values

Results: 

ResultDescription
asyncTokenasync token type

gpu.set_default_device (gpu::SetDefaultDeviceOp) 

Set default GPU for operations after this by index

Syntax:

operation ::= `gpu.set_default_device` attr-dict $devIndex

Operation that sets the current default GPU, using a zero-based index into the set of GPUs on the system. The default GPU setting may be thread-local.

Interfaces: MemoryEffectOpInterface (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{MemoryEffects::Write on ::mlir::SideEffects::DefaultResource}

Operands: 

OperandDescription
devIndex32-bit signless integer

gpu.shuffle (gpu::ShuffleOp) 

Shuffles values within a subgroup.

Syntax:

operation ::= `gpu.shuffle` $mode $value `,` $offset `,` $width attr-dict `:` type($value)

The “shuffle” op moves values to a across lanes (a.k.a., invocations, work items) within the same subgroup. The width argument specifies the number of lanes that participate in the shuffle, and must be uniform across all lanes. Further, the first width lanes of the subgroup must be active.

The intepretation of the offset arguments depends on the selected mode.

Returns the shuffleResult and true if the current lane id is smaller than width, and an unspecified value and false otherwise.

xor example:

%1, %2 = gpu.shuffle xor %0, %offset, %width : f32

For lane k, returns the value %0 from lane k ^ offset. Every lane trades value with exactly one other lane.

down example:

%cst1 = arith.constant 1 : i32
%3, %4 = gpu.shuffle down %0, %cst1, %width : f32

For lane k, returns the value from lane (k + 1) % width.

up example:

%cst1 = arith.constant 1 : i32
%5, %6 = gpu.shuffle up %0, %cst1, %width : f32

For lane k, returns the value from lane (k - 1) % width.

idx example:

%cst0 = arith.constant 0 : i32
%7, %8 = gpu.shuffle idx %0, %cst0, %width : f32

Broadcasts the value from lane 0 to all lanes.

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
mode::mlir::gpu::ShuffleModeAttr
Indexing modes supported by gpu.shuffle.

Enum cases:

  • xor (XOR)
  • up (UP)
  • down (DOWN)
  • idx (IDX)

Operands: 

OperandDescription
valuei32, i64, f32 or f64
offset32-bit signless integer
width32-bit signless integer

Results: 

ResultDescription
shuffleResulti32, i64, f32 or f64
valid1-bit signless integer

gpu.spgemm_copy (gpu::SpGEMMCopyOp) 

SpGEMM copy operation

Syntax:

operation ::= `gpu.spgemm_copy` custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
              $spmatA (`{` $modeA^ `}`)? `,` $spmatB (`{` $modeB^ `}`)? `,` $spmatC `,` $desc attr-dict `:` $computeType

The gpu.spgemm_copy operation copies the sparse matrix result of a SpGEMM computation.

If the async keyword is present, the op is executed asynchronously (i.e. it does not block until the execution has finished on the device). In that case, it returns a !gpu.async.token in addition to the environment.

Example:

gpu.spgemm_copy %spmatA, %spmatB, %spmatC, %spgemmDesc: f32

The matrix arguments can also be associated with one of the following operators: NON_TRANSPOSE, TRANSPOSE, CONJUGATE_TRANSPOSE. The default value is NON_TRANSPOSE.

Interfaces: GPU_AsyncOpInterface

Attributes: 

AttributeMLIR TypeDescription
modeA::mlir::gpu::TransposeModeAttr
transpose mode of sparse matrix supported by sparse tensor ops

Enum cases:

  • NON_TRANSPOSE (NON_TRANSPOSE)
  • TRANSPOSE (TRANSPOSE)
  • CONJUGATE_TRANSPOSE (CONJUGATE_TRANSPOSE)
modeB::mlir::gpu::TransposeModeAttr
transpose mode of sparse matrix supported by sparse tensor ops

Enum cases:

  • NON_TRANSPOSE (NON_TRANSPOSE)
  • TRANSPOSE (TRANSPOSE)
  • CONJUGATE_TRANSPOSE (CONJUGATE_TRANSPOSE)
computeType::mlir::TypeAttrany type attribute

Operands: 

OperandDescription
asyncDependenciesvariadic of async token type
descSpGEMM operation handle type
spmatAsparse matrix handle type
spmatBsparse matrix handle type
spmatCsparse matrix handle type

Results: 

ResultDescription
asyncTokenasync token type

gpu.spgemm_create_descr (gpu::SpGEMMCreateDescrOp) 

SpGEMM Create Descr operation

Syntax:

operation ::= `gpu.spgemm_create_descr` custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
              attr-dict

The gpu.spgemm_create_descr creates a descriptor for the SpGEMM operation. The descriptor describes the SpGEMM operation and stores the internal data throughout the computation. It needs to be passed as an argument to spgemm_* operations.

If the async keyword is present, the op is executed asynchronously (i.e. it does not block until the execution has finished on the device). In that case, it returns a !gpu.async.token in addition to the environment.

Example:

%desc, %token = gpu.spgemm_create_descr async [%dep]

Interfaces: GPU_AsyncOpInterface

Operands: 

OperandDescription
asyncDependenciesvariadic of async token type

Results: 

ResultDescription
descSpGEMM operation handle type
asyncTokenasync token type

gpu.spgemm_destroy_descr (gpu::SpGEMMDestroyDescrOp) 

SpGEMM Destroy Descr operation

Syntax:

operation ::= `gpu.spgemm_destroy_descr` custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
              $desc attr-dict

The gpu.spgemm_destroy_descr destroys the SpGEMM operation descriptor.

If the async keyword is present, the op is executed asynchronously (i.e. it does not block until the execution has finished on the device). In that case, it returns a !gpu.async.token in addition to the environment.

Example:

%token = gpu.spgemm_destroy_descr async [%dep] %desc

Interfaces: GPU_AsyncOpInterface

Operands: 

OperandDescription
asyncDependenciesvariadic of async token type
descSpGEMM operation handle type

Results: 

ResultDescription
asyncTokenasync token type

gpu.spgemm_work_estimation_or_compute (gpu::SpGEMMWorkEstimationOrComputeOp) 

SpGEMM work estimation operation

Syntax:

operation ::= `gpu.spgemm_work_estimation_or_compute` custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
              `{` $kind `}` $spmatA (`{` $modeA^ `}`)? `,` $spmatB (`{` $modeB^ `}`)? `,` $spmatC `,` $desc `,` $bufferSz `,` $buffer  attr-dict `:` $computeType `into` type($buffer)

The gpu.spgemm_work_estimation_or_compute is used to call cusparseSpGEMM_workEstimation or cusparseSpGEMM_compute. Both of them are for both determining the buffer size and performing the actual computation. The operation expects handles returned by previous sparse operations to construct an environment and the operands for SpGEMM. The buffer must have been allocated on the device.

C’ = alpha * op(A) * op(B) + beta * C

If the async keyword is present, the op is executed asynchronously (i.e. it does not block until the execution has finished on the device). In that case, it returns a !gpu.async.token in addition to the environment.

Example:

%bufferSz, %token = gpu.spgemm_work_estimation_or_compute async [%dep] {COMPUTE}
                      %desc, %spmatA{NON_TRANSPOSE}, %spmatB{NON_TRANSPOSE},
                      %spmatC, %spgemmDesc, %c0, %alloc: f32 into
                      memref<0xi8>

The matrix arguments can also be associated with one of the following operators: NON_TRANSPOSE, TRANSPOSE, CONJUGATE_TRANSPOSE. The default value is NON_TRANSPOSE.

Interfaces: GPU_AsyncOpInterface

Attributes: 

AttributeMLIR TypeDescription
modeA::mlir::gpu::TransposeModeAttr
transpose mode of sparse matrix supported by sparse tensor ops

Enum cases:

  • NON_TRANSPOSE (NON_TRANSPOSE)
  • TRANSPOSE (TRANSPOSE)
  • CONJUGATE_TRANSPOSE (CONJUGATE_TRANSPOSE)
modeB::mlir::gpu::TransposeModeAttr
transpose mode of sparse matrix supported by sparse tensor ops

Enum cases:

  • NON_TRANSPOSE (NON_TRANSPOSE)
  • TRANSPOSE (TRANSPOSE)
  • CONJUGATE_TRANSPOSE (CONJUGATE_TRANSPOSE)
computeType::mlir::TypeAttrany type attribute
kind::mlir::gpu::SpGEMMWorkEstimationOrComputeKindAttr
choose whether spgemm_work_estimation_or_compute does work estimation or compute

Enum cases:

  • WORK_ESTIMATION (WORK_ESTIMATION)
  • COMPUTE (COMPUTE)

Operands: 

OperandDescription
asyncDependenciesvariadic of async token type
descSpGEMM operation handle type
spmatAsparse matrix handle type
spmatBsparse matrix handle type
spmatCsparse matrix handle type
bufferSzindex
buffermemref of any type values

Results: 

ResultDescription
bufferSzNewindex
asyncTokenasync token type

gpu.spmm_buffer_size (gpu::SpMMBufferSizeOp) 

Precompute buffersize for SpMM operation

Syntax:

operation ::= `gpu.spmm_buffer_size` custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
              $spmatA (`{` $modeA^ `}`)? `,` $dnmatB (`{` $modeB^ `}`)? `,` $dnmatC attr-dict `:` type($bufferSzs) `into` $computeType

The gpu.spmm_buffer_size operation returns the buffer size required to perform the SpMM operation on the given sparse and dense matrix. The operation expects handles returned by previous sparse operations to construct an environment and the operands for SpMM.

If the async keyword is present, the op is executed asynchronously (i.e. it does not block until the execution has finished on the device). In that case, it returns a !gpu.async.token in addition to the environment.

The matrix arguments can also be associated with one of the following operators: NON_TRANSPOSE, TRANSPOSE, CONJUGATE_TRANSPOSE. The default value is NON_TRANSPOSE.

Example:

%bufferszs, %token = gpu.spmm_buffer_size async [%dep] %spmatA{TRANSPOSE}, %dnmatB{TRANSPOSE}, %dnmatC : i64 into f32

Traits: AttrSizedResultSegments

Interfaces: GPU_AsyncOpInterface

Attributes: 

AttributeMLIR TypeDescription
modeA::mlir::gpu::TransposeModeAttr
transpose mode of sparse matrix supported by sparse tensor ops

Enum cases:

  • NON_TRANSPOSE (NON_TRANSPOSE)
  • TRANSPOSE (TRANSPOSE)
  • CONJUGATE_TRANSPOSE (CONJUGATE_TRANSPOSE)
modeB::mlir::gpu::TransposeModeAttr
transpose mode of sparse matrix supported by sparse tensor ops

Enum cases:

  • NON_TRANSPOSE (NON_TRANSPOSE)
  • TRANSPOSE (TRANSPOSE)
  • CONJUGATE_TRANSPOSE (CONJUGATE_TRANSPOSE)
computeType::mlir::TypeAttrany type attribute

Operands: 

OperandDescription
asyncDependenciesvariadic of async token type
spmatAsparse matrix handle type
dnmatBdense tensor handle type
dnmatCdense tensor handle type

Results: 

ResultDescription
bufferSzsvariadic of index
asyncTokenasync token type

gpu.spmm (gpu::SpMMOp) 

SpMM operation

Syntax:

operation ::= `gpu.spmm` custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
              $spmatA (`{` $modeA^ `}`)? `,` $dnmatB (`{` $modeB^ `}`)? `,` $dnmatC `,` $buffers attr-dict `:` type($buffers) `into` $computeType

The gpu.spmm operation performs the SpMM operation on the given sparse and dense matrix, and buffer. The operation expects handles returned by previous sparse operations to construct an environment and the operands for SpMM. The buffer must have been allocated on the device.

If the async keyword is present, the op is executed asynchronously (i.e. it does not block until the execution has finished on the device). In that case, it returns a !gpu.async.token in addition to the environment.

The matrix arguments can also be associated with one of the following operators: NON_TRANSPOSE, TRANSPOSE, CONJUGATE_TRANSPOSE. The default value is NON_TRANSPOSE.

Example:

%token = gpu.spmm async [%dep] %spmatA{TRANSPOSE}, %dnmatB{TRANSPOSE}, %dnmatC, %buffers : type($buffers) into f32

Traits: AttrSizedOperandSegments

Interfaces: GPU_AsyncOpInterface

Attributes: 

AttributeMLIR TypeDescription
modeA::mlir::gpu::TransposeModeAttr
transpose mode of sparse matrix supported by sparse tensor ops

Enum cases:

  • NON_TRANSPOSE (NON_TRANSPOSE)
  • TRANSPOSE (TRANSPOSE)
  • CONJUGATE_TRANSPOSE (CONJUGATE_TRANSPOSE)
modeB::mlir::gpu::TransposeModeAttr
transpose mode of sparse matrix supported by sparse tensor ops

Enum cases:

  • NON_TRANSPOSE (NON_TRANSPOSE)
  • TRANSPOSE (TRANSPOSE)
  • CONJUGATE_TRANSPOSE (CONJUGATE_TRANSPOSE)
computeType::mlir::TypeAttrany type attribute

Operands: 

OperandDescription
asyncDependenciesvariadic of async token type
spmatAsparse matrix handle type
dnmatBdense tensor handle type
dnmatCdense tensor handle type
buffersvariadic of memref of any type values

Results: 

ResultDescription
asyncTokenasync token type

gpu.spmv_buffer_size (gpu::SpMVBufferSizeOp) 

Precompute buffersize for SpMV operation

Syntax:

operation ::= `gpu.spmv_buffer_size` custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
              $spmatA (`{` $modeA^ `}`)? `,` $dnX `,` $dnY attr-dict  `into` $computeType

The gpu.spmv_buffer_size operation returns the buffer size required to perform the SpMV operation on the given sparse matrix and dense vectors. The operation expects handles returned by previous sparse operations to construct an environment and the operands for SpMV.

If the async keyword is present, the op is executed asynchronously (i.e. it does not block until the execution has finished on the device). In that case, it returns a !gpu.async.token in addition to the environment.

The matrix arguments can also be associated with one of the following operators: NON_TRANSPOSE, TRANSPOSE, CONJUGATE_TRANSPOSE. The default value is NON_TRANSPOSE.

Example:

%buffersz, %token = gpu.spmv_buffer_size async [%dep] %spmatA{TRANSPOSE}, %dnX, %dnY into f32

Interfaces: GPU_AsyncOpInterface

Attributes: 

AttributeMLIR TypeDescription
modeA::mlir::gpu::TransposeModeAttr
transpose mode of sparse matrix supported by sparse tensor ops

Enum cases:

  • NON_TRANSPOSE (NON_TRANSPOSE)
  • TRANSPOSE (TRANSPOSE)
  • CONJUGATE_TRANSPOSE (CONJUGATE_TRANSPOSE)
computeType::mlir::TypeAttrany type attribute

Operands: 

OperandDescription
asyncDependenciesvariadic of async token type
spmatAsparse matrix handle type
dnXdense tensor handle type
dnYdense tensor handle type

Results: 

ResultDescription
bufferSzindex
asyncTokenasync token type

gpu.spmv (gpu::SpMVOp) 

SpMV operation

Syntax:

operation ::= `gpu.spmv` custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
              $spmatA (`{` $modeA^ `}`)? `,` $dnX `,` $dnY `,` $buffer attr-dict `:` type($buffer) `into` $computeType

The gpu.spmv operation performs the SpMV operation on the given sparse matrix, dense vectors, and buffer. The operation expects handles returned by previous sparse operations to construct an environment and the operands for SpMV. The buffer must have been allocated on the device.

If the async keyword is present, the op is executed asynchronously (i.e. it does not block until the execution has finished on the device). In that case, it returns a !gpu.async.token in addition to the environment.

The matrix arguments can also be associated with one of the following operators: NON_TRANSPOSE, TRANSPOSE, CONJUGATE_TRANSPOSE. The default value is NON_TRANSPOSE.

Example:

%token = gpu.spmv async [%dep] %spmatA{TRANSPOSE}, %dnX, %dnY : memref<?xf64> into bf16

Interfaces: GPU_AsyncOpInterface

Attributes: 

AttributeMLIR TypeDescription
modeA::mlir::gpu::TransposeModeAttr
transpose mode of sparse matrix supported by sparse tensor ops

Enum cases:

  • NON_TRANSPOSE (NON_TRANSPOSE)
  • TRANSPOSE (TRANSPOSE)
  • CONJUGATE_TRANSPOSE (CONJUGATE_TRANSPOSE)
computeType::mlir::TypeAttrany type attribute

Operands: 

OperandDescription
asyncDependenciesvariadic of async token type
spmatAsparse matrix handle type
dnXdense tensor handle type
dnYdense tensor handle type
buffermemref of any type values

Results: 

ResultDescription
asyncTokenasync token type

gpu.spmat_get_size (gpu::SpMatGetSizeOp) 

SpMat get size operation

Syntax:

operation ::= `gpu.spmat_get_size` custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
              $spmat attr-dict

The gpu.spmat_get_size operation retrieves the number of rows, number of columns, and number of non-zero elements of a sparse matrix.

If the async keyword is present, the op is executed asynchronously (i.e. it does not block until the execution has finished on the device). In that case, it returns a !gpu.async.token in addition to the environment.

Example:

%rows, %cols, %nnz, %token = gpu.spmat_get_size async [%dep] %spmatC

Interfaces: GPU_AsyncOpInterface

Operands: 

OperandDescription
asyncDependenciesvariadic of async token type
spmatsparse matrix handle type

Results: 

ResultDescription
rowsindex
colsindex
nnzindex
asyncTokenasync token type

gpu.subgroup_id (gpu::SubgroupIdOp) 

Syntax:

operation ::= `gpu.subgroup_id` attr-dict `:` type($result)

Returns the subgroup id, i.e., the index of the current subgroup within the workgroup.

Example:

%sgId = gpu.subgroup_id : index

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, InferIntRangeInterface, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Results: 

ResultDescription
resultindex

gpu.subgroup_mma_compute (gpu::SubgroupMmaComputeOp) 

GPU warp synchronous matrix multiply accumulate

Syntax:

operation ::= `gpu.subgroup_mma_compute` $opA`,` $opB`,` $opC attr-dict `:` type($opA)`,` type($opB) `->` type($res)

The gpu.subgroup_mma_compute operation performs a matrix-multiply accumulate (mma) operation using all the threads in a subgroup.

This operation takes three !gpu.mma_matrixs as arguments: these hold A, B and Coperands for the mma operation. The operation performed is represented as C += A * B. The op returns a !gpu.mma_matrix which contains the result of the operation held by all threads in a subgroup. a_transpose or b_transpose if present, signify that the respective operand was loaded in a transposed manner. The transpose operands are required to map to correct underlying intrisics but they currently do not seem to affect correctness even if they are absent given that the operands were loaded correctly using the transpose attribute in gpu.subgroup_mma_load_matrix op.

For integer types, the A and B matrices carry their signedness with their types. The accumulator type is expected to be signless and imply a signed integer with a greater width than the other two operands.

This op is meant to be used along with gpu.subgroup_mma_store_matrix and gpu.subgroup_mma_load_matrix ops.

Example:

%D = gpu.subgroup_mma_compute_matrix %A, %B, %C :
  !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp">>
  -> !gpu.mma_matrix<16x16xf16, "COp">

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
a_transpose::mlir::UnitAttrunit attribute
b_transpose::mlir::UnitAttrunit attribute

Operands: 

OperandDescription
opAgpu.mma_matrix of 8-bit signed integer or 8-bit unsigned integer or 16-bit float or 32-bit float values
opBgpu.mma_matrix of 8-bit signed integer or 8-bit unsigned integer or 16-bit float or 32-bit float values
opCgpu.mma_matrix of 32-bit signless integer or 16-bit float or 32-bit float values

Results: 

ResultDescription
resMMAMatrix type

gpu.subgroup_mma_constant_matrix (gpu::SubgroupMmaConstantMatrixOp) 

GPU warp synchronous constant matrix

Syntax:

operation ::= `gpu.subgroup_mma_constant_matrix` $value attr-dict `:` type($res)

The gpu.subgroup_mma_constant_matrix creates a !gpu.mma_matrix with constant elements.

The operation takes a scalar input and return a !gpu.mma_matrix where each element of is equal to the operand constant. The destination mma_matrix type must have elememt type equal to the constant type. Since the layout of !gpu.mma_matrix is opaque this only support setting all the elements to the same value.

This op is meant to be used along with gpu.subgroup_mma_compute.

Example:

 %0 = gpu.subgroup_mma_constant_matrix %a :
   !gpu.mma_matrix<16x16xf16, "AOp">
 %1 = gpu.subgroup_mma_constant_matrix %b :
   !gpu.mma_matrix<16x16xf32, "COp">

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
value8-bit signed integer or 8-bit unsigned integer or 32-bit signless integer or 16-bit float or 32-bit float

Results: 

ResultDescription
resMMAMatrix type

gpu.subgroup_mma_elementwise (gpu::SubgroupMmaElementwiseOp) 

GPU warp elementwise operation on a matrix

Syntax:

operation ::= `gpu.subgroup_mma_elementwise` $opType $args attr-dict `:` functional-type($args, $res)

The gpu.subgroup_mma_elementwise takes !gpu.mma_matrix inputs and compute a new !gpu.mma_matrix by applying an elementwise operation to each element.

Since the operation is elementwise and the matrix type must match, the matrix elements are processed independently of the matrix layout.

This op is meant to be used along with gpu.subgroup_mma_compute.

Example:

 %0 =  %A, %B { opType = "ADD" } :
  (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">)
  -> !gpu.mma_matrix<16x16xf16, "COp">

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
opType::mlir::gpu::MMAElementwiseOpAttr
elementwise operation to apply to mma matrix

Enum cases:

  • addf (ADDF)
  • mulf (MULF)
  • subf (SUBF)
  • maxf (MAXF)
  • minf (MINF)
  • divf (DIVF)
  • addi (ADDI)
  • muli (MULI)
  • subi (SUBI)
  • divs (DIVS)
  • divu (DIVU)
  • negatef (NEGATEF)
  • negates (NEGATES)
  • extf (EXTF)

Operands: 

OperandDescription
argsvariadic of MMAMatrix type

Results: 

ResultDescription
resMMAMatrix type

gpu.subgroup_mma_load_matrix (gpu::SubgroupMmaLoadMatrixOp) 

GPU warp synchronous matrix load

Syntax:

operation ::= `gpu.subgroup_mma_load_matrix` $srcMemref`[`$indices`]` attr-dict `:` type($srcMemref) `->` type($res)

The gpu.subgroup_mma_load_matrix operation loads a matrix collectively using all the threads in a subgroup.

This operation takes a memref as its first operand: it is the source matrix from which data is to be loaded. The op returns a !gpu.mma_matrix. The source memref can be in global memory or shared memory. The load address is determined using indices. The matrix being loaded into is the result. The leadDimension attribute specifies the leading dimension size of the source matrix which eventually allows the lowering to determine the size of each row. If the transpose attribute is present then the op does a transposed load.

For integer types, the resulting !gpu.mma_matrix type needs to specify the signedness of the data if the matrix type is an A or B operand for gpu.subgroup_mma_compute.

This op is often meant to be used along with gpu.subgroup_mma_store_matrix and gpu.subgroup_mma_compute.

Example:

 %0 = gpu.subgroup_mma_load_matrix src[%i,%j] : {leadDimension = 32 : i32}
      : memref<32x32xf16, 3>, !gpu.mma_matrix<16x16xf16, "AOp">

Interfaces: MemoryEffectOpInterface (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{MemoryEffects::Read on ::mlir::SideEffects::DefaultResource}

Attributes: 

AttributeMLIR TypeDescription
leadDimension::mlir::IntegerAttrindex attribute
transpose::mlir::UnitAttrunit attribute

Operands: 

OperandDescription
srcMemrefmemref of 8-bit signless integer or 32-bit signless integer or 16-bit float or 32-bit float or vector of 8-bit signless integer or 32-bit signless integer or 16-bit float or 32-bit float values of ranks 1 values
indicesvariadic of index

Results: 

ResultDescription
resMMAMatrix type

gpu.subgroup_mma_store_matrix (gpu::SubgroupMmaStoreMatrixOp) 

GPU warp synchronous matrix store

Syntax:

operation ::= `gpu.subgroup_mma_store_matrix` $src`,` $dstMemref`[`$indices`]` attr-dict `:` type($src)`,` type($dstMemref)

The gpu.subgroup_mma_store_matrix operation stores a matrix collectively using all the threads in a subgroup.

This operation takes a !gpu.mma_matrix and a memref as operands. !gpu.mma_matrix is the source value containing the data to be stored into the destination memref which can be in global or shared memory. The store address is determined using the indices provided. The leadDimension attribute specifies the leading dimension of the destination matrix. If the transpose attribute is present then the op does a transposed store.

This op is often meant to be used along with gpu.subgroup_mma_load_matrix and gpu.subgroup_mma_compute.

Example:

gpu.subgroup_mma_store_matrix %D, %sg[%i,%j] : { leadDimension = 32 : i32}
                : !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, 3>

Interfaces: MemoryEffectOpInterface (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{MemoryEffects::Write on ::mlir::SideEffects::DefaultResource}

Attributes: 

AttributeMLIR TypeDescription
leadDimension::mlir::IntegerAttrindex attribute
transpose::mlir::UnitAttrunit attribute

Operands: 

OperandDescription
srcgpu.mma_matrix of 8-bit signed integer or 8-bit unsigned integer or 32-bit signless integer or 16-bit float or 32-bit float values
dstMemrefmemref of 8-bit signless integer or 32-bit signless integer or 16-bit float or 32-bit float or vector of 8-bit signless integer or 32-bit signless integer or 16-bit float or 32-bit float values of ranks 1 values
indicesvariadic of index

gpu.subgroup_reduce (gpu::SubgroupReduceOp) 

Reduce values among subgroup.

Syntax:

operation ::= `gpu.subgroup_reduce` custom<AllReduceOperation>($op) $value
              (`uniform` $uniform^)? attr-dict
              `:` functional-type(operands, results)

The subgroup_reduce op reduces the value of every lane (work item) across a subgroup. The result is equal for all lanes.

When the reduced value is of a vector type, each vector element is reduced independently. Only 1-d vector types are allowed.

Example:

%1 = gpu.subgroup_reduce add %a : (f32) -> (f32)
%2 = gpu.subgroup_reduce add %b : (vector<4xf16>) -> (vector<4xf16>)

If uniform flag is set either none or all lanes of a subgroup need to execute this op in convergence. The reduction operation must be one of:

  • Integer types: add, mul, minui, minsi, maxui, maxsi, and, or, xor
  • Floating point types: add, mul, minnumf, maxnumf, minimumf, maximumf

Traits: SameOperandsAndResultType

Interfaces: InferTypeOpInterface

Attributes: 

AttributeMLIR TypeDescription
op::mlir::gpu::AllReduceOperationAttr
built-in reduction operations supported by gpu.allreduce.

Enum cases:

  • add (ADD)
  • mul (MUL)
  • minui (MINUI)
  • minsi (MINSI)
  • minnumf (MINNUMF)
  • maxui (MAXUI)
  • maxsi (MAXSI)
  • maxnumf (MAXNUMF)
  • and (AND)
  • or (OR)
  • xor (XOR)
  • minimumf (MINIMUMF)
  • maximumf (MAXIMUMF)
uniform::mlir::UnitAttrunit attribute

Operands: 

OperandDescription
valueInteger or Float or vector of Integer or Float values of ranks 1

Results: 

ResultDescription
resultInteger or Float or vector of Integer or Float values of ranks 1

gpu.subgroup_size (gpu::SubgroupSizeOp) 

Syntax:

operation ::= `gpu.subgroup_size` attr-dict `:` type($result)

Returns the number of threads within a subgroup.

Example:

%sgSz = gpu.subgroup_size : index

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, InferIntRangeInterface, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Results: 

ResultDescription
resultindex

gpu.terminator (gpu::TerminatorOp) 

Terminator for GPU launch regions.

Syntax:

operation ::= `gpu.terminator` attr-dict

A terminator operation for regions that appear in the body of gpu.launch operation. These regions are not expected to return any value so the terminator takes no operands.

Traits: AlwaysSpeculatableImplTrait, HasParent<LaunchOp>, Terminator

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

gpu.thread_id (gpu::ThreadIdOp) 

Syntax:

operation ::= `gpu.thread_id` $dimension attr-dict

Returns the thread id, i.e. the index of the current thread within the block along the x, y, or z dimension.

Example:

%tIdX = gpu.thread_id x

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, InferIntRangeInterface, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
dimension::mlir::gpu::DimensionAttr
a dimension, either 'x', 'y', or 'z'

Enum cases:

  • x (x)
  • y (y)
  • z (z)

Results: 

ResultDescription
«unnamed»index

gpu.wait (gpu::WaitOp) 

Wait for async gpu ops to complete.

Syntax:

operation ::= `gpu.wait` custom<AsyncDependencies>(type($asyncToken), $asyncDependencies) attr-dict

This op synchronizes the host or the device with a list of dependent ops.

If the op contains the async keyword, it returns a new async token which is synchronized with the op arguments. This new token is merely a shortcut to the argument list, and one could replace the uses of the result with the arguments for the same effect. The async version of this op is primarily used to make each async token have a single use during lowering and thereby make forks in async execution explicit. Example usage:

%t0 = gpu.foo async : !gpu.async.token
%t1 = gpu.bar async : !gpu.async.token
%t2 = gpu.wait async [%t0, %t1]
// gpu.baz doesn't run until gpu.foo and gpu.bar have both completed, just
// as if the async dependencies were [%t0, %t1].
%t3 = gpu.baz async [%t2]

If the op does not contain the async keyword, it does not return a new async token but blocks until all ops producing the async dependency tokens finished execution. All dependent memory operations are visible to the host once this op completes. Example usage:

%t0 = gpu.foo async : !gpu.async.token
%t1 = gpu.bar async : !gpu.async.token
// The gpu.wait op blocks until gpu.foo and gpu.bar have completed.
gpu.wait [%t0, %t1]

Interfaces: GPU_AsyncOpInterface

Operands: 

OperandDescription
asyncDependenciesvariadic of async token type

Results: 

ResultDescription
asyncTokenasync token type

gpu.yield (gpu::YieldOp) 

GPU yield operation

Syntax:

operation ::= `gpu.yield` attr-dict ($values^ `:` type($values))?

gpu.yield` is a special terminator operation for blocks inside regions in gpu ops. It returns values to the immediately enclosing gpu op.

Example:

gpu.yield %f0, %f1 : f32, f32

Traits: AlwaysSpeculatableImplTrait, ReturnLike, Terminator

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), RegionBranchTerminatorOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
valuesvariadic of any type